import torch from PIL import Image from transformers import AutoModel, AutoTokenizer from io import BytesIO import base64 class EndpointHandler: def __init__(self, model_dir=None): self.load_model() def load_model(self): model_name = "openbmb/MiniCPM-V-2_6-int4" print(f"Loading model: {model_name}") try: self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16) self.model.eval() print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") raise def load_image(self, image_base64): try: print("Decoding base64 image...") image_bytes = base64.b64decode(image_base64) image = Image.open(BytesIO(image_bytes)).convert("RGB") print("Image loaded successfully.") return image except Exception as e: print(f"Failed to decode or open image: {e}") raise ValueError(f"Failed to open image from base64 string: {e}") def predict(self, request): print(f"Received request: {request}") image_base64 = request.get("inputs", {}).get("image") # Get the base64 image from the request question = request.get("inputs", {}).get("question", "What is in the image?") stream = request.get("inputs", {}).get("stream", False) if not image_base64: print("Missing image in the request.") return {"error": "Missing image."} try: image = self.load_image(image_base64) # Decode base64 and load image msgs = [{"role": "user", "content": [image, question]}] print(f"Processing prediction with question: {question}") if stream: generated_text = "" print("Starting stream prediction...") res = self.model.chat( image=None, msgs=msgs, tokenizer=self.tokenizer, sampling=True, temperature=0.7, stream=True ) for new_text in res: generated_text += new_text print("Stream prediction completed.") return {"output": generated_text} else: output = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer) print("Prediction completed.") return {"output": output} except Exception as e: print(f"Error during prediction: {e}") return {"error": str(e)} def __call__(self, data): return self.predict(data) # Example usage handler = EndpointHandler()