pokemon-card-ocr / inference /model_loader.py
netprtony
Add initial implementation of Pokémon Card OCR with Gradio interface and model loading
372a329
import torch
import os
from transformers import MllamaForConditionalGeneration, AutoProcessor
from peft import PeftModel
# Use Hugging Face model IDs
BASE_MODEL_ID = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit"
LORA_MODEL_ID = "netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA"
def load_model_and_tokenizer(use_lora=True):
"""
Load the base model and apply LoRA adapter for inference from Hugging Face
Args:
use_lora: Whether to load and apply LoRA adapter
Returns:
model: The fine-tuned model ready for inference
processor: The processor (tokenizer)
"""
try:
# Check device availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Using device: {device}")
if device == "cpu":
print("⚠️ Warning: Running on CPU. This will be very slow. GPU strongly recommended.")
# Load base model from Hugging Face
print("📥 Loading base model from Hugging Face...")
print(f"📌 Model: {BASE_MODEL_ID}")
model = MllamaForConditionalGeneration.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else "cpu",
trust_remote_code=True,
)
# Load processor (tokenizer) from Hugging Face
print("📥 Loading processor from Hugging Face...")
processor = AutoProcessor.from_pretrained(
BASE_MODEL_ID,
trust_remote_code=True,
)
# Load LoRA adapter from Hugging Face if requested
if use_lora:
print("📥 Loading LoRA adapter from Hugging Face...")
print(f"📌 LoRA Model: {LORA_MODEL_ID}")
try:
model = PeftModel.from_pretrained(model, LORA_MODEL_ID)
print("✅ LoRA adapter loaded successfully!")
except Exception as lora_error:
print(f"⚠️ Warning: Could not load LoRA adapter: {str(lora_error)}")
print("📌 Using base model without fine-tuning")
else:
print("📌 Using base model without LoRA adapter")
# Set to eval mode
model.eval()
print("✅ Model loaded successfully!")
return model, processor
except Exception as e:
print(f"❌ Error loading model: {str(e)}")
import traceback
traceback.print_exc()
raise
def prepare_inputs(image, processor, device):
"""
Prepare inputs for the model from the image using the processor.
Args:
image: PIL Image
processor: The processor (tokenizer)
device: Device to move tensors to
Returns:
inputs: Prepared inputs for the model
"""
inputs = processor(images=image, return_tensors="pt").to(device)
return inputs