litwell commited on
Commit
4a0ee93
·
verified ·
1 Parent(s): 6f287f0

Upload models/src/training/qwen_vl_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/src/training/qwen_vl_utils.py +399 -0
models/src/training/qwen_vl_utils.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import logging
5
+ import math
6
+ import os
7
+ import sys
8
+ import time
9
+ # import cv2
10
+ import warnings
11
+ from functools import lru_cache
12
+ from io import BytesIO
13
+
14
+ import requests
15
+ import torch
16
+ import torchvision
17
+ from packaging import version
18
+ from PIL import Image
19
+ from torchvision import io, transforms
20
+ from torchvision.transforms import InterpolationMode
21
+ from typing import Optional
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ IMAGE_FACTOR = 28
27
+ MIN_PIXELS = 4 * 28 * 28
28
+ MAX_PIXELS = 16384 * 28 * 28
29
+ MAX_RATIO = 200
30
+
31
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
32
+ VIDEO_MAX_PIXELS = 768 * 28 * 28
33
+ FRAME_FACTOR = 2
34
+ FPS = 2.0
35
+ FPS_MIN_FRAMES = 4
36
+ FPS_MAX_FRAMES = 768
37
+
38
+ # Set the maximum number of video token inputs.
39
+ # Here, 128K represents the maximum number of input tokens for the VLLM model.
40
+ # Remember to adjust it according to your own configuration.
41
+ VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))
42
+ logger.info(f"set VIDEO_TOTAL_PIXELS: {VIDEO_TOTAL_PIXELS}")
43
+
44
+
45
+ # def count_frames(video_path):
46
+ # # 打开视频文件
47
+ # video = cv2.VideoCapture(video_path)
48
+
49
+ # # 统计实际读取到的帧数
50
+ # actual_frame_count = 0
51
+ # while True:
52
+ # ret, frame = video.read()
53
+ # if not ret:
54
+ # break
55
+ # actual_frame_count += 1
56
+
57
+ # # 释放视频对象
58
+ # video.release()
59
+ # return actual_frame_count
60
+
61
+ def round_by_factor(number: int, factor: int) -> int:
62
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
63
+ return round(number / factor) * factor
64
+
65
+
66
+ def ceil_by_factor(number: int, factor: int) -> int:
67
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
68
+ return math.ceil(number / factor) * factor
69
+
70
+
71
+ def floor_by_factor(number: int, factor: int) -> int:
72
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
73
+ return math.floor(number / factor) * factor
74
+
75
+
76
+ def smart_resize(
77
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
78
+ ) -> tuple[int, int]:
79
+ """
80
+ Rescales the image so that the following conditions are met:
81
+
82
+ 1. Both dimensions (height and width) are divisible by 'factor'.
83
+
84
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
85
+
86
+ 3. The aspect ratio of the image is maintained as closely as possible.
87
+ """
88
+ if max(height, width) / min(height, width) > MAX_RATIO:
89
+ raise ValueError(
90
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
91
+ )
92
+ h_bar = max(factor, round_by_factor(height, factor))
93
+ w_bar = max(factor, round_by_factor(width, factor))
94
+ if h_bar * w_bar > max_pixels:
95
+ beta = math.sqrt((height * width) / max_pixels)
96
+ h_bar = floor_by_factor(height / beta, factor)
97
+ w_bar = floor_by_factor(width / beta, factor)
98
+ elif h_bar * w_bar < min_pixels:
99
+ beta = math.sqrt(min_pixels / (height * width))
100
+ h_bar = ceil_by_factor(height * beta, factor)
101
+ w_bar = ceil_by_factor(width * beta, factor)
102
+ return h_bar, w_bar
103
+
104
+
105
+ def to_rgb(pil_image: Image.Image) -> Image.Image:
106
+ if pil_image.mode == 'RGBA':
107
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
108
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
109
+ return white_background
110
+ else:
111
+ return pil_image.convert("RGB")
112
+
113
+
114
+ def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
115
+ if "image" in ele:
116
+ image = ele["image"]
117
+ else:
118
+ image = ele["image_url"]
119
+ image_obj = None
120
+ if isinstance(image, Image.Image):
121
+ image_obj = image
122
+ elif image.startswith("http://") or image.startswith("https://"):
123
+ response = requests.get(image, stream=True)
124
+ image_obj = Image.open(BytesIO(response.content))
125
+ elif image.startswith("file://"):
126
+ image_obj = Image.open(image[7:])
127
+ elif image.startswith("data:image"):
128
+ if "base64," in image:
129
+ _, base64_data = image.split("base64,", 1)
130
+ data = base64.b64decode(base64_data)
131
+ image_obj = Image.open(BytesIO(data))
132
+ else:
133
+ image_obj = Image.open(image)
134
+ if image_obj is None:
135
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
136
+ image = to_rgb(image_obj)
137
+ ## resize
138
+ if "resized_height" in ele and "resized_width" in ele:
139
+ resized_height, resized_width = smart_resize(
140
+ ele["resized_height"],
141
+ ele["resized_width"],
142
+ factor=size_factor,
143
+ )
144
+ else:
145
+ width, height = image.size
146
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
147
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
148
+ resized_height, resized_width = smart_resize(
149
+ height,
150
+ width,
151
+ factor=size_factor,
152
+ min_pixels=min_pixels,
153
+ max_pixels=max_pixels,
154
+ )
155
+ image = image.resize((resized_width, resized_height))
156
+
157
+ return image
158
+
159
+
160
+ def smart_nframes(
161
+ ele: dict,
162
+ total_frames: int,
163
+ video_fps: int | float,
164
+ ) -> int:
165
+ """calculate the number of frames for video used for model inputs.
166
+
167
+ Args:
168
+ ele (dict): a dict contains the configuration of video.
169
+ support either `fps` or `nframes`:
170
+ - nframes: the number of frames to extract for model inputs.
171
+ - fps: the fps to extract frames for model inputs.
172
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
173
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
174
+ total_frames (int): the original total number of frames of the video.
175
+ video_fps (int | float): the original fps of the video.
176
+
177
+ Raises:
178
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
179
+
180
+ Returns:
181
+ int: the number of frames for video used for model inputs.
182
+ """
183
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
184
+ if "nframes" in ele:
185
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
186
+ else:
187
+ fps = ele.get("fps", FPS)
188
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
189
+ max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
190
+ nframes = total_frames / video_fps * fps
191
+ if nframes > total_frames:
192
+ logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
193
+ nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
194
+ nframes = floor_by_factor(nframes, FRAME_FACTOR)
195
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
196
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
197
+ return nframes
198
+
199
+
200
+ def _read_video_torchvision(
201
+ ele: dict,
202
+ ) -> (torch.Tensor, float):
203
+ """read video using torchvision.io.read_video
204
+
205
+ Args:
206
+ ele (dict): a dict contains the configuration of video.
207
+ support keys:
208
+ - video: the path of video. support "file://", "http://", "https://" and local path.
209
+ - video_start: the start time of video.
210
+ - video_end: the end time of video.
211
+ Returns:
212
+ torch.Tensor: the video tensor with shape (T, C, H, W).
213
+ """
214
+ video_path = ele["video"]
215
+ if version.parse(torchvision.__version__) < version.parse("0.19.0"):
216
+ if "http://" in video_path or "https://" in video_path:
217
+ warnings.warn("torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0.")
218
+ if "file://" in video_path:
219
+ video_path = video_path[7:]
220
+ st = time.time()
221
+ video, audio, info = io.read_video(
222
+ video_path,
223
+ start_pts=ele.get("video_start", 0.0),
224
+ end_pts=ele.get("video_end", None),
225
+ pts_unit="sec",
226
+ output_format="TCHW",
227
+ )
228
+ #actual_frame_count = count_frames(video_path)
229
+ total_frames, video_fps = video.size(0), info["video_fps"]
230
+ #total_frames = actual_frame_count
231
+ logger.info(f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
232
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
233
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long()
234
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
235
+ video = video[idx]
236
+ return video, sample_fps
237
+
238
+
239
+ def is_decord_available() -> bool:
240
+ import importlib.util
241
+
242
+ return importlib.util.find_spec("decord") is not None
243
+
244
+
245
+ def _read_video_decord(
246
+ ele: dict,
247
+ ) -> (torch.Tensor, float):
248
+ """read video using decord.VideoReader
249
+
250
+ Args:
251
+ ele (dict): a dict contains the configuration of video.
252
+ support keys:
253
+ - video: the path of video. support "file://", "http://", "https://" and local path.
254
+ - video_start: the start time of video.
255
+ - video_end: the end time of video.
256
+ Returns:
257
+ torch.Tensor: the video tensor with shape (T, C, H, W).
258
+ """
259
+ import decord
260
+ video_path = ele["video"]
261
+ st = time.time()
262
+ vr = decord.VideoReader(video_path)
263
+ # TODO: support start_pts and end_pts
264
+ if 'video_start' in ele or 'video_end' in ele:
265
+ raise NotImplementedError("not support start_pts and end_pts in decord for now.")
266
+
267
+ #actual_frame_count = count_frames(video_path)
268
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
269
+ #total_frames = actual_frame_count
270
+ logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
271
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
272
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
273
+ video = vr.get_batch(idx).asnumpy()
274
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
275
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
276
+ return video, sample_fps
277
+
278
+
279
+ VIDEO_READER_BACKENDS = {
280
+ "decord": _read_video_decord,
281
+ "torchvision": _read_video_torchvision,
282
+ }
283
+
284
+ FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
285
+
286
+
287
+ @lru_cache(maxsize=1)
288
+ def get_video_reader_backend() -> str:
289
+ if FORCE_QWENVL_VIDEO_READER is not None:
290
+ video_reader_backend = FORCE_QWENVL_VIDEO_READER
291
+ elif is_decord_available():
292
+ video_reader_backend = "decord"
293
+ else:
294
+ video_reader_backend = "torchvision"
295
+ print(f"qwen-vl-utils using {video_reader_backend} to read video.", file=sys.stderr)
296
+ return video_reader_backend
297
+
298
+
299
+ def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False) -> torch.Tensor | list[Image.Image]:
300
+ if isinstance(ele["video"], str):
301
+ video_reader_backend = get_video_reader_backend()
302
+ try:
303
+ video, sample_fps = VIDEO_READER_BACKENDS[video_reader_backend](ele)
304
+ except Exception as e:
305
+ logger.warning(f"video_reader_backend {video_reader_backend} error, use torchvision as default, msg: {e}")
306
+ video, sample_fps = VIDEO_READER_BACKENDS["torchvision"](ele)
307
+
308
+ nframes, _, height, width = video.shape
309
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
310
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
311
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
312
+ max_pixels_supposed = ele.get("max_pixels", max_pixels)
313
+ if max_pixels_supposed > max_pixels:
314
+ logger.warning(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
315
+ max_pixels = min(max_pixels_supposed, max_pixels)
316
+ if "resized_height" in ele and "resized_width" in ele:
317
+ resized_height, resized_width = smart_resize(
318
+ ele["resized_height"],
319
+ ele["resized_width"],
320
+ factor=image_factor,
321
+ )
322
+ else:
323
+ resized_height, resized_width = smart_resize(
324
+ height,
325
+ width,
326
+ factor=image_factor,
327
+ min_pixels=min_pixels,
328
+ max_pixels=max_pixels,
329
+ )
330
+ video = transforms.functional.resize(
331
+ video,
332
+ [resized_height, resized_width],
333
+ interpolation=InterpolationMode.BICUBIC,
334
+ antialias=True,
335
+ ).float()
336
+ if return_video_sample_fps:
337
+ return video, sample_fps
338
+ return video
339
+ else:
340
+ assert isinstance(ele["video"], (list, tuple))
341
+ process_info = ele.copy()
342
+ process_info.pop("type", None)
343
+ process_info.pop("video", None)
344
+ images = [
345
+ fetch_image({"image": video_element, **process_info}, size_factor=image_factor)
346
+ for video_element in ele["video"]
347
+ ]
348
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
349
+ if len(images) < nframes:
350
+ images.extend([images[-1]] * (nframes - len(images)))
351
+ if return_video_sample_fps:
352
+ return images, process_info.pop("fps", 2.0)
353
+ return images
354
+
355
+
356
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
357
+ vision_infos = []
358
+ if isinstance(conversations[0], dict):
359
+ conversations = [conversations]
360
+ for conversation in conversations:
361
+ for message in conversation:
362
+ if isinstance(message["content"], list):
363
+ for ele in message["content"]:
364
+ if (
365
+ "image" in ele
366
+ or "image_url" in ele
367
+ or "video" in ele
368
+ or ele["type"] in ("image", "image_url", "video")
369
+ ):
370
+ vision_infos.append(ele)
371
+ return vision_infos
372
+
373
+
374
+ def process_vision_info(
375
+ conversations: list[dict] | list[list[dict]],
376
+ return_video_kwargs: bool = False,
377
+ ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
378
+
379
+ vision_infos = extract_vision_info(conversations)
380
+ ## Read images or videos
381
+ image_inputs = []
382
+ video_inputs = []
383
+ video_sample_fps_list = []
384
+ for vision_info in vision_infos:
385
+ if "image" in vision_info or "image_url" in vision_info:
386
+ image_inputs.append(fetch_image(vision_info))
387
+ elif "video" in vision_info:
388
+ video_input, video_sample_fps = fetch_video(vision_info, return_video_sample_fps=True)
389
+ video_sample_fps_list.append(video_sample_fps)
390
+ video_inputs.append(video_input)
391
+ else:
392
+ raise ValueError("image, image_url or video should in content.")
393
+ if len(image_inputs) == 0:
394
+ image_inputs = None
395
+ if len(video_inputs) == 0:
396
+ video_inputs = None
397
+ if return_video_kwargs:
398
+ return image_inputs, video_inputs, {'fps': video_sample_fps_list}
399
+ return image_inputs, video_inputs