netprtony commited on
Commit
372a329
·
1 Parent(s): a748d78

Add initial implementation of Pokémon Card OCR with Gradio interface and model loading

Browse files
Files changed (3) hide show
  1. app.py +152 -0
  2. inference/model_loader.py +86 -0
  3. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from inference.model_loader import load_model_and_tokenizer
5
+ import json
6
+
7
+ # Check device
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"🖥️ Using device: {device}")
10
+
11
+ # Load model
12
+ print("📥 Loading model from Hugging Face...")
13
+ model, processor = load_model_and_tokenizer(use_lora=True)
14
+ print("✅ Model loaded successfully!")
15
+
16
+ def extract_pokemon_card_info(image, max_tokens=256, temperature=0.7, top_p=0.9):
17
+ """
18
+ Extract card information from Pokémon card image
19
+
20
+ Args:
21
+ image: PIL Image
22
+ max_tokens: Maximum number of tokens to generate
23
+ temperature: Sampling temperature
24
+ top_p: Nucleus sampling parameter
25
+
26
+ Returns:
27
+ raw_output: Raw text output from model
28
+ json_output: Parsed JSON output (if available)
29
+ """
30
+ try:
31
+ if image is None:
32
+ return "⚠️ Please upload an image first!", ""
33
+
34
+ # Prepare instruction
35
+ instruction = "You are an OCR expert specialized in Pokémon cards. Extract the card name and card number in JSON format."
36
+
37
+ # Prepare conversation
38
+ conversation = [
39
+ {
40
+ "role": "user",
41
+ "content": [
42
+ {"type": "text", "text": instruction},
43
+ {"type": "image"},
44
+ ],
45
+ },
46
+ ]
47
+
48
+ # Apply chat template
49
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
50
+
51
+ # Process inputs
52
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
53
+
54
+ # Generate output
55
+ with torch.no_grad():
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=max_tokens,
59
+ temperature=temperature,
60
+ top_p=top_p,
61
+ do_sample=True if temperature > 0 else False,
62
+ )
63
+
64
+ # Decode output
65
+ decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
66
+
67
+ # Try to parse as JSON for better display
68
+ json_output = ""
69
+ try:
70
+ # Extract JSON from output
71
+ json_start = decoded_output.find('{')
72
+ json_end = decoded_output.rfind('}') + 1
73
+ if json_start >= 0 and json_end > json_start:
74
+ json_str = decoded_output[json_start:json_end]
75
+ result_json = json.loads(json_str)
76
+ json_output = json.dumps(result_json, indent=2, ensure_ascii=False)
77
+ except Exception as json_error:
78
+ json_output = "Could not parse output as JSON"
79
+
80
+ return decoded_output, json_output
81
+
82
+ except Exception as e:
83
+ import traceback
84
+ error_msg = f"❌ Error during inference:\n{str(e)}\n\n{traceback.format_exc()}"
85
+ return error_msg, ""
86
+
87
+ # Create Gradio interface
88
+ with gr.Blocks(title="Pokémon Card OCR Demo", theme=gr.themes.Soft()) as demo:
89
+ gr.Markdown("# 🎴 Pokémon Card OCR Demo")
90
+ gr.Markdown("Extract card name and number from Pokémon card images using AI")
91
+ gr.Markdown("**Models:** `unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit` + `netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA`")
92
+
93
+ # Device info
94
+ if device == "cpu":
95
+ gr.Markdown("⚠️ **Warning:** Running on CPU - Processing will be VERY slow. GPU strongly recommended for production use.")
96
+ else:
97
+ gr.Markdown(f"✅ **Using GPU:** {torch.cuda.get_device_name(0)}")
98
+
99
+ with gr.Row():
100
+ with gr.Column(scale=1):
101
+ # Input section
102
+ gr.Markdown("### 📥 Input")
103
+ image_input = gr.Image(type="pil", label="Upload Pokémon Card Image")
104
+
105
+ # Settings
106
+ gr.Markdown("### ⚙️ Settings")
107
+ max_tokens = gr.Slider(minimum=64, maximum=512, value=256, step=1, label="Max Tokens")
108
+ temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
109
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P")
110
+
111
+ # Extract button
112
+ extract_btn = gr.Button("🔍 Extract Information", variant="primary", size="lg")
113
+
114
+ with gr.Column(scale=1):
115
+ # Output section
116
+ gr.Markdown("### 🎯 Results")
117
+ raw_output = gr.Textbox(label="Raw Output", lines=10, max_lines=20)
118
+ json_output = gr.Code(label="Parsed JSON", language="json", lines=10)
119
+
120
+ # Example images
121
+ gr.Markdown("### 📋 Examples")
122
+ gr.Examples(
123
+ examples=[
124
+ ["https://images.pokemontcg.io/base1/4.png"],
125
+ ["https://images.pokemontcg.io/base1/16.png"],
126
+ ["https://images.pokemontcg.io/xy1/1.png"],
127
+ ],
128
+ inputs=[image_input],
129
+ label="Click to load example images"
130
+ )
131
+
132
+ # Event handlers
133
+ extract_btn.click(
134
+ fn=extract_pokemon_card_info,
135
+ inputs=[image_input, max_tokens, temperature, top_p],
136
+ outputs=[raw_output, json_output]
137
+ )
138
+
139
+ # Footer
140
+ gr.Markdown("---")
141
+ gr.Markdown("🎴 Pokémon Card OCR Demo | Powered by Llama-3.2-11B-Vision with LoRA fine-tuning")
142
+ gr.Markdown("Base Model: [unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit](https://huggingface.co/unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit)")
143
+ gr.Markdown("LoRA Adapter: [netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA](https://huggingface.co/netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA)")
144
+
145
+ # Launch the app
146
+ if __name__ == "__main__":
147
+ demo.launch(
148
+ server_name="0.0.0.0", # Allow external access
149
+ server_port=7860, # Default Gradio port
150
+ share=False, # Set to True to create a public link
151
+ show_error=True,
152
+ )
inference/model_loader.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from transformers import MllamaForConditionalGeneration, AutoProcessor
4
+ from peft import PeftModel
5
+
6
+ # Use Hugging Face model IDs
7
+ BASE_MODEL_ID = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit"
8
+ LORA_MODEL_ID = "netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA"
9
+
10
+ def load_model_and_tokenizer(use_lora=True):
11
+ """
12
+ Load the base model and apply LoRA adapter for inference from Hugging Face
13
+
14
+ Args:
15
+ use_lora: Whether to load and apply LoRA adapter
16
+
17
+ Returns:
18
+ model: The fine-tuned model ready for inference
19
+ processor: The processor (tokenizer)
20
+ """
21
+ try:
22
+ # Check device availability
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ print(f"🖥️ Using device: {device}")
25
+
26
+ if device == "cpu":
27
+ print("⚠️ Warning: Running on CPU. This will be very slow. GPU strongly recommended.")
28
+
29
+ # Load base model from Hugging Face
30
+ print("📥 Loading base model from Hugging Face...")
31
+ print(f"📌 Model: {BASE_MODEL_ID}")
32
+
33
+ model = MllamaForConditionalGeneration.from_pretrained(
34
+ BASE_MODEL_ID,
35
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
36
+ device_map="auto" if device == "cuda" else "cpu",
37
+ trust_remote_code=True,
38
+ )
39
+
40
+ # Load processor (tokenizer) from Hugging Face
41
+ print("📥 Loading processor from Hugging Face...")
42
+ processor = AutoProcessor.from_pretrained(
43
+ BASE_MODEL_ID,
44
+ trust_remote_code=True,
45
+ )
46
+
47
+ # Load LoRA adapter from Hugging Face if requested
48
+ if use_lora:
49
+ print("📥 Loading LoRA adapter from Hugging Face...")
50
+ print(f"📌 LoRA Model: {LORA_MODEL_ID}")
51
+ try:
52
+ model = PeftModel.from_pretrained(model, LORA_MODEL_ID)
53
+ print("✅ LoRA adapter loaded successfully!")
54
+ except Exception as lora_error:
55
+ print(f"⚠️ Warning: Could not load LoRA adapter: {str(lora_error)}")
56
+ print("📌 Using base model without fine-tuning")
57
+ else:
58
+ print("📌 Using base model without LoRA adapter")
59
+
60
+ # Set to eval mode
61
+ model.eval()
62
+
63
+ print("✅ Model loaded successfully!")
64
+ return model, processor
65
+
66
+ except Exception as e:
67
+ print(f"❌ Error loading model: {str(e)}")
68
+ import traceback
69
+ traceback.print_exc()
70
+ raise
71
+
72
+ def prepare_inputs(image, processor, device):
73
+ """
74
+ Prepare inputs for the model from the image using the processor.
75
+
76
+ Args:
77
+ image: PIL Image
78
+ processor: The processor (tokenizer)
79
+ device: Device to move tensors to
80
+ Returns:
81
+ inputs: Prepared inputs for the model
82
+ """
83
+ inputs = processor(images=image, return_tensors="pt").to(device)
84
+ return inputs
85
+
86
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow
4
+ requests
5
+ tqdm
6
+ unsloth
7
+ datasets
8
+ trl
9
+ bitsandbytes
10
+ peft
11
+ accelerate