Zeyue7 commited on
Commit
e82974e
·
1 Parent(s): ffa5ac7

VidMuse_CVPR

Browse files
config.json DELETED
@@ -1,8 +0,0 @@
1
- {
2
- "model_type": "simple_processor",
3
- "message": "Hello from SimpleProcessor!",
4
- "auto_map": {
5
- "AutoConfig": "processor.VidMuseConfig",
6
- "AutoProcessor": "processor.VidMuseProcessor"
7
- }
8
- }
 
 
 
 
 
 
 
 
 
modeling_vidmuse.py DELETED
@@ -1,27 +0,0 @@
1
- # modeling_vidmuse.py
2
- from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
3
- import torch
4
-
5
- # 注册自定义配置和模型(关键步骤!)
6
- class VidMuseConfig(PretrainedConfig):
7
- model_type = "vidmuse"
8
-
9
- def __init__(self, compression_model=None, **kwargs):
10
- super().__init__(**kwargs)
11
- self.compression_model = compression_model
12
-
13
- class VidMuseModel(PreTrainedModel):
14
- config_class = VidMuseConfig # 明确指定关联的配置类
15
-
16
- def __init__(self, config):
17
- super().__init__(config) # 必须调用父类初始化
18
- self.model_dir = os.path.dirname(os.path.abspath(__file__))
19
- self.compression_model = self._load_submodel(config.compression_model)
20
-
21
- def _load_submodel(self, relative_path):
22
- full_path = os.path.join(self.model_dir, relative_path)
23
- return torch.load(full_path)
24
-
25
- # 注册到Auto框架(必须放在类定义之后!)
26
- AutoConfig.register("vidmuse", VidMuseConfig)
27
- AutoModel.register(VidMuseConfig, VidMuseModel)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_vidmuse_back.py DELETED
@@ -1,51 +0,0 @@
1
- # modeling_vidmuse.py
2
- from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
3
- import torch
4
- import os
5
- from huggingface_hub import hf_hub_download
6
- from huggingface_hub import snapshot_download
7
-
8
- # 注册自定义配置和模型(关键步骤!)
9
- class VidMuseConfig(PretrainedConfig):
10
- model_type = "vidmuse"
11
-
12
- def __init__(self, compression_model=None, **kwargs):
13
- super().__init__(**kwargs)
14
- self.compression_model = compression_model
15
-
16
- class VidMuseModel(PreTrainedModel):
17
- config_class = VidMuseConfig # 明确指定关联的配置类
18
-
19
- def __init__(self, config):
20
- super().__init__(config) # 必须调用父类初始化
21
- # self.model_dir = os.path.dirname(os.path.abspath(__file__))
22
- self.hub_cache_dir = snapshot_download(
23
- repo_id="Zeyue7/VidMuse",
24
- revision=config._commit_hash # 使用配置中的 commit hash
25
- )
26
-
27
- self.compression_model = self._load_submodel(config.compression_model)
28
- # import pdb; pdb.set_trace()
29
-
30
- def _load_submodel(self, relative_path):
31
- full_path = os.path.join(self.hub_cache_dir, relative_path)
32
- return torch.load(full_path)
33
-
34
- # @classmethod
35
- # def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
36
- # # 主动下载附加文件
37
- # hf_hub_download(
38
- # repo_id=pretrained_model_name_or_path,
39
- # filename="compression_state_dict.bin",
40
- # force_download=True,
41
- # cache_dir=kwargs.get("cache_dir", None)
42
- # )
43
-
44
- # # 继续正常加载流程
45
-
46
- # return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
47
-
48
- # 注册到Auto框架(必须放在类定义之后!)
49
- AutoConfig.register("vidmuse", VidMuseConfig)
50
- AutoModel.register(VidMuseConfig, VidMuseModel)
51
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
video_processor.py CHANGED
@@ -24,13 +24,14 @@ class VideoProcessor:
24
  target_duration = duration * target_fps
25
 
26
  if current_duration > target_duration:
27
- video_tensor = video_tensor[:, :target_duration]
28
  elif current_duration < target_duration:
