Zeyue7 commited on
Commit
6a4a4ae
·
1 Parent(s): df69be2

VidMuse_CVPR

Browse files
Files changed (3) hide show
  1. model.py +17 -0
  2. modeling_vidmuse.py +14 -13
  3. state_dict.bin +3 -0
model.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ from transformers import PreTrainedModel
3
+ import torch
4
+ from audiocraft.models import VidMuse
5
+ from einops import rearrange
6
+
7
+ class VidMuseModel(PreTrainedModel):
8
+ def __init__(self, config):
9
+ super().__init__(config)
10
+ self.model = VidMuse.get_pretrained(config.model_path) # 加载你已有的预训练模型
11
+
12
+ def forward(self, video_input, **gen_kwargs):
13
+ # 获取视频的本地帧和全局帧
14
+ local_video_tensor, global_video_tensor = video_input
15
+ # 使用 VidMuse 生成音频
16
+ outputs = self.model.generate([local_video_tensor, global_video_tensor], **gen_kwargs)
17
+ return outputs
modeling_vidmuse.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  import os
5
  from huggingface_hub import hf_hub_download
6
  import os
7
-
8
 
9
  # 注册自定义配置和模型(关键步骤!)
10
  class VidMuseConfig(PretrainedConfig):
@@ -26,24 +26,25 @@ class VidMuseModel(PreTrainedModel):
26
  )
27
 
28
  self.compression_model = self._load_submodel(config.compression_model)
29
-
30
 
31
  def _load_submodel(self, relative_path):
32
  full_path = os.path.join(self.hub_cache_dir, relative_path)
33
  return torch.load(full_path)
34
 
35
- @classmethod
36
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
37
- # 主动下载附加文件
38
- hf_hub_download(
39
- repo_id=pretrained_model_name_or_path,
40
- filename="compression_state_dict.bin",
41
- force_download=True,
42
- cache_dir=kwargs.get("cache_dir", None)
43
- )
 
 
44
 
45
- # 继续正常加载流程
46
- return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
47
 
48
  # 注册到Auto框架(必须放在类定义之后!)
49
  AutoConfig.register("vidmuse", VidMuseConfig)
 
4
  import os
5
  from huggingface_hub import hf_hub_download
6
  import os
7
+ from huggingface_hub import snapshot_download
8
 
9
  # 注册自定义配置和模型(关键步骤!)
10
  class VidMuseConfig(PretrainedConfig):
 
26
  )
27
 
28
  self.compression_model = self._load_submodel(config.compression_model)
29
+ # import pdb; pdb.set_trace()
30
 
31
  def _load_submodel(self, relative_path):
32
  full_path = os.path.join(self.hub_cache_dir, relative_path)
33
  return torch.load(full_path)
34
 
35
+ # @classmethod
36
+ # def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
37
+ # # 主动下载附加文件
38
+ # hf_hub_download(
39
+ # repo_id=pretrained_model_name_or_path,
40
+ # filename="compression_state_dict.bin",
41
+ # force_download=True,
42
+ # cache_dir=kwargs.get("cache_dir", None)
43
+ # )
44
+
45
+ # # 继续正常加载流程
46
 
47
+ # return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
 
48
 
49
  # 注册到Auto框架(必须放在类定义之后!)
50
  AutoConfig.register("vidmuse", VidMuseConfig)
state_dict.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ef70e83c661434c931e6147a35402556bd79f6d0d3d8527205f5ce1ccd26262
3
+ size 7872328846