File size: 5,720 Bytes
73085ab
 
6b701b8
73085ab
 
 
 
 
 
 
 
4563224
73085ab
 
 
 
 
cf932d8
ef8d70b
73085ab
 
 
 
 
 
 
 
 
f727a0f
73085ab
516f550
73085ab
 
 
 
 
 
 
 
 
6b701b8
 
73085ab
6b701b8
 
 
 
4563224
d946509
 
 
73085ab
6b701b8
 
 
4f22f1b
d946509
 
 
73085ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971e468
 
 
 
 
 
 
 
 
73085ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eecb9b2
4563224
 
 
 
 
 
3b9ccb8
4563224
ef8d70b
73085ab
4563224
 
 
 
73085ab
4f22f1b
4563224
4f22f1b
 
4563224
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import torch
import base64
from PIL import Image
from io import BytesIO
from typing import Dict, Any
from transformers import LlamaTokenizer, GenerationConfig
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import tempfile

DEFAULT_IMG_START_TOKEN = "<img>"
DEFAULT_IMG_END_TOKEN = "</img>"
DEFAULT_VIDEO_START_TOKEN = "<vid>"
DEFAULT_VIDEO_END_TOKEN = "</vid>"

class EndpointHandler:
    def __init__(self, model_path: str = "."):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
        self.model = HuskyForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        ).to(self.device).eval()

        self.gen_config = GenerationConfig(
            bos_token_id=1,
            do_sample=False,
            temperature=0.7,
            max_new_tokens=10240
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        inputs = self.preprocess(data)
        prediction = self.inference(inputs)
        return self.postprocess(prediction)

    def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
        prompt = request["inputs"]
        image_b64 = request.get("image", None)
        video_b64 = request.get("video", None)

        pixel_values = None

        if image_b64:
            image_bytes = base64.b64decode(image_b64)
            pixel_values = self._load_image(image_bytes).unsqueeze(0)  # [1, 3, 224, 224]
            if self.device == "cuda":
                pixel_values = pixel_values.half()
            pixel_values = pixel_values.to(self.device)
            prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)

        elif video_b64:
            video_bytes = base64.b64decode(video_b64)
            pixel_values = self._load_video(video_bytes)
            if self.device == "cuda":
                pixel_values = pixel_values.half()
            pixel_values = pixel_values.to(self.device)
            prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)

        return {
            "prompt": prompt,
            "pixel_values": pixel_values
        }

    def inference(self, inputs: Dict[str, Any]) -> str:
        prompt = inputs["prompt"]
        pixel_values = inputs["pixel_values"]

        model_inputs = self.tokenizer([prompt], return_tensors="pt")
        model_inputs.pop("token_type_ids", None)
        model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}

        if pixel_values is not None:
            output = self.model.generate(
                **model_inputs,
                pixel_values=pixel_values,
                generation_config=self.gen_config,
                return_dict_in_generate=True,
                output_scores=True
            )
        else:
            output = self.model.language_model.generate(
                **model_inputs,
                generation_config=self.gen_config,
                return_dict_in_generate=True,
                output_scores=True
            )
        # 🧠 打印 debug 信息
        generated_ids = output.sequences[0]
        print("📍生成的 token ids:", generated_ids.tolist())
        raw_text = self.tokenizer.decode(generated_ids, skip_special_tokens=False)
        clean_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        print("🧾 带特殊符号的输出:", raw_text)
        print("✅ 去掉特殊符号的输出:", clean_text)

        return clean_text  # 返回干净版本
    def postprocess(self, output: str) -> Dict[str, str]:
        return {"output": output.strip()}

    def _load_image(self, image_bytes: bytes) -> torch.Tensor:
        image = Image.open(BytesIO(image_bytes)).convert('RGB')
        crop_pct = 224 / 256
        size = int(224 / crop_pct)
        transform = T.Compose([
            T.Resize(size, interpolation=InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        return transform(image)

    def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor:
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
            tmpfile.write(video_bytes)
            video_path = tmpfile.name

        vr = VideoReader(video_path, ctx=cpu(0))
        total_frames = len(vr)
        indices = self.get_index(total_frames, num_segments)
        frames = [Image.fromarray(vr[i].asnumpy()) for i in indices]

        transform = T.Compose([
            T.Resize(224, interpolation=InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])
        processed = [transform(frame) for frame in frames]  # each: [3, 224, 224]
        video_tensor = torch.stack(processed, dim=0)  # [T, 3, 224, 224]
        video_tensor = video_tensor.permute(1, 0, 2, 3)  # [3, T, 224, 224]
        video_tensor = video_tensor.unsqueeze(0)  # [1, 3, T, 224, 224] ✅
        return video_tensor

    def get_index(self, num_frames: int, num_segments: int):
        if num_frames < num_segments:
            return list(range(num_frames)) + [num_frames - 1] * (num_segments - num_frames)
        interval = num_frames / num_segments
        return [int(interval * i) for i in range(num_segments)]