File size: 5,232 Bytes
4832cce
73085ab
 
6b701b8
73085ab
 
 
 
 
 
 
 
4563224
73085ab
 
 
 
 
cf932d8
ef8d70b
73085ab
 
 
 
 
 
 
 
 
f727a0f
73085ab
4832cce
73085ab
 
 
4832cce
 
 
 
 
 
73085ab
 
 
6b701b8
 
4832cce
73085ab
6b701b8
 
 
 
4832cce
 
d946509
73085ab
6b701b8
 
 
4832cce
 
d946509
73085ab
 
4832cce
73085ab
 
 
 
 
 
 
 
 
4832cce
73085ab
4832cce
 
 
 
 
 
 
 
 
 
971e468
 
4832cce
971e468
73085ab
 
 
 
 
4832cce
73085ab
 
 
 
 
 
 
 
4832cce
4563224
 
 
 
 
 
3b9ccb8
4563224
ef8d70b
73085ab
4563224
 
 
4832cce
73085ab
4832cce
 
4563224
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
### ✅ handler.py(优化版)
import os
import torch
import base64
from PIL import Image
from io import BytesIO
from typing import Dict, Any
from transformers import LlamaTokenizer, GenerationConfig
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import tempfile

DEFAULT_IMG_START_TOKEN = "<img>"
DEFAULT_IMG_END_TOKEN = "</img>"
DEFAULT_VIDEO_START_TOKEN = "<vid>"
DEFAULT_VIDEO_END_TOKEN = "</vid>"

class EndpointHandler:
    def __init__(self, model_path: str = "."):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
        self.model = HuskyForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        ).to(self.device).eval()

        self.gen_config = GenerationConfig(
            bos_token_id=1,
            do_sample=False,
            temperature=0.7,
            max_new_tokens=1024
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        try:
            inputs = self.preprocess(data)
            prediction = self.inference(inputs)
            return self.postprocess(prediction)
        except Exception as e:
            return {"output": f"❌ 推理失败: {str(e)}"}

    def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
        prompt = request["inputs"]
        image_b64 = request.get("image", None)
        video_b64 = request.get("video", None)
        num_segments = request.get("num_segments", 16)

        pixel_values = None

        if image_b64:
            image_bytes = base64.b64decode(image_b64)
            pixel_values = self._load_image(image_bytes).unsqueeze(0)
            pixel_values = pixel_values.half() if self.device == "cuda" else pixel_values
            pixel_values = pixel_values.to(self.device)
            prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)

        elif video_b64:
            video_bytes = base64.b64decode(video_b64)
            pixel_values = self._load_video(video_bytes, num_segments)
            pixel_values = pixel_values.half() if self.device == "cuda" else pixel_values
            pixel_values = pixel_values.to(self.device)
            prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)

        return {"prompt": prompt, "pixel_values": pixel_values}

    def inference(self, inputs: Dict[str, Any]) -> str:
        prompt = inputs["prompt"]
        pixel_values = inputs["pixel_values"]

        model_inputs = self.tokenizer([prompt], return_tensors="pt")
        model_inputs.pop("token_type_ids", None)
        model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}

        print("📌 prompt token长度:", model_inputs["input_ids"].shape[1])
        if pixel_values is not None:
            print("🎞️ pixel shape:", pixel_values.shape)

        output = self.model.generate(
            **model_inputs,
            pixel_values=pixel_values,
            generation_config=self.gen_config,
            return_dict_in_generate=True,
            output_scores=True
        )

        generated_ids = output.sequences[0]
        clean_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        return clean_text

    def postprocess(self, output: str) -> Dict[str, str]:
        return {"output": output.strip()}

    def _load_image(self, image_bytes: bytes) -> torch.Tensor:
        image = Image.open(BytesIO(image_bytes)).convert('RGB')
        size = int(224 / (224 / 256))
        transform = T.Compose([
            T.Resize(size, interpolation=InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        return transform(image)

    def _load_video(self, video_bytes: bytes, num_segments=16) -> torch.Tensor:
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
            tmpfile.write(video_bytes)
            video_path = tmpfile.name

        vr = VideoReader(video_path, ctx=cpu(0))
        total_frames = len(vr)
        indices = self.get_index(total_frames, num_segments)
        frames = [Image.fromarray(vr[i].asnumpy()) for i in indices]

        transform = T.Compose([
            T.Resize(224, interpolation=InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        processed = [transform(frame) for frame in frames]
        video_tensor = torch.stack(processed, dim=0).permute(1, 0, 2, 3).unsqueeze(0)
        return video_tensor

    def get_index(self, num_frames: int, num_segments: int):
        if num_frames < num_segments:
            return list(range(num_frames)) + [num_frames - 1] * (num_segments - num_frames)
        interval = num_frames / num_segments
        return [int(interval * i) for i in range(num_segments)]