embodied_explainer / handler.py
ccclemenfff's picture
ggg
4f22f1b
raw
history blame
5.32 kB
import os
import torch
import base64
from PIL import Image
from io import BytesIO
from typing import Dict, Any
from transformers import LlamaTokenizer, GenerationConfig
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import tempfile
DEFAULT_IMG_START_TOKEN = "<img>"
DEFAULT_IMG_END_TOKEN = "</img>"
DEFAULT_VIDEO_START_TOKEN = "<vid>"
DEFAULT_VIDEO_END_TOKEN = "</vid>"
class EndpointHandler:
def __init__(self, model_path: str = "."):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
self.model = HuskyForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device).eval()
self.gen_config = GenerationConfig(
bos_token_id=1,
do_sample=True,
temperature=0.7,
max_new_tokens=1024
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
inputs = self.preprocess(data)
prediction = self.inference(inputs)
return self.postprocess(prediction)
def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
prompt = request["inputs"]
image_b64 = request.get("image", None)
video_b64 = request.get("video", None)
pixel_values = None
if image_b64:
image_bytes = base64.b64decode(image_b64)
pixel_values = self._load_image(image_bytes).unsqueeze(0) # [1, 3, 224, 224]
if self.device == "cuda":
pixel_values = pixel_values.half()
pixel_values = pixel_values.to(self.device)
prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
elif video_b64:
video_bytes = base64.b64decode(video_b64)
pixel_values = self._load_video(video_bytes)
if self.device == "cuda":
pixel_values = pixel_values.half()
pixel_values = pixel_values.to(self.device)
prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
return {
"prompt": prompt,
"pixel_values": pixel_values
}
def inference(self, inputs: Dict[str, Any]) -> str:
prompt = inputs["prompt"]
pixel_values = inputs["pixel_values"]
model_inputs = self.tokenizer([prompt], return_tensors="pt")
model_inputs.pop("token_type_ids", None)
model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
if pixel_values is not None:
output = self.model.generate(
**model_inputs,
pixel_values=pixel_values,
generation_config=self.gen_config,
return_dict_in_generate=True,
output_scores=True
)
else:
output = self.model.language_model.generate(
**model_inputs,
generation_config=self.gen_config,
return_dict_in_generate=True,
output_scores=True
)
return self.tokenizer.decode(output.sequences[0], skip_special_tokens=True)
def postprocess(self, output: str) -> Dict[str, str]:
return {"output": output.strip()}
def _load_image(self, image_bytes: bytes) -> torch.Tensor:
image = Image.open(BytesIO(image_bytes)).convert('RGB')
crop_pct = 224 / 256
size = int(224 / crop_pct)
transform = T.Compose([
T.Resize(size, interpolation=InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
return transform(image)
def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor:
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
tmpfile.write(video_bytes)
video_path = tmpfile.name
vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(vr)
indices = self.get_index(total_frames, num_segments)
frames = [Image.fromarray(vr[i].asnumpy()) for i in indices]
transform = T.Compose([
T.Resize(224, interpolation=InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
processed = [transform(frame) for frame in frames] # each: [3, 224, 224]
video_tensor = torch.stack(processed, dim=0) # [T, 3, 224, 224]
video_tensor = video_tensor.permute(1, 0, 2, 3) # [3, T, 224, 224]
video_tensor = video_tensor.unsqueeze(0) # [1, 3, T, 224, 224] ✅
return video_tensor
def get_index(self, num_frames: int, num_segments: int):
if num_frames < num_segments:
return list(range(num_frames)) + [num_frames - 1] * (num_segments - num_frames)
interval = num_frames / num_segments
return [int(interval * i) for i in range(num_segments)]