Spaces:
Runtime error
Runtime error
netprtony
Add initial implementation of Pokémon Card OCR with Gradio interface and model loading
372a329
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from inference.model_loader import load_model_and_tokenizer | |
| import json | |
| # Check device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🖥️ Using device: {device}") | |
| # Load model | |
| print("📥 Loading model from Hugging Face...") | |
| model, processor = load_model_and_tokenizer(use_lora=True) | |
| print("✅ Model loaded successfully!") | |
| def extract_pokemon_card_info(image, max_tokens=256, temperature=0.7, top_p=0.9): | |
| """ | |
| Extract card information from Pokémon card image | |
| Args: | |
| image: PIL Image | |
| max_tokens: Maximum number of tokens to generate | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| Returns: | |
| raw_output: Raw text output from model | |
| json_output: Parsed JSON output (if available) | |
| """ | |
| try: | |
| if image is None: | |
| return "⚠️ Please upload an image first!", "" | |
| # Prepare instruction | |
| instruction = "You are an OCR expert specialized in Pokémon cards. Extract the card name and card number in JSON format." | |
| # Prepare conversation | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": instruction}, | |
| {"type": "image"}, | |
| ], | |
| }, | |
| ] | |
| # Apply chat template | |
| prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
| # Process inputs | |
| inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) | |
| # Generate output | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True if temperature > 0 else False, | |
| ) | |
| # Decode output | |
| decoded_output = processor.decode(outputs[0], skip_special_tokens=True) | |
| # Try to parse as JSON for better display | |
| json_output = "" | |
| try: | |
| # Extract JSON from output | |
| json_start = decoded_output.find('{') | |
| json_end = decoded_output.rfind('}') + 1 | |
| if json_start >= 0 and json_end > json_start: | |
| json_str = decoded_output[json_start:json_end] | |
| result_json = json.loads(json_str) | |
| json_output = json.dumps(result_json, indent=2, ensure_ascii=False) | |
| except Exception as json_error: | |
| json_output = "Could not parse output as JSON" | |
| return decoded_output, json_output | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ Error during inference:\n{str(e)}\n\n{traceback.format_exc()}" | |
| return error_msg, "" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Pokémon Card OCR Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎴 Pokémon Card OCR Demo") | |
| gr.Markdown("Extract card name and number from Pokémon card images using AI") | |
| gr.Markdown("**Models:** `unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit` + `netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA`") | |
| # Device info | |
| if device == "cpu": | |
| gr.Markdown("⚠️ **Warning:** Running on CPU - Processing will be VERY slow. GPU strongly recommended for production use.") | |
| else: | |
| gr.Markdown(f"✅ **Using GPU:** {torch.cuda.get_device_name(0)}") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input section | |
| gr.Markdown("### 📥 Input") | |
| image_input = gr.Image(type="pil", label="Upload Pokémon Card Image") | |
| # Settings | |
| gr.Markdown("### ⚙️ Settings") | |
| max_tokens = gr.Slider(minimum=64, maximum=512, value=256, step=1, label="Max Tokens") | |
| temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P") | |
| # Extract button | |
| extract_btn = gr.Button("🔍 Extract Information", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| # Output section | |
| gr.Markdown("### 🎯 Results") | |
| raw_output = gr.Textbox(label="Raw Output", lines=10, max_lines=20) | |
| json_output = gr.Code(label="Parsed JSON", language="json", lines=10) | |
| # Example images | |
| gr.Markdown("### 📋 Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["https://images.pokemontcg.io/base1/4.png"], | |
| ["https://images.pokemontcg.io/base1/16.png"], | |
| ["https://images.pokemontcg.io/xy1/1.png"], | |
| ], | |
| inputs=[image_input], | |
| label="Click to load example images" | |
| ) | |
| # Event handlers | |
| extract_btn.click( | |
| fn=extract_pokemon_card_info, | |
| inputs=[image_input, max_tokens, temperature, top_p], | |
| outputs=[raw_output, json_output] | |
| ) | |
| # Footer | |
| gr.Markdown("---") | |
| gr.Markdown("🎴 Pokémon Card OCR Demo | Powered by Llama-3.2-11B-Vision with LoRA fine-tuning") | |
| gr.Markdown("Base Model: [unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit](https://huggingface.co/unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit)") | |
| gr.Markdown("LoRA Adapter: [netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA](https://huggingface.co/netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA)") | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", # Allow external access | |
| server_port=7860, # Default Gradio port | |
| share=False, # Set to True to create a public link | |
| show_error=True, | |
| ) |