GeoAgent / handler.py
EugeneZhao's picture
Create handler.py
854eade verified
raw
history blame
2.74 kB
import base64
from io import BytesIO
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
class EndpointHandler:
def __init__(self, path=""):
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
path,
torch_dtype=dtype,
device_map="auto",
)
self.processor = AutoProcessor.from_pretrained(path)
def _load_image(self, image_ref):
if image_ref is None:
raise ValueError("Missing image. Please provide `inputs.image_url` or `inputs.image_base64`.")
if isinstance(image_ref, str) and image_ref.startswith("http"):
resp = requests.get(image_ref, timeout=30)
resp.raise_for_status()
return Image.open(BytesIO(resp.content)).convert("RGB")
if isinstance(image_ref, str) and image_ref.startswith("data:image"):
_, b64data = image_ref.split(",", 1)
return Image.open(BytesIO(base64.b64decode(b64data))).convert("RGB")
# 默认当作本地路径处理
return Image.open(image_ref).convert("RGB")
def __call__(self, data):
payload = data.get("inputs", {}) or {}
prompt = payload.get("prompt", "Please analyze this image and infer its location.")
image_url = payload.get("image_url")
image_base64 = payload.get("image_base64")
max_new_tokens = int(payload.get("max_new_tokens", 256))
image = self._load_image(image_url or image_base64)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
}
]
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = self.processor(
text=[text],
images=[image],
return_tensors="pt",
).to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
)
generated_ids = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(model_inputs.input_ids, output_ids)
]
output_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0]
return {
"generated_text": output_text
}