File size: 5,848 Bytes
372a329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
from PIL import Image
import torch
from inference.model_loader import load_model_and_tokenizer
import json

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️  Using device: {device}")

# Load model
print("📥 Loading model from Hugging Face...")
model, processor = load_model_and_tokenizer(use_lora=True)
print("✅ Model loaded successfully!")

def extract_pokemon_card_info(image, max_tokens=256, temperature=0.7, top_p=0.9):
    """
    Extract card information from Pokémon card image
    
    Args:
        image: PIL Image
        max_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature
        top_p: Nucleus sampling parameter
        
    Returns:
        raw_output: Raw text output from model
        json_output: Parsed JSON output (if available)
    """
    try:
        if image is None:
            return "⚠️ Please upload an image first!", ""
        
        # Prepare instruction
        instruction = "You are an OCR expert specialized in Pokémon cards. Extract the card name and card number in JSON format."
        
        # Prepare conversation
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": instruction},
                    {"type": "image"},
                ],
            },
        ]
        
        # Apply chat template
        prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        
        # Process inputs
        inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
        
        # Generate output
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True if temperature > 0 else False,
            )
        
        # Decode output
        decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
        
        # Try to parse as JSON for better display
        json_output = ""
        try:
            # Extract JSON from output
            json_start = decoded_output.find('{')
            json_end = decoded_output.rfind('}') + 1
            if json_start >= 0 and json_end > json_start:
                json_str = decoded_output[json_start:json_end]
                result_json = json.loads(json_str)
                json_output = json.dumps(result_json, indent=2, ensure_ascii=False)
        except Exception as json_error:
            json_output = "Could not parse output as JSON"
        
        return decoded_output, json_output
        
    except Exception as e:
        import traceback
        error_msg = f"❌ Error during inference:\n{str(e)}\n\n{traceback.format_exc()}"
        return error_msg, ""

# Create Gradio interface
with gr.Blocks(title="Pokémon Card OCR Demo", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🎴 Pokémon Card OCR Demo")
    gr.Markdown("Extract card name and number from Pokémon card images using AI")
    gr.Markdown("**Models:** `unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit` + `netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA`")
    
    # Device info
    if device == "cpu":
        gr.Markdown("⚠️ **Warning:** Running on CPU - Processing will be VERY slow. GPU strongly recommended for production use.")
    else:
        gr.Markdown(f"✅ **Using GPU:** {torch.cuda.get_device_name(0)}")
    
    with gr.Row():
        with gr.Column(scale=1):
            # Input section
            gr.Markdown("### 📥 Input")
            image_input = gr.Image(type="pil", label="Upload Pokémon Card Image")
            
            # Settings
            gr.Markdown("### ⚙️ Settings")
            max_tokens = gr.Slider(minimum=64, maximum=512, value=256, step=1, label="Max Tokens")
            temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P")
            
            # Extract button
            extract_btn = gr.Button("🔍 Extract Information", variant="primary", size="lg")
        
        with gr.Column(scale=1):
            # Output section
            gr.Markdown("### 🎯 Results")
            raw_output = gr.Textbox(label="Raw Output", lines=10, max_lines=20)
            json_output = gr.Code(label="Parsed JSON", language="json", lines=10)
    
    # Example images
    gr.Markdown("### 📋 Examples")
    gr.Examples(
        examples=[
            ["https://images.pokemontcg.io/base1/4.png"],
            ["https://images.pokemontcg.io/base1/16.png"],
            ["https://images.pokemontcg.io/xy1/1.png"],
        ],
        inputs=[image_input],
        label="Click to load example images"
    )
    
    # Event handlers
    extract_btn.click(
        fn=extract_pokemon_card_info,
        inputs=[image_input, max_tokens, temperature, top_p],
        outputs=[raw_output, json_output]
    )
    
    # Footer
    gr.Markdown("---")
    gr.Markdown("🎴 Pokémon Card OCR Demo | Powered by Llama-3.2-11B-Vision with LoRA fine-tuning")
    gr.Markdown("Base Model: [unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit](https://huggingface.co/unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit)")
    gr.Markdown("LoRA Adapter: [netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA](https://huggingface.co/netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA)")

# Launch the app
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",  # Allow external access
        server_port=7860,       # Default Gradio port
        share=False,            # Set to True to create a public link
        show_error=True,
    )