File size: 2,900 Bytes
28a2884 99d1be1 513ab80 28a2884 4b05c12 e9beda9 856dd67 99d1be1 18c6d65 856dd67 513ab80 99d1be1 513ab80 99d1be1 18c6d65 99d1be1 513ab80 7634259 856dd67 18c6d65 513ab80 28a2884 513ab80 18c6d65 5456c8e 28a2884 513ab80 99d1be1 18c6d65 99d1be1 18c6d65 99d1be1 18c6d65 99d1be1 18c6d65 99d1be1 18c6d65 99d1be1 b7bae01 99d1be1 18c6d65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, model_dir=None):
self.load_model()
def load_model(self):
model_name = "openbmb/MiniCPM-V-2_6-int4"
print(f"Loading model: {model_name}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
self.model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
raise
def load_image(self, image_base64):
try:
print("Decoding base64 image...")
image_bytes = base64.b64decode(image_base64)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
print("Image loaded successfully.")
return image
except Exception as e:
print(f"Failed to decode or open image: {e}")
raise ValueError(f"Failed to open image from base64 string: {e}")
def predict(self, request):
print(f"Received request: {request}")
image_base64 = request.get("inputs", {}).get("image") # Get the base64 image from the request
question = request.get("inputs", {}).get("question", "What is in the image?")
stream = request.get("inputs", {}).get("stream", False)
if not image_base64:
print("Missing image in the request.")
return {"error": "Missing image."}
try:
image = self.load_image(image_base64) # Decode base64 and load image
msgs = [{"role": "user", "content": [image, question]}]
print(f"Processing prediction with question: {question}")
if stream:
generated_text = ""
print("Starting stream prediction...")
res = self.model.chat(
image=None,
msgs=msgs,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7,
stream=True
)
for new_text in res:
generated_text += new_text
print("Stream prediction completed.")
return {"output": generated_text}
else:
output = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer)
print("Prediction completed.")
return {"output": output}
except Exception as e:
print(f"Error during prediction: {e}")
return {"error": str(e)}
def __call__(self, data):
return self.predict(data)
# Example usage
handler = EndpointHandler()
|