from transformers import AutoModel, AutoTokenizer, AutoModelForImageTextToText from typing import Dict, List, Any import torch import base64 from io import BytesIO from PIL import Image import os import tempfile class EndpointHandler: def __init__(self, model_dir = 'scb10x/typhoon-ocr1.5-2b'): model_path = model_dir self.model = AutoModelForImageTextToText.from_pretrained(model_path, dtype="auto", device_map="auto") selfprocessor = AutoProcessor.from_pretrained(model_path) def __call__(self, data: Dict[str, Any]) -> str: try: base64_string = None if "inputs" in data and isinstance(data["inputs"], str): base64_string = data["inputs"] # Case 2: Base64 in nested inputs dictionary elif "inputs" in data and isinstance(data["inputs"], dict): base64_string = data["inputs"].get("base64") # Case 3: Direct base64 at root level elif "base64" in data: base64_string = data["base64"] # Case 4: Try raw data as base64 elif isinstance(data, str): base64_string = data if not base64_string: return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())} print("Found base64 string, length:", len(base64_string)) # Remove data URL prefix if present if ',' in base64_string: base64_string = base64_string.split(',')[1] # Decode base64 to image image_data = base64.b64decode(base64_string) messages = [ { "role": "user", "content": [ { "type": "image", "image": image_data, }, { "type": "text", "text": prompt } ], } ] # Preparation for inference inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ) inputs = inputs.to(self.model.device) # Inference: Generation of the output generated_ids = self.model.generate(**inputs, max_new_tokens=10000) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text[0]) return output_text[0] except Exception as e: print(f"Error processing image: {e}") return str(e)