| | import os |
| | from typing import Dict, Any |
| | from PIL import Image |
| | from io import BytesIO |
| |
|
| | from inference import Chat |
| | from robohusky.conversation import get_conv_template |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = "."): |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.chat = Chat( |
| | model_path=path, |
| | device=self.device, |
| | num_gpus=1, |
| | max_new_tokens=1024, |
| | load_8bit=False |
| | ) |
| | self.vision_feature = None |
| | self.modal_type = "text" |
| | self.conv = get_conv_template("husky").copy() |
| |
|
| | def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
| | query = inputs.get("inputs", "") |
| | self.conv = get_conv_template("husky").copy() |
| | self.vision_feature = None |
| | self.modal_type = "text" |
| |
|
| | if "image" in inputs: |
| | image_bytes = inputs["image"] |
| | image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| | image.save("temp.jpg") |
| | self.vision_feature = self.chat.get_image_embedding("temp.jpg") |
| | self.modal_type = "image" |
| |
|
| | elif "video" in inputs: |
| | video_bytes = inputs["video"] |
| | with open("temp.mp4", "wb") as f: |
| | f.write(video_bytes) |
| | self.vision_feature = self.chat.get_video_embedding("temp.mp4") |
| | self.modal_type = "video" |
| |
|
| | return {"query": query} |
| |
|
| | def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: |
| | processed = self.preprocess(inputs) |
| | query = processed["query"] |
| |
|
| | conversations = self.chat.ask(text=query, conv=self.conv, modal_type=self.modal_type) |
| | outputs = self.chat.answer(conversations, self.vision_feature, modal_type=self.modal_type) |
| | self.conv.messages[-1][1] = outputs.strip() |
| | return {"output": outputs.strip()} |
| |
|