|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
from io import BytesIO |
|
|
import base64 |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir=None): |
|
|
self.load_model() |
|
|
|
|
|
def load_model(self): |
|
|
model_name = "openbmb/MiniCPM-V-2_6-int4" |
|
|
print(f"Loading model: {model_name}") |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16) |
|
|
self.model.eval() |
|
|
print("Model loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
raise |
|
|
|
|
|
def load_image(self, image_base64): |
|
|
try: |
|
|
print("Decoding base64 image...") |
|
|
image_bytes = base64.b64decode(image_base64) |
|
|
image = Image.open(BytesIO(image_bytes)).convert("RGB") |
|
|
print("Image loaded successfully.") |
|
|
return image |
|
|
except Exception as e: |
|
|
print(f"Failed to decode or open image: {e}") |
|
|
raise ValueError(f"Failed to open image from base64 string: {e}") |
|
|
|
|
|
def predict(self, request): |
|
|
print(f"Received request: {request}") |
|
|
|
|
|
image_base64 = request.get("inputs", {}).get("image") |
|
|
question = request.get("inputs", {}).get("question", "What is in the image?") |
|
|
stream = request.get("inputs", {}).get("stream", False) |
|
|
|
|
|
if not image_base64: |
|
|
print("Missing image in the request.") |
|
|
return {"error": "Missing image."} |
|
|
|
|
|
try: |
|
|
image = self.load_image(image_base64) |
|
|
msgs = [{"role": "user", "content": [image, question]}] |
|
|
|
|
|
print(f"Processing prediction with question: {question}") |
|
|
|
|
|
if stream: |
|
|
generated_text = "" |
|
|
print("Starting stream prediction...") |
|
|
res = self.model.chat( |
|
|
image=None, |
|
|
msgs=msgs, |
|
|
tokenizer=self.tokenizer, |
|
|
sampling=True, |
|
|
temperature=0.7, |
|
|
stream=True |
|
|
) |
|
|
for new_text in res: |
|
|
generated_text += new_text |
|
|
print("Stream prediction completed.") |
|
|
return {"output": generated_text} |
|
|
else: |
|
|
output = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer) |
|
|
print("Prediction completed.") |
|
|
return {"output": output} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during prediction: {e}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
def __call__(self, data): |
|
|
return self.predict(data) |
|
|
|
|
|
|
|
|
handler = EndpointHandler() |
|
|
|