import base64 from io import BytesIO import requests import torch from PIL import Image from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration class EndpointHandler: def __init__(self, path=""): dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( path, torch_dtype=dtype, device_map="auto", ) self.processor = AutoProcessor.from_pretrained(path) def _load_image(self, image_ref): if image_ref is None: raise ValueError("Missing image. Please provide `inputs.image_url` or `inputs.image_base64`.") if isinstance(image_ref, str) and image_ref.startswith("http"): resp = requests.get(image_ref, timeout=30) resp.raise_for_status() return Image.open(BytesIO(resp.content)).convert("RGB") if isinstance(image_ref, str) and image_ref.startswith("data:image"): _, b64data = image_ref.split(",", 1) return Image.open(BytesIO(base64.b64decode(b64data))).convert("RGB") # 默认当作本地路径处理 return Image.open(image_ref).convert("RGB") def __call__(self, data): payload = data.get("inputs", {}) or {} prompt = payload.get("prompt", "Please analyze this image and infer its location.") image_url = payload.get("image_url") image_base64 = payload.get("image_base64") max_new_tokens = int(payload.get("max_new_tokens", 256)) image = self._load_image(image_url or image_base64) messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], } ] text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) model_inputs = self.processor( text=[text], images=[image], return_tensors="pt", ).to(self.model.device) with torch.no_grad(): output_ids = self.model.generate( **model_inputs, max_new_tokens=max_new_tokens, ) generated_ids = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, output_ids) ] output_text = self.processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True, )[0] return { "generated_text": output_text }