File size: 5,232 Bytes
4832cce 73085ab 6b701b8 73085ab 4563224 73085ab cf932d8 ef8d70b 73085ab f727a0f 73085ab 4832cce 73085ab 4832cce 73085ab 6b701b8 4832cce 73085ab 6b701b8 4832cce d946509 73085ab 6b701b8 4832cce d946509 73085ab 4832cce 73085ab 4832cce 73085ab 4832cce 971e468 4832cce 971e468 73085ab 4832cce 73085ab 4832cce 4563224 3b9ccb8 4563224 ef8d70b 73085ab 4563224 4832cce 73085ab 4832cce 4563224 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | ### ✅ handler.py(优化版)
import os
import torch
import base64
from PIL import Image
from io import BytesIO
from typing import Dict, Any
from transformers import LlamaTokenizer, GenerationConfig
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import tempfile
DEFAULT_IMG_START_TOKEN = "<img>"
DEFAULT_IMG_END_TOKEN = "</img>"
DEFAULT_VIDEO_START_TOKEN = "<vid>"
DEFAULT_VIDEO_END_TOKEN = "</vid>"
class EndpointHandler:
def __init__(self, model_path: str = "."):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
self.model = HuskyForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device).eval()
self.gen_config = GenerationConfig(
bos_token_id=1,
do_sample=False,
temperature=0.7,
max_new_tokens=1024
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
try:
inputs = self.preprocess(data)
prediction = self.inference(inputs)
return self.postprocess(prediction)
except Exception as e:
return {"output": f"❌ 推理失败: {str(e)}"}
def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
prompt = request["inputs"]
image_b64 = request.get("image", None)
video_b64 = request.get("video", None)
num_segments = request.get("num_segments", 16)
pixel_values = None
if image_b64:
image_bytes = base64.b64decode(image_b64)
pixel_values = self._load_image(image_bytes).unsqueeze(0)
pixel_values = pixel_values.half() if self.device == "cuda" else pixel_values
pixel_values = pixel_values.to(self.device)
prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
elif video_b64:
video_bytes = base64.b64decode(video_b64)
pixel_values = self._load_video(video_bytes, num_segments)
pixel_values = pixel_values.half() if self.device == "cuda" else pixel_values
pixel_values = pixel_values.to(self.device)
prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
return {"prompt": prompt, "pixel_values": pixel_values}
def inference(self, inputs: Dict[str, Any]) -> str:
prompt = inputs["prompt"]
pixel_values = inputs["pixel_values"]
model_inputs = self.tokenizer([prompt], return_tensors="pt")
model_inputs.pop("token_type_ids", None)
model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
print("📌 prompt token长度:", model_inputs["input_ids"].shape[1])
if pixel_values is not None:
print("🎞️ pixel shape:", pixel_values.shape)
output = self.model.generate(
**model_inputs,
pixel_values=pixel_values,
generation_config=self.gen_config,
return_dict_in_generate=True,
output_scores=True
)
generated_ids = output.sequences[0]
clean_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
return clean_text
def postprocess(self, output: str) -> Dict[str, str]:
return {"output": output.strip()}
def _load_image(self, image_bytes: bytes) -> torch.Tensor:
image = Image.open(BytesIO(image_bytes)).convert('RGB')
size = int(224 / (224 / 256))
transform = T.Compose([
T.Resize(size, interpolation=InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
return transform(image)
def _load_video(self, video_bytes: bytes, num_segments=16) -> torch.Tensor:
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
tmpfile.write(video_bytes)
video_path = tmpfile.name
vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(vr)
indices = self.get_index(total_frames, num_segments)
frames = [Image.fromarray(vr[i].asnumpy()) for i in indices]
transform = T.Compose([
T.Resize(224, interpolation=InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
processed = [transform(frame) for frame in frames]
video_tensor = torch.stack(processed, dim=0).permute(1, 0, 2, 3).unsqueeze(0)
return video_tensor
def get_index(self, num_frames: int, num_segments: int):
if num_frames < num_segments:
return list(range(num_frames)) + [num_frames - 1] * (num_segments - num_frames)
interval = num_frames / num_segments
return [int(interval * i) for i in range(num_segments)]
|