File size: 3,151 Bytes
eec31df 5053334 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | from transformers import AutoModel, AutoTokenizer, AutoModelForImageTextToText
from typing import Dict, List, Any
import torch
import base64
from io import BytesIO
from PIL import Image
import os
import tempfile
class EndpointHandler:
def __init__(self, model_dir = 'scb10x/typhoon-ocr1.5-2b'):
model_path = model_dir
self.model = AutoModelForImageTextToText.from_pretrained(model_path, dtype="auto", device_map="auto")
selfprocessor = AutoProcessor.from_pretrained(model_path)
def __call__(self, data: Dict[str, Any]) -> str:
try:
base64_string = None
if "inputs" in data and isinstance(data["inputs"], str):
base64_string = data["inputs"]
# Case 2: Base64 in nested inputs dictionary
elif "inputs" in data and isinstance(data["inputs"], dict):
base64_string = data["inputs"].get("base64")
# Case 3: Direct base64 at root level
elif "base64" in data:
base64_string = data["base64"]
# Case 4: Try raw data as base64
elif isinstance(data, str):
base64_string = data
if not base64_string:
return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())}
print("Found base64 string, length:", len(base64_string))
# Remove data URL prefix if present
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
# Decode base64 to image
image_data = base64.b64decode(base64_string)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image_data,
},
{
"type": "text",
"text": prompt
}
],
}
]
# Preparation for inference
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
inputs = inputs.to(self.model.device)
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=10000)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text[0])
return output_text[0]
except Exception as e:
print(f"Error processing image: {e}")
return str(e) |