""" HuggingFace Space fรผr LightOnOCR Fine-tuning & Testing """ import gradio as gr import spaces import torch from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor from peft import LoraConfig, get_peft_model from datasets import load_dataset import json import os from pathlib import Path # Global state model = None processor = None training_status = "Not started" @spaces.GPU def setup_training(): """Setup training environment""" global training_status training_status = "Setting up..." # Check GPU if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) vram = torch.cuda.get_device_properties(0).total_memory / 1e9 info = f"โœ… GPU: {gpu_name}\n๐Ÿ’พ VRAM: {vram:.2f} GB" else: info = "โš ๏ธ No GPU detected!" training_status = "Ready" return info @spaces.GPU def start_training(epochs, batch_size, learning_rate): """Start LoRA fine-tuning""" global training_status, model, processor try: training_status = "Loading model..." yield f"๐Ÿ“ฆ Loading LightOnOCR-2-1B...\n" # Load model model = LightOnOcrForConditionalGeneration.from_pretrained( "lightonai/LightOnOCR-2-1B", torch_dtype=torch.bfloat16, device_map="auto" ) processor = LightOnOcrProcessor.from_pretrained("lightonai/LightOnOCR-2-1B") yield f"โœ… Model loaded\nโš™๏ธ Configuring LoRA...\n" # Debug: Check model structure yield f"๐Ÿ” Model type: {type(model).__name__}\n" vision_attrs = [attr for attr in dir(model) if 'vision' in attr.lower()] yield f"๐Ÿ” Vision attributes: {vision_attrs}\n" # LoRA config lora_config = LoraConfig( r=32, lora_alpha=64, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) # Try to freeze vision encoder vision_frozen = False for attr_name in ['vision_model', 'vision_tower', 'vision_encoder', 'visual_model', 'vision']: if hasattr(model, attr_name): try: vision_model = getattr(model, attr_name) for param in vision_model.parameters(): param.requires_grad = False yield f"โ„๏ธ Frozen: {attr_name}\n" vision_frozen = True break except Exception as e: yield f"โš ๏ธ Could not freeze {attr_name}: {e}\n" if not vision_frozen: yield f"โš ๏ธ Vision model not frozen (will train everything)\n" yield f"๐Ÿ”ง Applying LoRA...\n" model = get_peft_model(model, lora_config) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) yield f"โœ… LoRA applied\n๐Ÿ“Š Trainable: {trainable:,} / {total:,} params\n" yield f"๐Ÿ“Š Percentage: {100*trainable/total:.2f}%\n" # Save model output_dir = "./lightonocr-chat-lora" yield f"๏ฟฝ Saving model to {output_dir}...\n" model.save_pretrained(output_dir) processor.save_pretrained(output_dir) yield f"โœ… Model saved!\n" yield f"๐Ÿš€ Training setup complete!\n" yield f"โš ๏ธ Full training loop not implemented yet\n" training_status = "Complete (Demo)" yield f"โœ… Model ready for testing!\n" except Exception as e: training_status = f"Error: {str(e)}" yield f"โŒ Error: {str(e)}\n" import traceback yield f"๐Ÿ“‹ Traceback:\n{traceback.format_exc()}\n" @spaces.GPU def test_model(prompt, image=None): """Test the model""" global model, processor # Try to load saved model if not in memory if model is None: try: from peft import PeftModel output_dir = "./lightonocr-chat-lora" if os.path.exists(output_dir): model = LightOnOcrForConditionalGeneration.from_pretrained( "lightonai/LightOnOCR-2-1B", torch_dtype=torch.bfloat16, device_map="auto" ) model = PeftModel.from_pretrained(model, output_dir) processor = LightOnOcrProcessor.from_pretrained(output_dir) else: return "โš ๏ธ Model not trained yet. Run training first!" except Exception as e: return f"โš ๏ธ Could not load model: {str(e)}" try: if image is not None: # OCR mode conversation = [{"role": "user", "content": [{"type": "image", "image": image}]}] inputs = processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) else: # Chat mode text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" inputs = processor.tokenizer(text, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} output_ids = model.generate(**inputs, max_new_tokens=256) if image is not None: generated_ids = output_ids[0, inputs["input_ids"].shape[1]:] output_text = processor.decode(generated_ids, skip_special_tokens=True) else: output_text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=False) if "<|im_start|>assistant" in output_text: output_text = output_text.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip() return output_text except Exception as e: return f"โŒ Error: {str(e)}" # Gradio Interface with gr.Blocks(title="LightOnOCR Fine-tuning Lab") as demo: gr.Markdown(""" # ๐Ÿ”ฌ LightOnOCR-2-1B Fine-tuning Lab Reaktiviere Tool-Calling Fรคhigkeiten + Chat Integration """) with gr.Tab("๐Ÿš€ Training"): gr.Markdown("### Setup & Training") setup_btn = gr.Button("๐Ÿ”ง Check GPU", variant="primary") setup_output = gr.Textbox(label="System Info", lines=3) gr.Markdown("### Training Configuration") with gr.Row(): epochs = gr.Slider(1, 10, value=5, step=1, label="Epochs") batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size") learning_rate = gr.Slider(1e-5, 1e-3, value=3e-4, step=1e-5, label="Learning Rate") train_btn = gr.Button("๐Ÿš€ Start Training", variant="primary") train_output = gr.Textbox(label="Training Log", lines=15) setup_btn.click(setup_training, outputs=setup_output) train_btn.click( start_training, inputs=[epochs, batch_size, learning_rate], outputs=train_output ) with gr.Tab("๐Ÿ’ฌ Test Chat"): gr.Markdown("### Test Conversational Abilities") chat_input = gr.Textbox(label="Your Message", placeholder="Hallo! Wie geht es dir?") chat_btn = gr.Button("Send", variant="primary") chat_output = gr.Textbox(label="Response", lines=5) chat_btn.click(test_model, inputs=[chat_input], outputs=chat_output) gr.Examples( examples=[ ["Hallo!"], ["Wie geht es dir?"], ["Was kannst du?"], ["Was ist OCR?"], ], inputs=chat_input ) with gr.Tab("๐Ÿ–ผ๏ธ Test OCR"): gr.Markdown("### Test OCR Capabilities") image_input = gr.Image(type="pil", label="Upload Image") ocr_btn = gr.Button("Extract Text", variant="primary") ocr_output = gr.Textbox(label="Extracted Text", lines=5) ocr_btn.click(test_model, inputs=[gr.Textbox(visible=False, value=""), image_input], outputs=ocr_output) with gr.Tab("๐Ÿ“Š Info"): gr.Markdown(""" ## ๐ŸŽฏ Project Goal Reaktiviere Tool-Calling Fรคhigkeiten in LightOnOCR-2-1B und fรผge Chat-Integration hinzu. ## ๐Ÿ—๏ธ Architecture - **Base Model**: LightOnOCR-2-1B (1B params) - **Method**: LoRA Fine-tuning - **Frozen**: Vision Encoder (behalte OCR) - **Trainable**: Language Model (~50M params) ## ๐Ÿ”ง Special Tokens - `` / `` - Tool invocation - `` / `` - Tool results - `` / `` - Reasoning - `<|box_start|>` / `<|box_end|>` - Bounding boxes ## ๐Ÿ“ˆ Training Data - 34 Training Examples - 4 Validation Examples - Categories: Greetings, Facts, Math, Help ## ๐Ÿš€ Usage 1. Click "Check GPU" to verify H200 2. Configure training parameters 3. Click "Start Training" 4. Test in Chat/OCR tabs ## ๐Ÿ“ Notes - Training takes ~5-10 minutes on H200 - OCR capabilities are preserved - Tool-calling is learned from scratch """) if __name__ == "__main__": demo.launch()