Spaces:
Runtime error
Runtime error
| # Copyright Alibaba Inc. All Rights Reserved. | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from safetensors import safe_open | |
| class AudioProjModel(nn.Module): | |
| """音频投影模型""" | |
| def __init__(self, audio_dim, proj_dim): | |
| super().__init__() | |
| self.audio_dim = audio_dim | |
| self.proj_dim = proj_dim | |
| self.projection = nn.Sequential( | |
| nn.Linear(audio_dim, proj_dim * 2), | |
| nn.ReLU(), | |
| nn.Linear(proj_dim * 2, proj_dim), | |
| ) | |
| def forward(self, audio_features): | |
| return self.projection(audio_features) | |
| class WanCrossAttentionProcessor(nn.Module): | |
| """Wan模型的交叉注意力处理器""" | |
| def __init__(self, hidden_size, cross_attention_dim, audio_proj_dim): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.cross_attention_dim = cross_attention_dim | |
| self.audio_proj_dim = audio_proj_dim | |
| # 音频条件的查询、键、值投影层 | |
| self.to_q_audio = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.to_k_audio = nn.Linear(audio_proj_dim, hidden_size, bias=False) | |
| self.to_v_audio = nn.Linear(audio_proj_dim, hidden_size, bias=False) | |
| self.scale = hidden_size ** -0.5 | |
| def forward(self, hidden_states, audio_features=None, **kwargs): | |
| if audio_features is None: | |
| return hidden_states | |
| batch_size, seq_len, _ = hidden_states.shape | |
| # 计算查询、键、值 | |
| query = self.to_q_audio(hidden_states) | |
| key = self.to_k_audio(audio_features) | |
| value = self.to_v_audio(audio_features) | |
| # 计算注意力权重 | |
| attention_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale | |
| attention_probs = F.softmax(attention_scores, dim=-1) | |
| # 应用注意力权重 | |
| attention_output = torch.matmul(attention_probs, value) | |
| return hidden_states + attention_output | |
| class FantasyTalkingAudioConditionModel(nn.Module): | |
| """FantasyTalking音频条件模型""" | |
| def __init__(self, base_model, audio_dim, proj_dim): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.audio_dim = audio_dim | |
| self.proj_dim = proj_dim | |
| # 音频投影层 | |
| self.audio_proj = AudioProjModel(audio_dim, proj_dim) | |
| # 存储原始的注意力处理器 | |
| self.original_processors = {} | |
| def load_audio_processor(self, checkpoint_path, base_model): | |
| """加载音频处理器权重""" | |
| if os.path.exists(checkpoint_path): | |
| print(f"加载FantasyTalking权重: {checkpoint_path}") | |
| # 这里应该加载实际的权重文件 | |
| # state_dict = torch.load(checkpoint_path, map_location="cpu") | |
| # self.load_state_dict(state_dict, strict=False) | |
| else: | |
| print(f"权重文件不存在: {checkpoint_path}") | |
| def enable_audio_condition(self): | |
| """启用音频条件""" | |
| # 这里应该替换base_model中的注意力处理器 | |
| pass | |
| def disable_audio_condition(self): | |
| """禁用音频条件""" | |
| # 这里应该恢复原始的注意力处理器 | |
| pass | |
| def forward(self, audio_features): | |
| """前向传播""" | |
| return self.audio_proj(audio_features) | |