VidMuse_CVPR
Browse files- config.json +0 -8
- modeling_vidmuse.py +0 -27
- modeling_vidmuse_back.py +0 -51
- video_processor.py +9 -4
config.json
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"model_type": "simple_processor",
|
| 3 |
-
"message": "Hello from SimpleProcessor!",
|
| 4 |
-
"auto_map": {
|
| 5 |
-
"AutoConfig": "processor.VidMuseConfig",
|
| 6 |
-
"AutoProcessor": "processor.VidMuseProcessor"
|
| 7 |
-
}
|
| 8 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_vidmuse.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
# modeling_vidmuse.py
|
| 2 |
-
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
# 注册自定义配置和模型(关键步骤!)
|
| 6 |
-
class VidMuseConfig(PretrainedConfig):
|
| 7 |
-
model_type = "vidmuse"
|
| 8 |
-
|
| 9 |
-
def __init__(self, compression_model=None, **kwargs):
|
| 10 |
-
super().__init__(**kwargs)
|
| 11 |
-
self.compression_model = compression_model
|
| 12 |
-
|
| 13 |
-
class VidMuseModel(PreTrainedModel):
|
| 14 |
-
config_class = VidMuseConfig # 明确指定关联的配置类
|
| 15 |
-
|
| 16 |
-
def __init__(self, config):
|
| 17 |
-
super().__init__(config) # 必须调用父类初始化
|
| 18 |
-
self.model_dir = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
-
self.compression_model = self._load_submodel(config.compression_model)
|
| 20 |
-
|
| 21 |
-
def _load_submodel(self, relative_path):
|
| 22 |
-
full_path = os.path.join(self.model_dir, relative_path)
|
| 23 |
-
return torch.load(full_path)
|
| 24 |
-
|
| 25 |
-
# 注册到Auto框架(必须放在类定义之后!)
|
| 26 |
-
AutoConfig.register("vidmuse", VidMuseConfig)
|
| 27 |
-
AutoModel.register(VidMuseConfig, VidMuseModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modeling_vidmuse_back.py
DELETED
|
@@ -1,51 +0,0 @@
|
|
| 1 |
-
# modeling_vidmuse.py
|
| 2 |
-
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
| 3 |
-
import torch
|
| 4 |
-
import os
|
| 5 |
-
from huggingface_hub import hf_hub_download
|
| 6 |
-
from huggingface_hub import snapshot_download
|
| 7 |
-
|
| 8 |
-
# 注册自定义配置和模型(关键步骤!)
|
| 9 |
-
class VidMuseConfig(PretrainedConfig):
|
| 10 |
-
model_type = "vidmuse"
|
| 11 |
-
|
| 12 |
-
def __init__(self, compression_model=None, **kwargs):
|
| 13 |
-
super().__init__(**kwargs)
|
| 14 |
-
self.compression_model = compression_model
|
| 15 |
-
|
| 16 |
-
class VidMuseModel(PreTrainedModel):
|
| 17 |
-
config_class = VidMuseConfig # 明确指定关联的配置类
|
| 18 |
-
|
| 19 |
-
def __init__(self, config):
|
| 20 |
-
super().__init__(config) # 必须调用父类初始化
|
| 21 |
-
# self.model_dir = os.path.dirname(os.path.abspath(__file__))
|
| 22 |
-
self.hub_cache_dir = snapshot_download(
|
| 23 |
-
repo_id="Zeyue7/VidMuse",
|
| 24 |
-
revision=config._commit_hash # 使用配置中的 commit hash
|
| 25 |
-
)
|
| 26 |
-
|
| 27 |
-
self.compression_model = self._load_submodel(config.compression_model)
|
| 28 |
-
# import pdb; pdb.set_trace()
|
| 29 |
-
|
| 30 |
-
def _load_submodel(self, relative_path):
|
| 31 |
-
full_path = os.path.join(self.hub_cache_dir, relative_path)
|
| 32 |
-
return torch.load(full_path)
|
| 33 |
-
|
| 34 |
-
# @classmethod
|
| 35 |
-
# def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 36 |
-
# # 主动下载附加文件
|
| 37 |
-
# hf_hub_download(
|
| 38 |
-
# repo_id=pretrained_model_name_or_path,
|
| 39 |
-
# filename="compression_state_dict.bin",
|
| 40 |
-
# force_download=True,
|
| 41 |
-
# cache_dir=kwargs.get("cache_dir", None)
|
| 42 |
-
# )
|
| 43 |
-
|
| 44 |
-
# # 继续正常加载流程
|
| 45 |
-
|
| 46 |
-
# return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 47 |
-
|
| 48 |
-
# 注册到Auto框架(必须放在类定义之后!)
|
| 49 |
-
AutoConfig.register("vidmuse", VidMuseConfig)
|
| 50 |
-
AutoModel.register(VidMuseConfig, VidMuseModel)
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_processor.py
CHANGED
|
@@ -24,13 +24,14 @@ class VideoProcessor:
|
|
| 24 |
target_duration = duration * target_fps
|
| 25 |
|
| 26 |
if current_duration > target_duration:
|
| 27 |
-
video_tensor = video_tensor[:, :target_duration]
|
| 28 |
elif current_duration < target_duration:
|
| 29 |
last_frame = video_tensor[:, -1:]
|
| 30 |
-
repeat_times = target_duration - current_duration
|
| 31 |
video_tensor = torch.cat((video_tensor, last_frame.repeat(1, repeat_times, 1, 1)), dim=1)
|
| 32 |
return video_tensor
|
| 33 |
|
|
|
|
| 34 |
def video_read_global(self, filepath, seek_time=0., duration=-1, target_fps=2, global_mode='average', global_num_frames=32):
|
| 35 |
vr = VideoReader(filepath, ctx=cpu(0))
|
| 36 |
fps = vr.get_avg_fps()
|
|
@@ -40,7 +41,7 @@ class VideoProcessor:
|
|
| 40 |
total_frames_to_read = target_fps * duration
|
| 41 |
frame_interval = int(math.ceil(fps / target_fps))
|
| 42 |
start_frame = int(seek_time * fps)
|
| 43 |
-
end_frame = start_frame + frame_interval * total_frames_to_read
|
| 44 |
frame_ids = list(range(start_frame, min(end_frame, frame_count), frame_interval))
|
| 45 |
else:
|
| 46 |
frame_ids = list(range(0, frame_count, int(math.ceil(fps / target_fps))))
|
|
@@ -53,8 +54,9 @@ class VideoProcessor:
|
|
| 53 |
local_video_tensor = einops.rearrange(local_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
|
| 54 |
local_video_tensor = self.adjust_video_duration(local_video_tensor, duration, target_fps)
|
| 55 |
|
| 56 |
-
if global_mode
|
| 57 |
global_frame_ids = torch.linspace(0, frame_count - 1, global_num_frames).long()
|
|
|
|
| 58 |
global_frames = vr.get_batch(global_frame_ids)
|
| 59 |
global_frames = torch.from_numpy(global_frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
|
| 60 |
|
|
@@ -62,8 +64,11 @@ class VideoProcessor:
|
|
| 62 |
global_video_tensor = torch.stack(global_frames)
|
| 63 |
global_video_tensor = einops.rearrange(global_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
|
| 64 |
|
|
|
|
| 65 |
return local_video_tensor, global_video_tensor
|
| 66 |
|
|
|
|
|
|
|
| 67 |
def process(self, video_path, target_fps=2, global_mode='average', global_num_frames=32):
|
| 68 |
duration = self.get_video_duration(video_path)
|
| 69 |
if duration is None:
|
|
|
|
| 24 |
target_duration = duration * target_fps
|
| 25 |
|
| 26 |
if current_duration > target_duration:
|
| 27 |
+
video_tensor = video_tensor[:, :int(target_duration)]
|
| 28 |
elif current_duration < target_duration:
|
| 29 |
last_frame = video_tensor[:, -1:]
|
| 30 |
+
repeat_times = int(target_duration - current_duration)
|
| 31 |
video_tensor = torch.cat((video_tensor, last_frame.repeat(1, repeat_times, 1, 1)), dim=1)
|
| 32 |
return video_tensor
|
| 33 |
|
| 34 |
+
|
| 35 |
def video_read_global(self, filepath, seek_time=0., duration=-1, target_fps=2, global_mode='average', global_num_frames=32):
|
| 36 |
vr = VideoReader(filepath, ctx=cpu(0))
|
| 37 |
fps = vr.get_avg_fps()
|
|
|
|
| 41 |
total_frames_to_read = target_fps * duration
|
| 42 |
frame_interval = int(math.ceil(fps / target_fps))
|
| 43 |
start_frame = int(seek_time * fps)
|
| 44 |
+
end_frame = int(start_frame + frame_interval * total_frames_to_read)
|
| 45 |
frame_ids = list(range(start_frame, min(end_frame, frame_count), frame_interval))
|
| 46 |
else:
|
| 47 |
frame_ids = list(range(0, frame_count, int(math.ceil(fps / target_fps))))
|
|
|
|
| 54 |
local_video_tensor = einops.rearrange(local_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
|
| 55 |
local_video_tensor = self.adjust_video_duration(local_video_tensor, duration, target_fps)
|
| 56 |
|
| 57 |
+
if global_mode=='average':
|
| 58 |
global_frame_ids = torch.linspace(0, frame_count - 1, global_num_frames).long()
|
| 59 |
+
|
| 60 |
global_frames = vr.get_batch(global_frame_ids)
|
| 61 |
global_frames = torch.from_numpy(global_frames.asnumpy()).permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
|
| 62 |
|
|
|
|
| 64 |
global_video_tensor = torch.stack(global_frames)
|
| 65 |
global_video_tensor = einops.rearrange(global_video_tensor, 't c h w -> c t h w') # [T, C, H, W] -> [C, T, H, W]
|
| 66 |
|
| 67 |
+
assert global_video_tensor.shape[1] == global_num_frames, f"the shape of global_video_tensor is {global_video_tensor.shape}"
|
| 68 |
return local_video_tensor, global_video_tensor
|
| 69 |
|
| 70 |
+
|
| 71 |
+
|
| 72 |
def process(self, video_path, target_fps=2, global_mode='average', global_num_frames=32):
|
| 73 |
duration = self.get_video_duration(video_path)
|
| 74 |
if duration is None:
|