|
|
import os |
|
|
import torch |
|
|
import base64 |
|
|
from PIL import Image |
|
|
from io import BytesIO |
|
|
from typing import Dict, Any |
|
|
from transformers import LlamaTokenizer, GenerationConfig |
|
|
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration |
|
|
from decord import VideoReader, cpu |
|
|
import torchvision.transforms as T |
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
import tempfile |
|
|
|
|
|
DEFAULT_IMG_START_TOKEN = "<img>" |
|
|
DEFAULT_IMG_END_TOKEN = "</img>" |
|
|
DEFAULT_VIDEO_START_TOKEN = "<vid>" |
|
|
DEFAULT_VIDEO_END_TOKEN = "</vid>" |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_path: str = "."): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False) |
|
|
self.model = HuskyForConditionalGeneration.from_pretrained( |
|
|
model_path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device).eval() |
|
|
|
|
|
self.gen_config = GenerationConfig( |
|
|
bos_token_id=1, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
max_new_tokens=1024 |
|
|
) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
|
|
inputs = self.preprocess(data) |
|
|
prediction = self.inference(inputs) |
|
|
return self.postprocess(prediction) |
|
|
|
|
|
def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]: |
|
|
prompt = request["inputs"] |
|
|
image_b64 = request.get("image", None) |
|
|
video_b64 = request.get("video", None) |
|
|
|
|
|
pixel_values = None |
|
|
|
|
|
if image_b64: |
|
|
image_bytes = base64.b64decode(image_b64) |
|
|
pixel_values = self._load_image(image_bytes).unsqueeze(0) |
|
|
if self.device == "cuda": |
|
|
pixel_values = pixel_values.half() |
|
|
pixel_values = pixel_values.to(self.device) |
|
|
prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN) |
|
|
|
|
|
elif video_b64: |
|
|
video_bytes = base64.b64decode(video_b64) |
|
|
pixel_values = self._load_video(video_bytes) |
|
|
if self.device == "cuda": |
|
|
pixel_values = pixel_values.half() |
|
|
pixel_values = pixel_values.to(self.device) |
|
|
prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN) |
|
|
|
|
|
return { |
|
|
"prompt": prompt, |
|
|
"pixel_values": pixel_values |
|
|
} |
|
|
|
|
|
def inference(self, inputs: Dict[str, Any]) -> str: |
|
|
prompt = inputs["prompt"] |
|
|
pixel_values = inputs["pixel_values"] |
|
|
|
|
|
model_inputs = self.tokenizer([prompt], return_tensors="pt") |
|
|
model_inputs.pop("token_type_ids", None) |
|
|
model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()} |
|
|
|
|
|
if pixel_values is not None: |
|
|
output = self.model.generate( |
|
|
**model_inputs, |
|
|
pixel_values=pixel_values, |
|
|
generation_config=self.gen_config, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=True |
|
|
) |
|
|
else: |
|
|
output = self.model.language_model.generate( |
|
|
**model_inputs, |
|
|
generation_config=self.gen_config, |
|
|
return_dict_in_generate=True, |
|
|
output_scores=True |
|
|
) |
|
|
|
|
|
return self.tokenizer.decode(output.sequences[0], skip_special_tokens=True) |
|
|
|
|
|
def postprocess(self, output: str) -> Dict[str, str]: |
|
|
return {"output": output.strip()} |
|
|
|
|
|
def _load_image(self, image_bytes: bytes) -> torch.Tensor: |
|
|
image = Image.open(BytesIO(image_bytes)).convert('RGB') |
|
|
crop_pct = 224 / 256 |
|
|
size = int(224 / crop_pct) |
|
|
transform = T.Compose([ |
|
|
T.Resize(size, interpolation=InterpolationMode.BICUBIC), |
|
|
T.CenterCrop(224), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) |
|
|
]) |
|
|
return transform(image) |
|
|
|
|
|
def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor: |
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: |
|
|
tmpfile.write(video_bytes) |
|
|
video_path = tmpfile.name |
|
|
|
|
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
|
total_frames = len(vr) |
|
|
indices = self.get_index(total_frames, num_segments) |
|
|
frames = [Image.fromarray(vr[i].asnumpy()) for i in indices] |
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize(224, interpolation=InterpolationMode.BICUBIC), |
|
|
T.CenterCrop(224), |
|
|
T.ToTensor(), |
|
|
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
|
]) |
|
|
processed = [transform(frame) for frame in frames] |
|
|
video_tensor = torch.stack(processed, dim=0) |
|
|
video_tensor = video_tensor.permute(1, 0, 2, 3) |
|
|
video_tensor = video_tensor.unsqueeze(0) |
|
|
return video_tensor |
|
|
|
|
|
def get_index(self, num_frames: int, num_segments: int): |
|
|
if num_frames < num_segments: |
|
|
return list(range(num_frames)) + [num_frames - 1] * (num_segments - num_frames) |
|
|
interval = num_frames / num_segments |
|
|
return [int(interval * i) for i in range(num_segments)] |
|
|
|