File size: 4,431 Bytes
73085ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf932d8
ef8d70b
73085ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef8d70b
73085ab
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import torch
from PIL import Image
from io import BytesIO
from typing import Dict, Any
from transformers import LlamaTokenizer, GenerationConfig
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
from robohusky.video_transformers import (
    GroupNormalize, GroupScale, GroupCenterCrop,
    Stack, ToTorchFormatTensor, get_index
)
from decord import VideoReader, cpu
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

DEFAULT_IMG_START_TOKEN = "<img>"
DEFAULT_IMG_END_TOKEN = "</img>"
DEFAULT_VIDEO_START_TOKEN = "<vid>"
DEFAULT_VIDEO_END_TOKEN = "</vid>"

class EndpointHandler:
    def __init__(self, model_path: str = "."):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=False)
        self.model = HuskyForConditionalGeneration.from_pretrained(
            model_path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        ).to(self.device).eval()

        self.gen_config = GenerationConfig(
            bos_token_id=1,
            do_sample=True,
            temperature=0.7,
            max_new_tokens=1024
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        # Hugging Face 会调用这个函数,data 是原始输入
        inputs = self.preprocess(data)
        prediction = self.inference(inputs)
        return self.postprocess(prediction)

    def preprocess(self, request: Dict[str, Any]) -> Dict[str, Any]:
        prompt = request["inputs"]
        image = request.get("image", None)
        video = request.get("video", None)

        if image:
            pixel_values = self._load_image(image).unsqueeze(0).to(self.device)
            prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
        elif video:
            pixel_values = self._load_video(video).unsqueeze(0).to(self.device)
            prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
        else:
            pixel_values = None

        return {
            "prompt": prompt,
            "pixel_values": pixel_values
        }

    def inference(self, inputs: Dict[str, Any]) -> str:
        prompt = inputs["prompt"]
        pixel_values = inputs["pixel_values"]

        model_inputs = self.tokenizer([prompt], return_tensors="pt")
        model_inputs.pop("token_type_ids", None)
        model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}

        if pixel_values is not None:
            output = self.model.generate(
                **model_inputs,
                pixel_values=pixel_values,
                generation_config=self.gen_config,
                return_dict_in_generate=True,
                output_scores=True
            )
        else:
            output = self.model.language_model.generate(
                **model_inputs,
                generation_config=self.gen_config,
                return_dict_in_generate=True,
                output_scores=True
            )

        return self.tokenizer.decode(output.sequences[0], skip_special_tokens=True)

    def postprocess(self, output: str) -> Dict[str, str]:
        return {"output": output.strip()}

    def _load_image(self, image_bytes: bytes) -> torch.Tensor:
        image = Image.open(BytesIO(image_bytes)).convert('RGB')
        crop_pct = 224 / 256
        size = int(224 / crop_pct)
        transform = T.Compose([
            T.Resize(size, interpolation=InterpolationMode.BICUBIC),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        return transform(image)

    def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor:
        with open("/tmp/temp_video.mp4", "wb") as f:
            f.write(video_bytes)
        vr = VideoReader("/tmp/temp_video.mp4", ctx=cpu(0))
        frame_indices = get_index(len(vr), num_segments)
        frames = [Image.fromarray(vr[idx].asnumpy()) for idx in frame_indices]

        transform = T.Compose([
            GroupScale(224),
            GroupCenterCrop(224),
            Stack(),
            ToTorchFormatTensor(),
            GroupNormalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
        ])
        return transform(frames)