Spaces:
Runtime error
Runtime error
File size: 5,848 Bytes
372a329 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gradio as gr
from PIL import Image
import torch
from inference.model_loader import load_model_and_tokenizer
import json
# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Using device: {device}")
# Load model
print("📥 Loading model from Hugging Face...")
model, processor = load_model_and_tokenizer(use_lora=True)
print("✅ Model loaded successfully!")
def extract_pokemon_card_info(image, max_tokens=256, temperature=0.7, top_p=0.9):
"""
Extract card information from Pokémon card image
Args:
image: PIL Image
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
Returns:
raw_output: Raw text output from model
json_output: Parsed JSON output (if available)
"""
try:
if image is None:
return "⚠️ Please upload an image first!", ""
# Prepare instruction
instruction = "You are an OCR expert specialized in Pokémon cards. Extract the card name and card number in JSON format."
# Prepare conversation
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": instruction},
{"type": "image"},
],
},
]
# Apply chat template
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# Process inputs
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
# Generate output
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True if temperature > 0 else False,
)
# Decode output
decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
# Try to parse as JSON for better display
json_output = ""
try:
# Extract JSON from output
json_start = decoded_output.find('{')
json_end = decoded_output.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = decoded_output[json_start:json_end]
result_json = json.loads(json_str)
json_output = json.dumps(result_json, indent=2, ensure_ascii=False)
except Exception as json_error:
json_output = "Could not parse output as JSON"
return decoded_output, json_output
except Exception as e:
import traceback
error_msg = f"❌ Error during inference:\n{str(e)}\n\n{traceback.format_exc()}"
return error_msg, ""
# Create Gradio interface
with gr.Blocks(title="Pokémon Card OCR Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎴 Pokémon Card OCR Demo")
gr.Markdown("Extract card name and number from Pokémon card images using AI")
gr.Markdown("**Models:** `unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit` + `netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA`")
# Device info
if device == "cpu":
gr.Markdown("⚠️ **Warning:** Running on CPU - Processing will be VERY slow. GPU strongly recommended for production use.")
else:
gr.Markdown(f"✅ **Using GPU:** {torch.cuda.get_device_name(0)}")
with gr.Row():
with gr.Column(scale=1):
# Input section
gr.Markdown("### 📥 Input")
image_input = gr.Image(type="pil", label="Upload Pokémon Card Image")
# Settings
gr.Markdown("### ⚙️ Settings")
max_tokens = gr.Slider(minimum=64, maximum=512, value=256, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P")
# Extract button
extract_btn = gr.Button("🔍 Extract Information", variant="primary", size="lg")
with gr.Column(scale=1):
# Output section
gr.Markdown("### 🎯 Results")
raw_output = gr.Textbox(label="Raw Output", lines=10, max_lines=20)
json_output = gr.Code(label="Parsed JSON", language="json", lines=10)
# Example images
gr.Markdown("### 📋 Examples")
gr.Examples(
examples=[
["https://images.pokemontcg.io/base1/4.png"],
["https://images.pokemontcg.io/base1/16.png"],
["https://images.pokemontcg.io/xy1/1.png"],
],
inputs=[image_input],
label="Click to load example images"
)
# Event handlers
extract_btn.click(
fn=extract_pokemon_card_info,
inputs=[image_input, max_tokens, temperature, top_p],
outputs=[raw_output, json_output]
)
# Footer
gr.Markdown("---")
gr.Markdown("🎴 Pokémon Card OCR Demo | Powered by Llama-3.2-11B-Vision with LoRA fine-tuning")
gr.Markdown("Base Model: [unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit](https://huggingface.co/unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit)")
gr.Markdown("LoRA Adapter: [netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA](https://huggingface.co/netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA)")
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", # Allow external access
server_port=7860, # Default Gradio port
share=False, # Set to True to create a public link
show_error=True,
) |