embodied_explainer / handler.py
ccclemenfff's picture
16->8
eecb9b2
raw
history blame
5.72 kB
import os
import torch
import base64
from PIL import Image
from io import BytesIO
from typing import Dict, Any
from transformers import LlamaTokenizer, GenerationConfig
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import tempfile
DEFAULT_IMG_START_TOKEN = "<img>"
DEFAULT_IMG_END_TOKEN = "</img>"
DEFAULT_VIDEO_START_TOKEN = "<vid>"
DEFAULT_VIDEO_END_TOKEN = "</vid>"
class EndpointHandler:
def __init__(self, model_path: str = "."):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
self.model = HuskyForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device).eval()
self.gen_config = GenerationConfig(
bos_token_id=1,
do_sample=False,
temperature=0.7,
max_new_tokens=10240
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
inputs = self.preprocess(data)
prediction = self.inference(inputs)
return self.postprocess(prediction)
def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
prompt = request["inputs"]
image_b64 = request.get("image", None)
video_b64 = request.get("video", None)
pixel_values = None
if image_b64:
image_bytes = base64.b64decode(image_b64)
pixel_values = self._load_image(image_bytes).unsqueeze(0) # [1, 3, 224, 224]
if self.device == "cuda":
pixel_values = pixel_values.half()
pixel_values = pixel_values.to(self.device)
prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
elif video_b64:
video_bytes = base64.b64decode(video_b64)
pixel_values = self._load_video(video_bytes)
if self.device == "cuda":
pixel_values = pixel_values.half()
pixel_values = pixel_values.to(self.device)
prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
return {
"prompt": prompt,
"pixel_values": pixel_values
}
def inference(self, inputs: Dict[str, Any]) -> str:
prompt = inputs["prompt"]
pixel_values = inputs["pixel_values"]
model_inputs = self.tokenizer([prompt], return_tensors="pt")
model_inputs.pop("token_type_ids", None)
model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
if pixel_values is not None:
output = self.model.generate(
**model_inputs,
pixel_values=pixel_values,
generation_config=self.gen_config,
return_dict_in_generate=True,
output_scores=True
)
else:
output = self.model.language_model.generate(
**model_inputs,
generation_config=self.gen_config,
return_dict_in_generate=True,
output_scores=True
)
# 🧠 打印 debug 信息
generated_ids = output.sequences[0]
print("📍生成的 token ids:", generated_ids.tolist())
raw_text = self.tokenizer.decode(generated_ids, skip_special_tokens=False)
clean_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
print("🧾 带特殊符号的输出:", raw_text)
print("✅ 去掉特殊符号的输出:", clean_text)
return clean_text # 返回干净版本
def postprocess(self, output: str) -> Dict[str, str]:
return {"output": output.strip()}
def _load_image(self, image_bytes: bytes) -> torch.Tensor:
image = Image.open(BytesIO(image_bytes)).convert('RGB')
crop_pct = 224 / 256
size = int(224 / crop_pct)
transform = T.Compose([
T.Resize(size, interpolation=InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
return transform(image)
def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor:
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
tmpfile.write(video_bytes)
video_path = tmpfile.name
vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(vr)
indices = self.get_index(total_frames, num_segments)
frames = [Image.fromarray(vr[i].asnumpy()) for i in indices]
transform = T.Compose([
T.Resize(224, interpolation=InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
processed = [transform(frame) for frame in frames] # each: [3, 224, 224]
video_tensor = torch.stack(processed, dim=0) # [T, 3, 224, 224]
video_tensor = video_tensor.permute(1, 0, 2, 3) # [3, T, 224, 224]
video_tensor = video_tensor.unsqueeze(0) # [1, 3, T, 224, 224] ✅
return video_tensor
def get_index(self, num_frames: int, num_segments: int):
if num_frames < num_segments:
return list(range(num_frames)) + [num_frames - 1] * (num_segments - num_frames)
interval = num_frames / num_segments
return [int(interval * i) for i in range(num_segments)]