bpiyush commited on
Commit
a0c6980
·
verified ·
1 Parent(s): 3a7d6fb

Upload tarsier/utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tarsier/utils.py +128 -0
tarsier/utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ import os
16
+ from PIL import Image, ImageSequence
17
+ import decord
18
+
19
+ VALID_DATA_FORMAT_STRING = "Input data must be {'.jpg', '.jpeg', '.png', '.tif'} for image; or {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv', '.gif'} for videos!"
20
+
21
+ # 均匀抽帧,必采样首尾帧。
22
+ def sample_frame_indices(start_frame, total_frames: int, n_frames: int):
23
+ if n_frames == 1:
24
+ return [0] # sample first frame in default
25
+ sample_ids = [round(i * (total_frames - 1) / (n_frames - 1)) for i in range(n_frames)]
26
+ sample_ids = [i + start_frame for i in sample_ids]
27
+ return sample_ids
28
+
29
+ def sample_video(
30
+ video_path: str,
31
+ n_frames: int = None,
32
+ start_time: int = 0,
33
+ end_time: int = -1
34
+ ) -> List[Image.Image]:
35
+
36
+ assert os.path.exists(video_path), f"File not found: {video_path}"
37
+ vr = decord.VideoReader(video_path, num_threads=1, ctx=decord.cpu(0))
38
+ vr.seek(0)
39
+ total_frames = len(vr)
40
+ fps = vr.get_avg_fps()
41
+
42
+ start_frame = 0
43
+ end_frame = total_frames - 1
44
+ if start_time > 0:
45
+ start_frame = min((total_frames-1), int(fps*start_time))
46
+ if end_time > 0:
47
+ end_frame = max(start_frame, int(fps*end_time))
48
+ end_frame = min(end_frame, (total_frames-1))
49
+ frame_indices = sample_frame_indices(
50
+ start_frame=start_frame,
51
+ total_frames=end_frame - start_frame + 1,
52
+ n_frames=n_frames,
53
+ )
54
+
55
+ frames = vr.get_batch(frame_indices).asnumpy()
56
+ frames = [Image.fromarray(f).convert('RGB') for f in frames]
57
+ return frames
58
+
59
+ def sample_gif(
60
+ gif_path: str,
61
+ n_frames:int = None,
62
+ start_time: int = 0,
63
+ end_time: int = -1
64
+ ) -> List[Image.Image]:
65
+
66
+ assert os.path.exists(gif_path), f"File not found: {gif_path}"
67
+
68
+ gif_frames = Image.open(gif_path)
69
+
70
+ start_frame = 0
71
+ end_frame = gif_frames.n_frames - 1
72
+ frame_indices = sample_frame_indices(
73
+ start_frame=start_frame,
74
+ total_frames=end_frame - start_frame + 1,
75
+ n_frames=n_frames,
76
+ )
77
+
78
+ frames = []
79
+ i = 0
80
+ for frame in ImageSequence.Iterator(gif_frames):
81
+ if i in frame_indices:
82
+ frames.append(frame.convert('RGB'))
83
+ i += 1
84
+ return frames
85
+
86
+ def sample_image(
87
+ image_path: str,
88
+ n_frames: int = None,
89
+ start_time: int = 0,
90
+ end_time: int = -1
91
+ ):
92
+ assert os.path.exists(image_path), f"File not found: {image_path}"
93
+ image = Image.open(image_path).convert('RGB')
94
+ return [image]
95
+
96
+ def get_visual_type(input_file):
97
+ ext = os.path.splitext(input_file)[-1]
98
+ if ext in {'.gif'}:
99
+ return 'gif'
100
+ elif ext in {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv'}:
101
+ return 'video'
102
+ elif ext in {'.jpg', '.jpeg', '.png', '.tif'}:
103
+ return 'image'
104
+ else:
105
+ print(f"{VALID_DATA_FORMAT_STRING} But found {ext}!")
106
+ return 'unk'
107
+
108
+ def get_benchmarks(benchmarks):
109
+ final_benchmarks = []
110
+ type2bm = {
111
+ 'dream': ['dream'],
112
+ 'caption': ['msvd-caption', 'msr-vtt-caption', 'vatex-caption'],
113
+ 'mc_qa': ['next-qa', 'egoschema', 'mvbench', 'video-mme'],
114
+ 'oe_qa': ['msvd-qa', 'msr-vtt-qa', 'tgif-qa', 'anet-qa'],
115
+ }
116
+ for bm in benchmarks:
117
+ bm = bm.lower()
118
+ if bm in final_benchmarks:
119
+ continue
120
+ if bm == 'all':
121
+ for v in type2bm.values():
122
+ final_benchmarks.extend(v)
123
+ return final_benchmarks
124
+ if bm in type2bm:
125
+ final_benchmarks.extend(type2bm[bm])
126
+ else:
127
+ final_benchmarks.append(bm)
128
+ return final_benchmarks