netprtony
Add initial implementation of Pokémon Card OCR with Gradio interface and model loading
372a329
import gradio as gr
from PIL import Image
import torch
from inference.model_loader import load_model_and_tokenizer
import json
# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Using device: {device}")
# Load model
print("📥 Loading model from Hugging Face...")
model, processor = load_model_and_tokenizer(use_lora=True)
print("✅ Model loaded successfully!")
def extract_pokemon_card_info(image, max_tokens=256, temperature=0.7, top_p=0.9):
"""
Extract card information from Pokémon card image
Args:
image: PIL Image
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
Returns:
raw_output: Raw text output from model
json_output: Parsed JSON output (if available)
"""
try:
if image is None:
return "⚠️ Please upload an image first!", ""
# Prepare instruction
instruction = "You are an OCR expert specialized in Pokémon cards. Extract the card name and card number in JSON format."
# Prepare conversation
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": instruction},
{"type": "image"},
],
},
]
# Apply chat template
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# Process inputs
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
# Generate output
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True if temperature > 0 else False,
)
# Decode output
decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
# Try to parse as JSON for better display
json_output = ""
try:
# Extract JSON from output
json_start = decoded_output.find('{')
json_end = decoded_output.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = decoded_output[json_start:json_end]
result_json = json.loads(json_str)
json_output = json.dumps(result_json, indent=2, ensure_ascii=False)
except Exception as json_error:
json_output = "Could not parse output as JSON"
return decoded_output, json_output
except Exception as e:
import traceback
error_msg = f"❌ Error during inference:\n{str(e)}\n\n{traceback.format_exc()}"
return error_msg, ""
# Create Gradio interface
with gr.Blocks(title="Pokémon Card OCR Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎴 Pokémon Card OCR Demo")
gr.Markdown("Extract card name and number from Pokémon card images using AI")
gr.Markdown("**Models:** `unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit` + `netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA`")
# Device info
if device == "cpu":
gr.Markdown("⚠️ **Warning:** Running on CPU - Processing will be VERY slow. GPU strongly recommended for production use.")
else:
gr.Markdown(f"✅ **Using GPU:** {torch.cuda.get_device_name(0)}")
with gr.Row():
with gr.Column(scale=1):
# Input section
gr.Markdown("### 📥 Input")
image_input = gr.Image(type="pil", label="Upload Pokémon Card Image")
# Settings
gr.Markdown("### ⚙️ Settings")
max_tokens = gr.Slider(minimum=64, maximum=512, value=256, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P")
# Extract button
extract_btn = gr.Button("🔍 Extract Information", variant="primary", size="lg")
with gr.Column(scale=1):
# Output section
gr.Markdown("### 🎯 Results")
raw_output = gr.Textbox(label="Raw Output", lines=10, max_lines=20)
json_output = gr.Code(label="Parsed JSON", language="json", lines=10)
# Example images
gr.Markdown("### 📋 Examples")
gr.Examples(
examples=[
["https://images.pokemontcg.io/base1/4.png"],
["https://images.pokemontcg.io/base1/16.png"],
["https://images.pokemontcg.io/xy1/1.png"],
],
inputs=[image_input],
label="Click to load example images"
)
# Event handlers
extract_btn.click(
fn=extract_pokemon_card_info,
inputs=[image_input, max_tokens, temperature, top_p],
outputs=[raw_output, json_output]
)
# Footer
gr.Markdown("---")
gr.Markdown("🎴 Pokémon Card OCR Demo | Powered by Llama-3.2-11B-Vision with LoRA fine-tuning")
gr.Markdown("Base Model: [unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit](https://huggingface.co/unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit)")
gr.Markdown("LoRA Adapter: [netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA](https://huggingface.co/netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA)")
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", # Allow external access
server_port=7860, # Default Gradio port
share=False, # Set to True to create a public link
show_error=True,
)