"""
HuggingFace Space fรผr LightOnOCR Fine-tuning & Testing
"""
import gradio as gr
import spaces
import torch
from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import json
import os
from pathlib import Path
# Global state
model = None
processor = None
training_status = "Not started"
@spaces.GPU
def setup_training():
"""Setup training environment"""
global training_status
training_status = "Setting up..."
# Check GPU
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
vram = torch.cuda.get_device_properties(0).total_memory / 1e9
info = f"โ
GPU: {gpu_name}\n๐พ VRAM: {vram:.2f} GB"
else:
info = "โ ๏ธ No GPU detected!"
training_status = "Ready"
return info
@spaces.GPU
def start_training(epochs, batch_size, learning_rate):
"""Start LoRA fine-tuning"""
global training_status, model, processor
try:
training_status = "Loading model..."
yield f"๐ฆ Loading LightOnOCR-2-1B...\n"
# Load model
model = LightOnOcrForConditionalGeneration.from_pretrained(
"lightonai/LightOnOCR-2-1B",
torch_dtype=torch.bfloat16,
device_map="auto"
)
processor = LightOnOcrProcessor.from_pretrained("lightonai/LightOnOCR-2-1B")
yield f"โ
Model loaded\nโ๏ธ Configuring LoRA...\n"
# Debug: Check model structure
yield f"๐ Model type: {type(model).__name__}\n"
vision_attrs = [attr for attr in dir(model) if 'vision' in attr.lower()]
yield f"๐ Vision attributes: {vision_attrs}\n"
# LoRA config
lora_config = LoraConfig(
r=32,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Try to freeze vision encoder
vision_frozen = False
for attr_name in ['vision_model', 'vision_tower', 'vision_encoder', 'visual_model', 'vision']:
if hasattr(model, attr_name):
try:
vision_model = getattr(model, attr_name)
for param in vision_model.parameters():
param.requires_grad = False
yield f"โ๏ธ Frozen: {attr_name}\n"
vision_frozen = True
break
except Exception as e:
yield f"โ ๏ธ Could not freeze {attr_name}: {e}\n"
if not vision_frozen:
yield f"โ ๏ธ Vision model not frozen (will train everything)\n"
yield f"๐ง Applying LoRA...\n"
model = get_peft_model(model, lora_config)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
yield f"โ
LoRA applied\n๐ Trainable: {trainable:,} / {total:,} params\n"
yield f"๐ Percentage: {100*trainable/total:.2f}%\n"
# Save model
output_dir = "./lightonocr-chat-lora"
yield f"๏ฟฝ Saving model to {output_dir}...\n"
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
yield f"โ
Model saved!\n"
yield f"๐ Training setup complete!\n"
yield f"โ ๏ธ Full training loop not implemented yet\n"
training_status = "Complete (Demo)"
yield f"โ
Model ready for testing!\n"
except Exception as e:
training_status = f"Error: {str(e)}"
yield f"โ Error: {str(e)}\n"
import traceback
yield f"๐ Traceback:\n{traceback.format_exc()}\n"
@spaces.GPU
def test_model(prompt, image=None):
"""Test the model"""
global model, processor
# Try to load saved model if not in memory
if model is None:
try:
from peft import PeftModel
output_dir = "./lightonocr-chat-lora"
if os.path.exists(output_dir):
model = LightOnOcrForConditionalGeneration.from_pretrained(
"lightonai/LightOnOCR-2-1B",
torch_dtype=torch.bfloat16,
device_map="auto"
)
model = PeftModel.from_pretrained(model, output_dir)
processor = LightOnOcrProcessor.from_pretrained(output_dir)
else:
return "โ ๏ธ Model not trained yet. Run training first!"
except Exception as e:
return f"โ ๏ธ Could not load model: {str(e)}"
try:
if image is not None:
# OCR mode
conversation = [{"role": "user", "content": [{"type": "image", "image": image}]}]
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
else:
# Chat mode
text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
inputs = processor.tokenizer(text, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
output_ids = model.generate(**inputs, max_new_tokens=256)
if image is not None:
generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
output_text = processor.decode(generated_ids, skip_special_tokens=True)
else:
output_text = processor.tokenizer.decode(output_ids[0], skip_special_tokens=False)
if "<|im_start|>assistant" in output_text:
output_text = output_text.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip()
return output_text
except Exception as e:
return f"โ Error: {str(e)}"
# Gradio Interface
with gr.Blocks(title="LightOnOCR Fine-tuning Lab") as demo:
gr.Markdown("""
# ๐ฌ LightOnOCR-2-1B Fine-tuning Lab
Reaktiviere Tool-Calling Fรคhigkeiten + Chat Integration
""")
with gr.Tab("๐ Training"):
gr.Markdown("### Setup & Training")
setup_btn = gr.Button("๐ง Check GPU", variant="primary")
setup_output = gr.Textbox(label="System Info", lines=3)
gr.Markdown("### Training Configuration")
with gr.Row():
epochs = gr.Slider(1, 10, value=5, step=1, label="Epochs")
batch_size = gr.Slider(1, 16, value=8, step=1, label="Batch Size")
learning_rate = gr.Slider(1e-5, 1e-3, value=3e-4, step=1e-5, label="Learning Rate")
train_btn = gr.Button("๐ Start Training", variant="primary")
train_output = gr.Textbox(label="Training Log", lines=15)
setup_btn.click(setup_training, outputs=setup_output)
train_btn.click(
start_training,
inputs=[epochs, batch_size, learning_rate],
outputs=train_output
)
with gr.Tab("๐ฌ Test Chat"):
gr.Markdown("### Test Conversational Abilities")
chat_input = gr.Textbox(label="Your Message", placeholder="Hallo! Wie geht es dir?")
chat_btn = gr.Button("Send", variant="primary")
chat_output = gr.Textbox(label="Response", lines=5)
chat_btn.click(test_model, inputs=[chat_input], outputs=chat_output)
gr.Examples(
examples=[
["Hallo!"],
["Wie geht es dir?"],
["Was kannst du?"],
["Was ist OCR?"],
],
inputs=chat_input
)
with gr.Tab("๐ผ๏ธ Test OCR"):
gr.Markdown("### Test OCR Capabilities")
image_input = gr.Image(type="pil", label="Upload Image")
ocr_btn = gr.Button("Extract Text", variant="primary")
ocr_output = gr.Textbox(label="Extracted Text", lines=5)
ocr_btn.click(test_model, inputs=[gr.Textbox(visible=False, value=""), image_input], outputs=ocr_output)
with gr.Tab("๐ Info"):
gr.Markdown("""
## ๐ฏ Project Goal
Reaktiviere Tool-Calling Fรคhigkeiten in LightOnOCR-2-1B und fรผge Chat-Integration hinzu.
## ๐๏ธ Architecture
- **Base Model**: LightOnOCR-2-1B (1B params)
- **Method**: LoRA Fine-tuning
- **Frozen**: Vision Encoder (behalte OCR)
- **Trainable**: Language Model (~50M params)
## ๐ง Special Tokens
- `` / `` - Tool invocation
- `` / `` - Tool results
- `` / `` - Reasoning
- `<|box_start|>` / `<|box_end|>` - Bounding boxes
## ๐ Training Data
- 34 Training Examples
- 4 Validation Examples
- Categories: Greetings, Facts, Math, Help
## ๐ Usage
1. Click "Check GPU" to verify H200
2. Configure training parameters
3. Click "Start Training"
4. Test in Chat/OCR tabs
## ๐ Notes
- Training takes ~5-10 minutes on H200
- OCR capabilities are preserved
- Tool-calling is learned from scratch
""")
if __name__ == "__main__":
demo.launch()