litwell commited on
Commit
d0932f1
·
verified ·
1 Parent(s): 37ec378

Upload models/test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/test.py +102 -0
models/test.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import cv2
4
+
5
+
6
+ def count_frames(video_path):
7
+ # 打开视频文件
8
+ video = cv2.VideoCapture(video_path)
9
+
10
+ # 统计实际读取到的帧数
11
+ actual_frame_count = 0
12
+ while True:
13
+ ret, frame = video.read()
14
+ if not ret:
15
+ break
16
+ actual_frame_count += 1
17
+
18
+ # 释放视频对象
19
+ video.release()
20
+ return actual_frame_count
21
+
22
+ def smart_nframes(
23
+ ele: dict,
24
+ total_frames: int,
25
+ video_fps: int | float,
26
+ ) -> int:
27
+ """calculate the number of frames for video used for model inputs.
28
+
29
+ Args:
30
+ ele (dict): a dict contains the configuration of video.
31
+ support either `fps` or `nframes`:
32
+ - nframes: the number of frames to extract for model inputs.
33
+ - fps: the fps to extract frames for model inputs.
34
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
35
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
36
+ total_frames (int): the original total number of frames of the video.
37
+ video_fps (int | float): the original fps of the video.
38
+
39
+ Raises:
40
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
41
+
42
+ Returns:
43
+ int: the number of frames for video used for model inputs.
44
+ """
45
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
46
+ if "nframes" in ele:
47
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
48
+ else:
49
+ fps = ele.get("fps", FPS)
50
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
51
+ max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
52
+ nframes = total_frames / video_fps * fps
53
+ if nframes > total_frames:
54
+ logger.warning(f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
55
+ nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
56
+ nframes = floor_by_factor(nframes, FRAME_FACTOR)
57
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
58
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
59
+ return nframes
60
+
61
+ def _read_video_decord(
62
+ ele: dict,
63
+ ) -> (torch.Tensor, float):
64
+ """read video using decord.VideoReader
65
+
66
+ Args:
67
+ ele (dict): a dict contains the configuration of video.
68
+ support keys:
69
+ - video: the path of video. support "file://", "http://", "https://" and local path.
70
+ - video_start: the start time of video.
71
+ - video_end: the end time of video.
72
+ Returns:
73
+ torch.Tensor: the video tensor with shape (T, C, H, W).
74
+ """
75
+ import decord
76
+ video_path = ele["video"]
77
+ st = time.time()
78
+ import pdb; pdb.set_trace()
79
+ vr = decord.VideoReader(video_path)
80
+ # TODO: support start_pts and end_pts
81
+ if 'video_start' in ele or 'video_end' in ele:
82
+ raise NotImplementedError("not support start_pts and end_pts in decord for now.")
83
+
84
+ actual_frame_count = count_frames(video_path)
85
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
86
+ total_frames = actual_frame_count
87
+ #logger.info(f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s")
88
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
89
+ idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
90
+ video = vr.get_batch(idx).asnumpy()
91
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
92
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
93
+ return video, sample_fps
94
+
95
+
96
+ ele_example = {
97
+ 'video': "/home/world_model/egoexo4d/keystep_train_takes-cut/georgiatech_cooking_14_02_2/aria02_214-1_0000030.mp4"
98
+ }
99
+
100
+ video, sample_fps = _read_video_decord(ele_example)
101
+ print(video.shape)
102
+ print(sample_fps)