from transformers import AutoModel, AutoTokenizer, AutoModelForImageTextToText, AutoProcessor from typing import Dict, List, Any import torch import base64 import io from io import BytesIO from PIL import Image import os import tempfile class EndpointHandler: def __init__(self, model_dir = 'scb10x/typhoon-ocr1.5-2b'): model_path = model_dir self.model = AutoModelForImageTextToText.from_pretrained(model_path, dtype="auto", device_map="auto") self.processor = AutoProcessor.from_pretrained(model_path) def __call__(self, data: Dict[str, Any]) -> str: try: base64_string = None if "inputs" in data and isinstance(data["inputs"], str): base64_string = data["inputs"] # Case 2: Base64 in nested inputs dictionary elif "inputs" in data and isinstance(data["inputs"], dict): base64_string = data["inputs"].get("base64") # Case 3: Direct base64 at root level elif "base64" in data: base64_string = data["base64"] # Case 4: Try raw data as base64 elif isinstance(data, str): base64_string = data if not base64_string: return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())} print("Found base64 string, length:", len(base64_string)) # Remove data URL prefix if present if ',' in base64_string: base64_string = base64_string.split(',')[1] # Decode base64 to image image_data = base64.b64decode(base64_string) image = Image.open(io.BytesIO(image_data)) messages = [ { "role": "user", "content": [ { "type": "image", "image": image, }, { "type": "text", "text": "Return content as markdown" } ], } ] # Preparation for inference inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ) inputs = inputs.to(self.model.device) # Inference: Generation of the output generated_ids = self.model.generate(**inputs, max_new_tokens=10000) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text[0]) return output_text[0] except Exception as e: print(f"Error processing image: {e}") return str(e)