File size: 5,317 Bytes
73085ab
 
6b701b8
73085ab
 
 
 
 
 
 
 
4563224
73085ab
 
 
 
 
cf932d8
ef8d70b
73085ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b701b8
 
73085ab
6b701b8
 
 
 
4563224
d946509
 
 
73085ab
6b701b8
 
 
4f22f1b
d946509
 
 
73085ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4563224
 
 
 
 
 
3b9ccb8
4563224
ef8d70b
73085ab
4563224
 
 
 
73085ab
4f22f1b
4563224
4f22f1b
 
4563224
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import torch
import base64
from PIL import Image
from io import BytesIO
from typing import Dict, Any
from transformers import LlamaTokenizer, GenerationConfig
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import tempfile

DEFAULT_IMG_START_TOKEN = "<img>"
DEFAULT_IMG_END_TOKEN = "</img>"
DEFAULT_VIDEO_START_TOKEN = "<vid>"
DEFAULT_VIDEO_END_TOKEN = "</vid>"

class EndpointHandler:
    def __init__(self, model_path: str = "."):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
        self.model = HuskyForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        ).to(self.device).eval()

        self.gen_config = GenerationConfig(
            bos_token_id=1,
            do_sample=True,
            temperature=0.7,
            max_new_tokens=1024
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        inputs = self.preprocess(data)
        prediction = self.inference(inputs)
        return self.postprocess(prediction)

    def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
        prompt = request["inputs"]
        image_b64 = request.get("image", None)
        video_b64 = request.get("video", None)

        pixel_values = None

        if image_b64:
            image_bytes = base64.b64decode(image_b64)
            pixel_values = self._load_image(image_bytes).unsqueeze(0)  # [1, 3, 224, 224]
            if self.device == "cuda":
                pixel_values = pixel_values.half()
            pixel_values = pixel_values.to(self.device)
            prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)

        elif video_b64:
            video_bytes = base64.b64decode(video_b64)
            pixel_values = self._load_video(video_bytes)
            if self.device == "cuda":
                pixel_values = pixel_values.half()
            pixel_values = pixel_values.to(self.device)
            prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)

        return {
            "prompt": prompt,
            "pixel_values": pixel_values
        }

    def inference(self, inputs: Dict[str, Any]) -> str:
        prompt = inputs["prompt"]
        pixel_values = inputs["pixel_values"]

        model_inputs = self.tokenizer([prompt], return_tensors="pt")
        model_inputs.pop("token_type_ids", None)
        model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}

        if pixel_values is not None:
            output = self.model.generate(
                **model_inputs,
                pixel_values=pixel_values,
                generation_config=self.gen_config,
                return_dict_in_generate=True,
                output_scores=True
            )
        else:
            output = self.model.language_model.generate(
                **model_inputs,
                generation_config=self.gen_config,
                return_dict_in_generate=True,
                output_scores=True
            )

        return self.tokenizer.decode(output.sequences[0], skip_special_tokens=True)

    def postprocess(self, output: str) -> Dict[str, str]:
        return {"output": output.strip()}

    def _load_image(self, image_bytes: bytes) -> torch.Tensor:
        image = Image.open(BytesIO(image_bytes)).convert('RGB')
        crop_pct = 224 / 256
        size = int(224 / crop_pct)
        transform = T.Compose([
            T.Resize(size, interpolation=InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        return transform(image)

    def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor:
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
            tmpfile.write(video_bytes)
            video_path = tmpfile.name

        vr = VideoReader(video_path, ctx=cpu(0))
        total_frames = len(vr)
        indices = self.get_index(total_frames, num_segments)
        frames = [Image.fromarray(vr[i].asnumpy()) for i in indices]

        transform = T.Compose([
            T.Resize(224, interpolation=InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ])
        processed = [transform(frame) for frame in frames]  # each: [3, 224, 224]
        video_tensor = torch.stack(processed, dim=0)  # [T, 3, 224, 224]
        video_tensor = video_tensor.permute(1, 0, 2, 3)  # [3, T, 224, 224]
        video_tensor = video_tensor.unsqueeze(0)  # [1, 3, T, 224, 224] ✅
        return video_tensor

    def get_index(self, num_frames: int, num_segments: int):
        if num_frames < num_segments:
            return list(range(num_frames)) + [num_frames - 1] * (num_segments - num_frames)
        interval = num_frames / num_segments
        return [int(interval * i) for i in range(num_segments)]