minicpm-o-handler / handler.py
sreejith8100's picture
Update handler.py
513ab80 verified
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, model_dir=None):
self.load_model()
def load_model(self):
model_name = "openbmb/MiniCPM-V-2_6-int4"
print(f"Loading model: {model_name}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
self.model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
raise
def load_image(self, image_base64):
try:
print("Decoding base64 image...")
image_bytes = base64.b64decode(image_base64)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
print("Image loaded successfully.")
return image
except Exception as e:
print(f"Failed to decode or open image: {e}")
raise ValueError(f"Failed to open image from base64 string: {e}")
def predict(self, request):
print(f"Received request: {request}")
image_base64 = request.get("inputs", {}).get("image") # Get the base64 image from the request
question = request.get("inputs", {}).get("question", "What is in the image?")
stream = request.get("inputs", {}).get("stream", False)
if not image_base64:
print("Missing image in the request.")
return {"error": "Missing image."}
try:
image = self.load_image(image_base64) # Decode base64 and load image
msgs = [{"role": "user", "content": [image, question]}]
print(f"Processing prediction with question: {question}")
if stream:
generated_text = ""
print("Starting stream prediction...")
res = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7,
stream=True
)
for new_text in res:
generated_text += new_text
print("Stream prediction completed.")
return {"output": generated_text}
else:
output = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer)
print("Prediction completed.")
return {"output": output}
except Exception as e:
print(f"Error during prediction: {e}")
return {"error": str(e)}
def __call__(self, data):
return self.predict(data)
# Example usage
handler = EndpointHandler()