import gradio as gr from PIL import Image import torch from inference.model_loader import load_model_and_tokenizer import json # Check device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🖥️ Using device: {device}") # Load model print("📥 Loading model from Hugging Face...") model, processor = load_model_and_tokenizer(use_lora=True) print("✅ Model loaded successfully!") def extract_pokemon_card_info(image, max_tokens=256, temperature=0.7, top_p=0.9): """ Extract card information from Pokémon card image Args: image: PIL Image max_tokens: Maximum number of tokens to generate temperature: Sampling temperature top_p: Nucleus sampling parameter Returns: raw_output: Raw text output from model json_output: Parsed JSON output (if available) """ try: if image is None: return "⚠️ Please upload an image first!", "" # Prepare instruction instruction = "You are an OCR expert specialized in Pokémon cards. Extract the card name and card number in JSON format." # Prepare conversation conversation = [ { "role": "user", "content": [ {"type": "text", "text": instruction}, {"type": "image"}, ], }, ] # Apply chat template prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) # Process inputs inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) # Generate output with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True if temperature > 0 else False, ) # Decode output decoded_output = processor.decode(outputs[0], skip_special_tokens=True) # Try to parse as JSON for better display json_output = "" try: # Extract JSON from output json_start = decoded_output.find('{') json_end = decoded_output.rfind('}') + 1 if json_start >= 0 and json_end > json_start: json_str = decoded_output[json_start:json_end] result_json = json.loads(json_str) json_output = json.dumps(result_json, indent=2, ensure_ascii=False) except Exception as json_error: json_output = "Could not parse output as JSON" return decoded_output, json_output except Exception as e: import traceback error_msg = f"❌ Error during inference:\n{str(e)}\n\n{traceback.format_exc()}" return error_msg, "" # Create Gradio interface with gr.Blocks(title="Pokémon Card OCR Demo", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎴 Pokémon Card OCR Demo") gr.Markdown("Extract card name and number from Pokémon card images using AI") gr.Markdown("**Models:** `unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit` + `netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA`") # Device info if device == "cpu": gr.Markdown("⚠️ **Warning:** Running on CPU - Processing will be VERY slow. GPU strongly recommended for production use.") else: gr.Markdown(f"✅ **Using GPU:** {torch.cuda.get_device_name(0)}") with gr.Row(): with gr.Column(scale=1): # Input section gr.Markdown("### 📥 Input") image_input = gr.Image(type="pil", label="Upload Pokémon Card Image") # Settings gr.Markdown("### ⚙️ Settings") max_tokens = gr.Slider(minimum=64, maximum=512, value=256, step=1, label="Max Tokens") temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top P") # Extract button extract_btn = gr.Button("🔍 Extract Information", variant="primary", size="lg") with gr.Column(scale=1): # Output section gr.Markdown("### 🎯 Results") raw_output = gr.Textbox(label="Raw Output", lines=10, max_lines=20) json_output = gr.Code(label="Parsed JSON", language="json", lines=10) # Example images gr.Markdown("### 📋 Examples") gr.Examples( examples=[ ["https://images.pokemontcg.io/base1/4.png"], ["https://images.pokemontcg.io/base1/16.png"], ["https://images.pokemontcg.io/xy1/1.png"], ], inputs=[image_input], label="Click to load example images" ) # Event handlers extract_btn.click( fn=extract_pokemon_card_info, inputs=[image_input, max_tokens, temperature, top_p], outputs=[raw_output, json_output] ) # Footer gr.Markdown("---") gr.Markdown("🎴 Pokémon Card OCR Demo | Powered by Llama-3.2-11B-Vision with LoRA fine-tuning") gr.Markdown("Base Model: [unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit](https://huggingface.co/unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit)") gr.Markdown("LoRA Adapter: [netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA](https://huggingface.co/netprtony/Llama-3.2-11B-Vision-PokemonCard-OCR-LoRA)") # Launch the app if __name__ == "__main__": demo.launch( server_name="0.0.0.0", # Allow external access server_port=7860, # Default Gradio port share=False, # Set to True to create a public link show_error=True, )