File size: 3,428 Bytes
11c2f83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Custom handler for LightOnOCR-2-1B on HuggingFace Inference Endpoints.
Requires transformers >= 5.0.0

Deployment options:
  A) Fork lightonai/LightOnOCR-2-1B and add this file → uses model_dir
  B) New repo with just handler.py + requirements.txt → loads from Hub
"""
import base64
import io
import os
from typing import Any, Dict

import torch
from PIL import Image
from transformers import LightOnOcrForConditionalGeneration, LightOnOcrProcessor

MODEL_ID = "lightonai/LightOnOCR-2-1B"


class EndpointHandler:
    def __init__(self, model_dir: str, **kwargs: Any):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        dtype = torch.bfloat16 if device == "cuda" else torch.float32

        self.device = device
        self.dtype = dtype

        # Use model_dir if it contains model weights (fork), otherwise load from Hub
        config_path = os.path.join(model_dir, "config.json")
        source = model_dir if os.path.exists(config_path) else MODEL_ID

        self.model = LightOnOcrForConditionalGeneration.from_pretrained(
            source, torch_dtype=dtype
        ).to(device)
        self.processor = LightOnOcrProcessor.from_pretrained(source)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        inputs_data = data.get("inputs", data)

        # --- Handle image input ---
        image = None
        image_url = None

        if isinstance(inputs_data, str):
            # Direct base64 string
            image = Image.open(io.BytesIO(base64.b64decode(inputs_data))).convert("RGB")
        elif isinstance(inputs_data, dict):
            if "image" in inputs_data:
                img_input = inputs_data["image"]
                if img_input.startswith(("http://", "https://")):
                    image_url = img_input
                else:
                    image = Image.open(io.BytesIO(base64.b64decode(img_input))).convert("RGB")
            elif "url" in inputs_data:
                image_url = inputs_data["url"]

        if image is None and image_url is None:
            return {"error": "No image provided. Send 'image' (base64 or URL) or 'url' in inputs."}

        # --- Build conversation ---
        prompt = inputs_data.get("prompt", None) if isinstance(inputs_data, dict) else None
        content = []
        if image_url:
            content.append({"type": "image", "url": image_url})
        elif image:
            content.append({"type": "image", "image": image})

        if prompt:
            content.append({"type": "text", "text": prompt})

        conversation = [{"role": "user", "content": content}]

        # --- Process & generate ---
        max_tokens = int(inputs_data.get("max_new_tokens", 4096)) if isinstance(inputs_data, dict) else 4096

        inputs = self.processor.apply_chat_template(
            conversation,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        )
        inputs = {
            k: v.to(device=self.device, dtype=self.dtype) if v.is_floating_point() else v.to(self.device)
            for k, v in inputs.items()
        }

        output_ids = self.model.generate(**inputs, max_new_tokens=max_tokens)
        generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
        output_text = self.processor.decode(generated_ids, skip_special_tokens=True)

        return {"generated_text": output_text}