k-l-lambda commited on
Commit
489147c
·
verified ·
1 Parent(s): 84d5763

Upload media_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. media_utils.py +368 -0
media_utils.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import math
4
+ import os
5
+ from datetime import datetime, timezone
6
+ from typing import List, Literal, Optional, TypedDict
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ from pydantic import BaseModel, Field
11
+
12
+ try:
13
+ from mecord import VideoReader
14
+ except ImportError:
15
+ VideoReader = None
16
+
17
+
18
+ class VideoSpec(BaseModel):
19
+ media_type: str = Literal['video']
20
+ height: int = Field(..., gt=0, description="video frame height")
21
+ width: int = Field(..., gt=0, description="video frame width")
22
+ num_frames: int = Field(..., gt=0, description="num frames")
23
+ fps: float = Field(..., gt=0, description="average fps")
24
+
25
+ # optional, help to accelerate video reading
26
+ key_indices: list[int] = Field(None, description="key indices")
27
+ frame_time_info: dict = Field(None, description="frame time info")
28
+
29
+
30
+ class ImageInput(TypedDict):
31
+ type: Literal['image']
32
+ image: Image.Image
33
+
34
+
35
+ class VideoChunkInput(TypedDict):
36
+ type: Literal['video_chunk']
37
+ video_chunk: List[Image.Image]
38
+ prompt: Optional[str] = None
39
+
40
+
41
+ MediaInput = ImageInput | VideoChunkInput
42
+
43
+
44
+ def get_video_meta(video_src: bytes | str | os.PathLike,
45
+ accurate: bool = True) -> dict:
46
+ """Get the dimensions of a video."""
47
+ if isinstance(video_src, os.PathLike):
48
+ video_src = str(video_src)
49
+ # if b64 string, decode to bytes
50
+ if isinstance(video_src,
51
+ str) and video_src.startswith('data:video/mp4;base64,'):
52
+ video_src = base64.b64decode(video_src.split(',')[1])
53
+ video = VideoReader(video_src, auto_init=accurate, num_threads=1)
54
+ assert video.num_frames > 0, "Invalid video format."
55
+ assert video.original_width > 0 and video.original_height > 0, (
56
+ "Invalid video format.")
57
+ assert video.avg_fps > 0, "Invalid video format."
58
+ return VideoSpec(media_type='video',
59
+ height=video.original_height,
60
+ width=video.original_width,
61
+ num_frames=video.num_frames,
62
+ fps=video.avg_fps,
63
+ key_indices=video.key_indices,
64
+ frame_time_info=video.frame_time_info)
65
+
66
+
67
+ def timestamp_as_str(timestamp: float,
68
+ timestamp_mode: str = "hh:mm:ss.fff") -> str:
69
+ """Convert a timestamp to a string in the format of HH:MM:SS.mmm."""
70
+ if timestamp_mode == "hh:mm:ss.fff":
71
+ return (datetime.fromtimestamp(timestamp,
72
+ tz=timezone.utc).strftime("%H:%M:%S") +
73
+ f".{int((timestamp % 1) * 1000):03d}")
74
+ elif timestamp_mode == "mm:ss.fff":
75
+ return (datetime.fromtimestamp(timestamp,
76
+ tz=timezone.utc).strftime("%M:%S") +
77
+ f".{int((timestamp % 1) * 1000):03d}")
78
+ elif timestamp_mode == "mm:ss":
79
+ return datetime.fromtimestamp(timestamp,
80
+ tz=timezone.utc).strftime("%M:%S")
81
+ else:
82
+ raise ValueError(f"Invalid timestamp mode: {timestamp_mode}")
83
+
84
+
85
+ def navit_resize_image(
86
+ width: int,
87
+ height: int,
88
+ patch_size: int,
89
+ merge_kernel_size: int,
90
+ in_patch_limit: int,
91
+ patch_limit_on_one_side: int,
92
+ fixed_output_tokens: int | None,
93
+ ):
94
+ # Apply the patch limits.
95
+ s1 = math.sqrt(
96
+ in_patch_limit /
97
+ (max(1.0, width // patch_size) * max(1.0, height // patch_size)))
98
+ s2 = patch_limit_on_one_side * patch_size / width
99
+ s3 = patch_limit_on_one_side * patch_size / height
100
+ scale = min(1.0, s1, s2, s3)
101
+ new_w, new_h = max(1, int(width * scale)), max(1, int(height * scale))
102
+ new_w = min(new_w, patch_limit_on_one_side * patch_size)
103
+ new_h = min(new_h, patch_limit_on_one_side * patch_size)
104
+
105
+ # Calculate the padding to make the height and width divisible by the merge kernel size and patch size.
106
+ factor = merge_kernel_size * patch_size
107
+
108
+ pad_height = (factor - new_h % factor) % factor
109
+ pad_width = (factor - new_w % factor) % factor
110
+
111
+ if fixed_output_tokens is not None:
112
+ num_tokens = fixed_output_tokens
113
+ else:
114
+ # Calculate new dimensions after padding and patching
115
+ token_height = (new_h + pad_height) // factor
116
+ token_width = (new_w + pad_width) // factor
117
+
118
+ assert token_height * merge_kernel_size <= patch_limit_on_one_side, (
119
+ f"token_height {token_height} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
120
+ )
121
+ assert token_width * merge_kernel_size <= patch_limit_on_one_side, (
122
+ f"token_width {token_width} * merge_kernel_size {merge_kernel_size} > patch_limit_on_one_side {patch_limit_on_one_side}"
123
+ )
124
+
125
+ num_tokens = token_height * token_width
126
+ return {
127
+ "num_tokens": num_tokens,
128
+ "new_width": new_w,
129
+ "new_height": new_h,
130
+ "pad_width": pad_width,
131
+ "pad_height": pad_height,
132
+ "sampled_nframes": 1,
133
+ }
134
+
135
+
136
+ def navit_resize_video(
137
+ width: int,
138
+ height: int,
139
+ nframes: int,
140
+ avg_fps: float,
141
+ sample_fps: float,
142
+ patch_size: int,
143
+ merge_kernel_size: int,
144
+ in_patch_limit_each_frame: int,
145
+ patch_limit_on_one_side: int,
146
+ in_patch_limit_total: int | None,
147
+ max_num_frames_each_video: int | None,
148
+ fixed_output_tokens_each_frame: int | None,
149
+ ):
150
+ sample_fps = min(sample_fps, avg_fps)
151
+ # Calculate the number of frames to sample based on target FPS
152
+ sampled_nframes = max(round(nframes * sample_fps / avg_fps), 1)
153
+ if max_num_frames_each_video is not None:
154
+ sampled_nframes = min(sampled_nframes, max_num_frames_each_video)
155
+
156
+ if in_patch_limit_total is not None:
157
+ in_patch_limit_each_frame = min(
158
+ round(in_patch_limit_total / sampled_nframes),
159
+ in_patch_limit_each_frame)
160
+
161
+ ret = navit_resize_image(
162
+ width,
163
+ height,
164
+ patch_size,
165
+ merge_kernel_size,
166
+ in_patch_limit_each_frame,
167
+ patch_limit_on_one_side,
168
+ fixed_output_tokens_each_frame,
169
+ )
170
+ ret["sampled_nframes"] = sampled_nframes
171
+ return ret
172
+
173
+
174
+ def real_sample_fps_and_max_num_frames(
175
+ type_name: Literal["video", "video_chunk"],
176
+ sample_fps: float,
177
+ max_num_frames_each_video: int | None,
178
+ ) -> tuple[int, int | None]:
179
+ if type_name == "video":
180
+ return sample_fps, max_num_frames_each_video
181
+ elif type_name == "video_chunk":
182
+ max_num_frames_each_video = None
183
+ sample_fps = math.inf
184
+ return sample_fps, max_num_frames_each_video
185
+ else:
186
+ return math.inf, None
187
+
188
+
189
+ def _to_pil(data: str | bytes):
190
+ if isinstance(data, Image.Image):
191
+
192
+ return data.convert("RGB")
193
+ elif isinstance(data, str):
194
+ if data.startswith("data:"):
195
+ raw_base64 = data.split(",")[1]
196
+ return Image.open(io.BytesIO(
197
+ base64.b64decode(raw_base64))).convert("RGB")
198
+ else:
199
+ return Image.open(data).convert("RGB")
200
+ elif isinstance(data, bytes):
201
+ return Image.open(io.BytesIO(data)).convert("RGB")
202
+ else:
203
+ raise ValueError(f"Unsupported data type: {type(data)}")
204
+
205
+
206
+ def ensure_media_type(media: MediaInput) -> MediaInput:
207
+ if media['type'] == 'image':
208
+ media['image'] = _to_pil(media['image'])
209
+ return media
210
+ elif media['type'] == 'video_chunk':
211
+ media['video_chunk'] = [
212
+ _to_pil(frame) for frame in media['video_chunk']
213
+ ]
214
+ return media
215
+ else:
216
+ raise ValueError(f"Unsupported media type: {media['type']}")
217
+
218
+
219
+ def image_to_np(
220
+ image: Image.Image,
221
+ resize_to: tuple[int, int] | None = None,
222
+ mode: str = "resize",
223
+ raise_error_for_ill_resize: bool = True,
224
+ ) -> np.ndarray:
225
+ """Convert an image to a numpy array.
226
+
227
+ Args:
228
+ content: The image to convert.
229
+ resize_to: The size to resize the image to.
230
+ mode: The mode to resize the image to.
231
+ raise_error_for_ill_resize: Whether to raise an error for ill-sized resize.
232
+
233
+ Returns:
234
+ A numpy array.
235
+ """
236
+ assert isinstance(image, Image.Image), "image must be a PIL Image"
237
+ if resize_to is not None:
238
+ if mode == "resize":
239
+ image = image.resize(resize_to, resample=Image.Resampling.BICUBIC)
240
+
241
+ elif mode == "rescale_and_pad_to_center":
242
+ scale = min(resize_to[0] / image.width,
243
+ resize_to[1] / image.height, 1.0)
244
+ new_width = round(image.width * scale)
245
+ new_height = round(image.height * scale)
246
+ if new_width == 0 or new_height == 0:
247
+ if raise_error_for_ill_resize:
248
+ raise ValueError(
249
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
250
+ )
251
+ else:
252
+ return np.zeros((resize_to[1], resize_to[0], 3),
253
+ dtype=np.uint8)
254
+
255
+ image = image.resize((new_width, new_height),
256
+ resample=Image.Resampling.BICUBIC)
257
+ padding_left = (resize_to[0] - new_width) // 2
258
+ padding_right = resize_to[0] - new_width - padding_left
259
+ padding_top = (resize_to[1] - new_height) // 2
260
+ padding_bottom = resize_to[1] - new_height - padding_top
261
+ image = np.asarray(image)
262
+ image = np.pad(
263
+ image,
264
+ ((padding_top, padding_bottom), (padding_left, padding_right),
265
+ (0, 0)),
266
+ mode="constant",
267
+ constant_values=0,
268
+ )
269
+ assert image.shape == (resize_to[1], resize_to[0], 3)
270
+
271
+ elif mode == "rescale_and_pad_to_rightbottom":
272
+ scale = min(resize_to[0] / image.width,
273
+ resize_to[1] / image.height, 1.0)
274
+ new_width = round(image.width * scale)
275
+ new_height = round(image.height * scale)
276
+ if new_width == 0 or new_height == 0:
277
+ if raise_error_for_ill_resize:
278
+ raise ValueError(
279
+ f"Invalid resize to: {resize_to}, from image size: {image.size}"
280
+ )
281
+ else:
282
+ return np.zeros((resize_to[1], resize_to[0], 3),
283
+ dtype=np.uint8)
284
+
285
+ image = image.resize((new_width, new_height),
286
+ resample=Image.Resampling.BICUBIC)
287
+ padding_right = resize_to[0] - new_width
288
+ padding_bottom = resize_to[1] - new_height
289
+ image = np.asarray(image)
290
+ image = np.pad(
291
+ image,
292
+ ((0, padding_bottom), (0, padding_right), (0, 0)),
293
+ mode="constant",
294
+ constant_values=0,
295
+ )
296
+ assert image.shape == (resize_to[1], resize_to[0], 3)
297
+
298
+ else:
299
+ raise ValueError(f"Invalid mode: {mode}")
300
+
301
+ if isinstance(image, Image.Image):
302
+ return np.asarray(image)
303
+ else:
304
+ return image
305
+
306
+
307
+ def navit_patchify(pixel_values: np.ndarray,
308
+ patch_size: int) -> dict[str, np.ndarray]:
309
+ """Reshape the pixel values to a navit shape.
310
+
311
+ Args:
312
+ pixel_values: np.ndarray, shape (t, h, w, c)
313
+ patch_size: int
314
+
315
+ Returns:
316
+ dict[str, np.ndarray]
317
+ - patches: np.ndarray, shape (t * h//patch_size * w//patch_size, c, patch_size, patch_size)
318
+ - grid_thw: np.ndarray, (t, h//patch_size, w//patch_size)
319
+ """
320
+ T, H, W, C = pixel_values.shape
321
+ assert C == 3, "pixel_values must have 3 channels"
322
+
323
+ patches = pixel_values.reshape(T, H // patch_size, patch_size,
324
+ W // patch_size, patch_size, C)
325
+ # (T, H//patch_size, W//patch_size, C, patch_size, patch_size)
326
+ patches = patches.transpose(0, 1, 3, 5, 2, 4)
327
+ patches = patches.reshape(-1, C, patch_size, patch_size)
328
+ grid_thw = np.array([T, H // patch_size, W // patch_size])
329
+ return {"pixel_values": patches, "grid_thw": grid_thw}
330
+
331
+
332
+ def normalize(x: np.ndarray,
333
+ mean,
334
+ std_inv,
335
+ pixels_dtype: np.dtype = np.float32) -> np.ndarray:
336
+ """Normalize the image.
337
+
338
+ Args:
339
+ x: The image to normalize. The shape is (..., 3). The dtype is uint8. The range is [0, 255].
340
+ mean: The mean of the image.
341
+ std_inv: The inverse of the std of the image.
342
+ pixels_dtype: The dtype of the image.
343
+ Returns:
344
+ The normalized image. The shape is (..., 3). The dtype is determined by the pixels_dtype.
345
+ """
346
+ x = (x / 255.0).astype(pixels_dtype)
347
+ x -= mean
348
+ x *= std_inv
349
+ return x
350
+
351
+
352
+ def _to_tensor(data, **kwargs):
353
+ import torch
354
+
355
+ if isinstance(data, np.ndarray):
356
+ return torch.from_numpy(data).to(**kwargs)
357
+ elif isinstance(data, torch.Tensor):
358
+ return data.to(**kwargs)
359
+ elif isinstance(data, list):
360
+ return [_to_tensor(item, **kwargs) for item in data]
361
+ elif isinstance(data, tuple):
362
+ return tuple(_to_tensor(item, **kwargs) for item in data)
363
+ elif isinstance(data, dict):
364
+ return {k: _to_tensor(v, **kwargs) for k, v in data.items()}
365
+ elif data is None:
366
+ return None
367
+ else:
368
+ raise ValueError(f"Unsupported data type: {type(data)}")