ccclemenfff commited on
Commit
4f22f1b
·
1 Parent(s): 3b9ccb8
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -53,7 +53,7 @@ class EndpointHandler:
53
 
54
  elif video_b64:
55
  video_bytes = base64.b64decode(video_b64)
56
- pixel_values = self._load_video(video_bytes).unsqueeze(0) # [1, T, 3, 224, 224]
57
  if self.device == "cuda":
58
  pixel_values = pixel_values.half()
59
  pixel_values = pixel_values.to(self.device)
@@ -121,10 +121,10 @@ class EndpointHandler:
121
  T.ToTensor(),
122
  T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
123
  ])
124
- processed = [transform(frame) for frame in frames] # each is [3, 224, 224]
125
  video_tensor = torch.stack(processed, dim=0) # [T, 3, 224, 224]
126
- video_tensor = video_tensor.unsqueeze(0) # [1, T, 3, 224, 224]
127
- video_tensor = video_tensor.permute(0, 2, 1, 3, 4) # 💥 [1, 3, T, 224, 224]
128
  return video_tensor
129
 
130
  def get_index(self, num_frames: int, num_segments: int):
 
53
 
54
  elif video_b64:
55
  video_bytes = base64.b64decode(video_b64)
56
+ pixel_values = self._load_video(video_bytes)
57
  if self.device == "cuda":
58
  pixel_values = pixel_values.half()
59
  pixel_values = pixel_values.to(self.device)
 
121
  T.ToTensor(),
122
  T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
123
  ])
124
+ processed = [transform(frame) for frame in frames] # each: [3, 224, 224]
125
  video_tensor = torch.stack(processed, dim=0) # [T, 3, 224, 224]
126
+ video_tensor = video_tensor.permute(1, 0, 2, 3) # [3, T, 224, 224]
127
+ video_tensor = video_tensor.unsqueeze(0) # [1, 3, T, 224, 224]
128
  return video_tensor
129
 
130
  def get_index(self, num_frames: int, num_segments: int):