| import base64 |
| from io import BytesIO |
|
|
| import requests |
| import torch |
| from PIL import Image |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| path, |
| torch_dtype=dtype, |
| device_map="auto", |
| ) |
| self.processor = AutoProcessor.from_pretrained(path) |
|
|
| def _load_image(self, image_ref): |
| if image_ref is None: |
| raise ValueError("Missing image. Please provide `inputs.image_url` or `inputs.image_base64`.") |
|
|
| if isinstance(image_ref, str) and image_ref.startswith("http"): |
| resp = requests.get(image_ref, timeout=30) |
| resp.raise_for_status() |
| return Image.open(BytesIO(resp.content)).convert("RGB") |
|
|
| if isinstance(image_ref, str) and image_ref.startswith("data:image"): |
| _, b64data = image_ref.split(",", 1) |
| return Image.open(BytesIO(base64.b64decode(b64data))).convert("RGB") |
|
|
| |
| return Image.open(image_ref).convert("RGB") |
|
|
| def __call__(self, data): |
| payload = data.get("inputs", {}) or {} |
|
|
| prompt = payload.get("prompt", "Please analyze this image and infer its location.") |
| image_url = payload.get("image_url") |
| image_base64 = payload.get("image_base64") |
| max_new_tokens = int(payload.get("max_new_tokens", 256)) |
|
|
| image = self._load_image(image_url or image_base64) |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image"}, |
| {"type": "text", "text": prompt}, |
| ], |
| } |
| ] |
|
|
| text = self.processor.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| model_inputs = self.processor( |
| text=[text], |
| images=[image], |
| return_tensors="pt", |
| ).to(self.model.device) |
|
|
| with torch.no_grad(): |
| output_ids = self.model.generate( |
| **model_inputs, |
| max_new_tokens=max_new_tokens, |
| ) |
|
|
| generated_ids = [ |
| out_ids[len(in_ids):] |
| for in_ids, out_ids in zip(model_inputs.input_ids, output_ids) |
| ] |
|
|
| output_text = self.processor.batch_decode( |
| generated_ids, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=True, |
| )[0] |
|
|
| return { |
| "generated_text": output_text |
| } |