|
|
import os |
|
|
from typing import Dict, Any |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
|
|
|
from inference import Chat |
|
|
from robohusky.conversation import get_conv_template |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = "."): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.chat = Chat( |
|
|
model_path=path, |
|
|
device=self.device, |
|
|
num_gpus=1, |
|
|
max_new_tokens=1024, |
|
|
load_8bit=False |
|
|
) |
|
|
self.vision_feature = None |
|
|
self.modal_type = "text" |
|
|
self.conv = get_conv_template("husky").copy() |
|
|
|
|
|
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
query = inputs.get("inputs", "") |
|
|
self.conv = get_conv_template("husky").copy() |
|
|
self.vision_feature = None |
|
|
self.modal_type = "text" |
|
|
|
|
|
if "image" in inputs: |
|
|
image_bytes = inputs["image"] |
|
|
image = Image.open(BytesIO(image_bytes)).convert("RGB") |
|
|
image.save("temp.jpg") |
|
|
self.vision_feature = self.chat.get_image_embedding("temp.jpg") |
|
|
self.modal_type = "image" |
|
|
|
|
|
elif "video" in inputs: |
|
|
video_bytes = inputs["video"] |
|
|
with open("temp.mp4", "wb") as f: |
|
|
f.write(video_bytes) |
|
|
self.vision_feature = self.chat.get_video_embedding("temp.mp4") |
|
|
self.modal_type = "video" |
|
|
|
|
|
return {"query": query} |
|
|
|
|
|
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]: |
|
|
processed = self.preprocess(inputs) |
|
|
query = processed["query"] |
|
|
|
|
|
conversations = self.chat.ask(text=query, conv=self.conv, modal_type=self.modal_type) |
|
|
outputs = self.chat.answer(conversations, self.vision_feature, modal_type=self.modal_type) |
|
|
self.conv.messages[-1][1] = outputs.strip() |
|
|
return {"output": outputs.strip()} |
|
|
|