""" Custom handler for LightOnOCR-2-1B on HuggingFace Inference Endpoints. Requires transformers >= 5.0.0 Deployment options: A) Fork lightonai/LightOnOCR-2-1B and add this file → uses model_dir B) New repo with just handler.py + requirements.txt → loads from Hub """ import base64 import io import os from typing import Any, Dict import torch from PIL import Image from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor MODEL_ID = "lightonai/LightOnOCR-2-1B" class EndpointHandler: def __init__(self, model_dir: str, **kwargs: Any): device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if device == "cuda" else torch.float32 self.device = device self.dtype = dtype # Use model_dir if it contains model weights (fork), otherwise load from Hub config_path = os.path.join(model_dir, "config.json") source = model_dir if os.path.exists(config_path) else MODEL_ID self.model = LightOnOcrForConditionalGeneration.from_pretrained( source, torch_dtype=dtype ).to(device) self.processor = LightOnOcrProcessor.from_pretrained(source) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs_data = data.get("inputs", data) # --- Handle image input --- image = None image_url = None if isinstance(inputs_data, str): # Direct base64 string image = Image.open(io.BytesIO(base64.b64decode(inputs_data))).convert("RGB") elif isinstance(inputs_data, dict): if "image" in inputs_data: img_input = inputs_data["image"] if img_input.startswith(("http://", "https://")): image_url = img_input else: image = Image.open(io.BytesIO(base64.b64decode(img_input))).convert("RGB") elif "url" in inputs_data: image_url = inputs_data["url"] if image is None and image_url is None: return {"error": "No image provided. Send 'image' (base64 or URL) or 'url' in inputs."} # --- Build conversation --- prompt = inputs_data.get("prompt", None) if isinstance(inputs_data, dict) else None content = [] if image_url: content.append({"type": "image", "url": image_url}) elif image: content.append({"type": "image", "image": image}) if prompt: content.append({"type": "text", "text": prompt}) conversation = [{"role": "user", "content": content}] # --- Process & generate --- max_tokens = int(inputs_data.get("max_new_tokens", 4096)) if isinstance(inputs_data, dict) else 4096 inputs = self.processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ) inputs = { k: v.to(device=self.device, dtype=self.dtype) if v.is_floating_point() else v.to(self.device) for k, v in inputs.items() } output_ids = self.model.generate(**inputs, max_new_tokens=max_tokens) generated_ids = output_ids[0, inputs["input_ids"].shape[1]:] output_text = self.processor.decode(generated_ids, skip_special_tokens=True) return {"generated_text": output_text}