embodied_explainer / handler.py
ccclemenfff's picture
aaa
73085ab
raw
history blame
4.43 kB
import os
import torch
from PIL import Image
from io import BytesIO
from typing import Dict, Any
from transformers import LlamaTokenizer, GenerationConfig
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
from robohusky.video_transformers import (
GroupNormalize, GroupScale, GroupCenterCrop,
Stack, ToTorchFormatTensor, get_index
)
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
DEFAULT_IMG_START_TOKEN = "<img>"
DEFAULT_IMG_END_TOKEN = "</img>"
DEFAULT_VIDEO_START_TOKEN = "<vid>"
DEFAULT_VIDEO_END_TOKEN = "</vid>"
class EndpointHandler:
def __init__(self, model_path: str = "."):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
self.model = HuskyForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device).eval()
self.gen_config = GenerationConfig(
bos_token_id=1,
do_sample=True,
temperature=0.7,
max_new_tokens=1024
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
# Hugging Face 会调用这个函数,data 是原始输入
inputs = self.preprocess(data)
prediction = self.inference(inputs)
return self.postprocess(prediction)
def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
prompt = request["inputs"]
image = request.get("image", None)
video = request.get("video", None)
if image:
pixel_values = self._load_image(image).unsqueeze(0).to(self.device)
prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
elif video:
pixel_values = self._load_video(video).unsqueeze(0).to(self.device)
prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
else:
pixel_values = None
return {
"prompt": prompt,
"pixel_values": pixel_values
}
def inference(self, inputs: Dict[str, Any]) -> str:
prompt = inputs["prompt"]
pixel_values = inputs["pixel_values"]
model_inputs = self.tokenizer([prompt], return_tensors="pt")
model_inputs.pop("token_type_ids", None)
model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
if pixel_values is not None:
output = self.model.generate(
**model_inputs,
pixel_values=pixel_values,
generation_config=self.gen_config,
return_dict_in_generate=True,
output_scores=True
)
else:
output = self.model.language_model.generate(
**model_inputs,
generation_config=self.gen_config,
return_dict_in_generate=True,
output_scores=True
)
return self.tokenizer.decode(output.sequences[0], skip_special_tokens=True)
def postprocess(self, output: str) -> Dict[str, str]:
return {"output": output.strip()}
def _load_image(self, image_bytes: bytes) -> torch.Tensor:
image = Image.open(BytesIO(image_bytes)).convert('RGB')
crop_pct = 224 / 256
size = int(224 / crop_pct)
transform = T.Compose([
T.Resize(size, interpolation=InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
return transform(image)
def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor:
with open("/tmp/temp_video.mp4", "wb") as f:
f.write(video_bytes)
vr = VideoReader("/tmp/temp_video.mp4", ctx=cpu(0))
frame_indices = get_index(len(vr), num_segments)
frames = [Image.fromarray(vr[idx].asnumpy()) for idx in frame_indices]
transform = T.Compose([
GroupScale(224),
GroupCenterCrop(224),
Stack(),
ToTorchFormatTensor(),
GroupNormalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
])
return transform(frames)