29
  last_frame = video_tensor[:, -1:]
30
- repeat_times = target_duration - current_duration
31
  video_tensor = torch.cat((video_tensor, last_frame.repeat(1, repeat_times, 1, 1)), dim=1)
32
  return video_tensor
33
 
 
34
  def video_read_global(self, filepath, seek_time=0., duration=-1, target_fps=2, global_mode='average', global_num_frames=32):
35
  vr = VideoReader(filepath, ctx=cpu(0))
36
  fps = vr.get_avg_fps()
@@ -40,7 +41,7 @@ class VideoProcessor:
40
  total_frames_to_read = target_fps * duration
41
  frame_interval = int(math.ceil(fps / target_fps))
42
  start_frame = int(seek_time * fps)
43
- end_frame = start_frame + frame_interval * total_frames_to_read
44
  frame_ids = list(range(start_frame, min(end_frame, frame_count), frame_interval))
45
  else:
46
  frame_ids = list(range(0, frame_count, int(math.ceil(fps / target_fps))))
@@ -53,8 +54,9 @@ class VideoProcessor:
53
  local_video_tensor = einops.rearrange(local_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
54
  local_video_tensor = self.adjust_video_duration(local_video_tensor, duration, target_fps)
55
 
56
- if global_mode == 'average':
57
  global_frame_ids = torch.linspace(0, frame_count - 1, global_num_frames).long()
 
58
  global_frames = vr.get_batch(global_frame_ids)
59
  global_frames = torch.from_numpy(global_frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
60
 
@@ -62,8 +64,11 @@ class VideoProcessor:
62
  global_video_tensor = torch.stack(global_frames)
63
  global_video_tensor = einops.rearrange(global_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
64
 
 
65
  return local_video_tensor, global_video_tensor
66
 
 
 
67
  def process(self, video_path, target_fps=2, global_mode='average', global_num_frames=32):
68
  duration = self.get_video_duration(video_path)
69
  if duration is None:
 
24
  target_duration = duration * target_fps
25
 
26
  if current_duration > target_duration:
27
+ video_tensor = video_tensor[:, :int(target_duration)]
28
  elif current_duration < target_duration:
29
  last_frame = video_tensor[:, -1:]
30
+ repeat_times = int(target_duration - current_duration)
31
  video_tensor = torch.cat((video_tensor, last_frame.repeat(1, repeat_times, 1, 1)), dim=1)
32
  return video_tensor
33
 
34
+
35
  def video_read_global(self, filepath, seek_time=0., duration=-1, target_fps=2, global_mode='average', global_num_frames=32):
36
  vr = VideoReader(filepath, ctx=cpu(0))
37
  fps = vr.get_avg_fps()
 
41
  total_frames_to_read = target_fps * duration
42
  frame_interval = int(math.ceil(fps / target_fps))
43
  start_frame = int(seek_time * fps)
44
+ end_frame = int(start_frame + frame_interval * total_frames_to_read)
45
  frame_ids = list(range(start_frame, min(end_frame, frame_count), frame_interval))
46
  else:
47
  frame_ids = list(range(0, frame_count, int(math.ceil(fps / target_fps))))
 
54
  local_video_tensor = einops.rearrange(local_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
55
  local_video_tensor = self.adjust_video_duration(local_video_tensor, duration, target_fps)
56
 
57
+ if global_mode=='average':
58
  global_frame_ids = torch.linspace(0, frame_count - 1, global_num_frames).long()
59
+
60
  global_frames = vr.get_batch(global_frame_ids)
61
  global_frames = torch.from_numpy(global_frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
62
 
 
64
  global_video_tensor = torch.stack(global_frames)
65
  global_video_tensor = einops.rearrange(global_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
66
 
67
+ assert global_video_tensor.shape[1] == global_num_frames, f"the shape of global_video_tensor is {global_video_tensor.shape}"
68
  return local_video_tensor, global_video_tensor
69
 
70
+
71
+
72
  def process(self, video_path, target_fps=2, global_mode='average', global_num_frames=32):
73
  duration = self.get_video_duration(video_path)
74
  if duration is None: