ccclemenfff commited on
Commit
4563224
·
1 Parent(s): d946509
Files changed (1) hide show
  1. handler.py +24 -22
handler.py CHANGED
@@ -6,13 +6,10 @@ from io import BytesIO
6
  from typing import Dict, Any
7
  from transformers import LlamaTokenizer, GenerationConfig
8
  from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
9
- from robohusky.video_transformers import (
10
- GroupNormalize, GroupScale, GroupCenterCrop,
11
- Stack, ToTorchFormatTensor, get_index
12
- )
13
  from decord import VideoReader, cpu
14
  import torchvision.transforms as T
15
  from torchvision.transforms.functional import InterpolationMode
 
16
 
17
  DEFAULT_IMG_START_TOKEN = "<img>"
18
  DEFAULT_IMG_END_TOKEN = "</img>"
@@ -48,22 +45,17 @@ class EndpointHandler:
48
 
49
  if image_b64:
50
  image_bytes = base64.b64decode(image_b64)
51
- pixel_values = self._load_image(image_bytes).unsqueeze(0)
52
-
53
- # ⭐️ 如果模型是 float16,就把输入也变成 half
54
  if self.device == "cuda":
55
  pixel_values = pixel_values.half()
56
-
57
  pixel_values = pixel_values.to(self.device)
58
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
59
 
60
  elif video_b64:
61
  video_bytes = base64.b64decode(video_b64)
62
- pixel_values = self._load_video(video_bytes).unsqueeze(0)
63
-
64
  if self.device == "cuda":
65
  pixel_values = pixel_values.half()
66
-
67
  pixel_values = pixel_values.to(self.device)
68
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
69
 
@@ -114,17 +106,27 @@ class EndpointHandler:
114
  return transform(image)
115
 
116
  def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor:
117
- with open("/tmp/temp_video.mp4", "wb") as f:
118
- f.write(video_bytes)
119
- vr = VideoReader("/tmp/temp_video.mp4", ctx=cpu(0))
120
- frame_indices = get_index(len(vr), num_segments)
121
- frames = [Image.fromarray(vr[idx].asnumpy()) for idx in frame_indices]
 
 
 
122
 
123
  transform = T.Compose([
124
- GroupScale(224),
125
- GroupCenterCrop(224),
126
- Stack(),
127
- ToTorchFormatTensor(),
128
- GroupNormalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711])
129
  ])
130
- return transform(frames)
 
 
 
 
 
 
 
 
 
6
  from typing import Dict, Any
7
  from transformers import LlamaTokenizer, GenerationConfig
8
  from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
 
 
 
 
9
  from decord import VideoReader, cpu
10
  import torchvision.transforms as T
11
  from torchvision.transforms.functional import InterpolationMode
12
+ import tempfile
13
 
14
  DEFAULT_IMG_START_TOKEN = "<img>"
15
  DEFAULT_IMG_END_TOKEN = "</img>"
 
45
 
46
  if image_b64:
47
  image_bytes = base64.b64decode(image_b64)
48
+ pixel_values = self._load_image(image_bytes).unsqueeze(0) # [1, 3, 224, 224]
 
 
49
  if self.device == "cuda":
50
  pixel_values = pixel_values.half()
 
51
  pixel_values = pixel_values.to(self.device)
52
  prompt = prompt.replace("<image>", DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN)
53
 
54
  elif video_b64:
55
  video_bytes = base64.b64decode(video_b64)
56
+ pixel_values = self._load_video(video_bytes).unsqueeze(0) # [1, T, 3, 224, 224]
 
57
  if self.device == "cuda":
58
  pixel_values = pixel_values.half()
 
59
  pixel_values = pixel_values.to(self.device)
60
  prompt = prompt.replace("<video>", DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN)
61
 
 
106
  return transform(image)
107
 
108
  def _load_video(self, video_bytes: bytes, num_segments=8) -> torch.Tensor:
109
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
110
+ tmpfile.write(video_bytes)
111
+ video_path = tmpfile.name
112
+
113
+ vr = VideoReader(video_path, ctx=cpu(0))
114
+ total_frames = len(vr)
115
+ indices = self.get_index(total_frames, num_segments)
116
+ frames = [Image.fromarray(vr[i].asnumpy()) for i in indices]
117
 
118
  transform = T.Compose([
119
+ T.Resize(224, interpolation=InterpolationMode.BICUBIC),
120
+ T.CenterCrop(224),
121
+ T.ToTensor(),
122
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
 
123
  ])
124
+ processed = [transform(frame) for frame in frames] # each is [3, 224, 224]
125
+ video_tensor = torch.stack(processed, dim=0) # [T, 3, 224, 224]
126
+ return video_tensor
127
+
128
+ def get_index(self, num_frames: int, num_segments: int):
129
+ if num_frames < num_segments:
130
+ return list(range(num_frames)) + [num_frames - 1] * (num_segments - num_frames)
131
+ interval = num_frames / num_segments
132
+ return [int(interval * i) for i in range(num_segments)]