import os from typing import Dict, Any from PIL import Image from io import BytesIO from inference import Chat # 直接import你放的inference.py里Chat类 from robohusky.conversation import get_conv_template class EndpointHandler: def __init__(self, path: str = "."): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.chat = Chat( model_path=path, device=self.device, num_gpus=1, max_new_tokens=1024, load_8bit=False ) self.vision_feature = None self.modal_type = "text" self.conv = get_conv_template("husky").copy() def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: query = inputs.get("inputs", "") self.conv = get_conv_template("husky").copy() self.vision_feature = None self.modal_type = "text" if "image" in inputs: image_bytes = inputs["image"] image = Image.open(BytesIO(image_bytes)).convert("RGB") image.save("temp.jpg") self.vision_feature = self.chat.get_image_embedding("temp.jpg") self.modal_type = "image" elif "video" in inputs: video_bytes = inputs["video"] with open("temp.mp4", "wb") as f: f.write(video_bytes) self.vision_feature = self.chat.get_video_embedding("temp.mp4") self.modal_type = "video" return {"query": query} def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: processed = self.preprocess(inputs) query = processed["query"] conversations = self.chat.ask(text=query, conv=self.conv, modal_type=self.modal_type) outputs = self.chat.answer(conversations, self.vision_feature, modal_type=self.modal_type) self.conv.messages[-1][1] = outputs.strip() return {"output": outputs.strip()}