Spaces:
Runtime error
Runtime error
| import os | |
| import pdb | |
| import math | |
| import pickle | |
| from types import SimpleNamespace | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from loguru import logger | |
| from models.layers.layer import BasicBlock | |
| from models.wavlm.WavLM import WavLM, WavLMConfig | |
| class ExactLengthAdjuster(nn.Module): | |
| """ | |
| Layer that ensures the output has exactly the target length along the time dimension. | |
| It either adds or removes frames as needed. | |
| """ | |
| def __init__(self, target_length=196): | |
| super(ExactLengthAdjuster, self).__init__() | |
| self.target_length = target_length | |
| def forward(self, x): | |
| # x is expected to be [batch, channels, time] | |
| current_length = x.shape[2] | |
| if current_length == self.target_length: | |
| return x | |
| elif current_length < self.target_length: | |
| # Need to add frames | |
| frames_to_add = self.target_length - current_length | |
| # Duplicate the last frame as many times as needed | |
| last_frame = x[:, :, -1:] | |
| extra_frames = last_frame.repeat(1, 1, frames_to_add) | |
| return torch.cat([x, extra_frames], dim=2) | |
| else: | |
| # Need to remove frames | |
| # Just truncate to the target length | |
| return x[:, :, :self.target_length] | |
| class WavEncoder(nn.Module): | |
| def __init__(self, out_dim, audio_in=2, target_length=256): | |
| super().__init__() | |
| self.out_dim = out_dim | |
| self.feat_extractor = nn.Sequential( | |
| BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1700, downsample=True), | |
| BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True), | |
| BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ), | |
| BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True), | |
| BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7), | |
| BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True), | |
| ) | |
| self.length_adjuster = ExactLengthAdjuster(target_length=target_length) | |
| def forward(self, wav_data): | |
| if wav_data.dim() == 2: | |
| wav_data = wav_data.unsqueeze(1) | |
| else: | |
| wav_data = wav_data.transpose(1, 2) | |
| out = self.feat_extractor(wav_data) | |
| out = self.length_adjuster(out) | |
| return out.transpose(1, 2) | |
| class ModalityEncoder(nn.Module): | |
| def __init__(self, | |
| data_path, | |
| t_fix_pre, | |
| audio_dim, | |
| audio_in=2, | |
| raw_audio=False, | |
| latent_dim=256, | |
| audio_fps=30, | |
| use_exp=False, | |
| target_length=256, | |
| spatial_temporal=False | |
| ): | |
| super().__init__() | |
| self.raw_audio = raw_audio | |
| self.latent_dim = latent_dim | |
| self.audio_fps = audio_fps | |
| self.WavEncoder = WavEncoder(audio_dim, audio_in=audio_in, target_length=target_length) | |
| self.text_encoder_body = nn.Linear(300, audio_dim) | |
| vocab_path = f"{data_path}weights/vocab.pkl" | |
| if os.path.exists(vocab_path): | |
| with open(vocab_path, 'rb') as f: | |
| self.lang_model = pickle.load(f) | |
| pre_trained_embedding = self.lang_model.word_embedding_weights | |
| else: | |
| logger.warning(f"vocab.pkl not found at {vocab_path}, using zeroed fallback embedding") | |
| fallback_weights = np.zeros((2, 300), dtype=np.float32) | |
| self.lang_model = SimpleNamespace( | |
| PAD_token=0, | |
| UNK_token=1, | |
| word_embedding_weights=fallback_weights, | |
| ) | |
| pre_trained_embedding = fallback_weights | |
| self.text_pre_encoder_body = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=t_fix_pre) | |
| word_dim = pre_trained_embedding.shape[1] | |
| if self.raw_audio: | |
| # load the pre-trained wavlm model | |
| # self.load_and_freeze_wavlm() | |
| self.audio_projection = nn.Linear(1024, audio_dim) | |
| joint_multiplier = 4 if use_exp else 3 | |
| self.context_dim = self.latent_dim * joint_multiplier | |
| mix_input_dim = audio_dim * 3 if self.raw_audio else audio_dim * 2 | |
| self.mix_audio_text = nn.Linear(mix_input_dim, self.context_dim) | |
| def forward(self, audio, word, raw_audio=None, squeeze_scale=4): | |
| # Initial features extraction - single transpose each | |
| # [B, T, D] -> [T, B, D] | |
| audio_feat = self.WavEncoder(audio) | |
| text_emb = self.text_pre_encoder_body(word) | |
| text_feat = self.text_encoder_body(text_emb) | |
| audio_len = audio_feat.shape[1] | |
| text_len = text_feat.shape[1] | |
| if audio_len != text_len: | |
| target_len = text_len if text_len > 0 else audio_len | |
| if target_len == 0: | |
| logger.warning("Both audio and text sequences are empty; inserting single-frame zeros") | |
| audio_feat = audio_feat.new_zeros(audio_feat.shape[0], 1, audio_feat.shape[2]) | |
| text_feat = text_feat.new_zeros(text_feat.shape[0], 1, text_feat.shape[2]) | |
| else: | |
| if audio_len == 0: | |
| audio_feat = audio_feat.new_zeros(text_feat.shape[0], target_len, audio_feat.shape[2]) | |
| else: | |
| audio_feat = F.interpolate( | |
| audio_feat.transpose(1, 2), | |
| size=target_len, | |
| mode="linear", | |
| align_corners=False, | |
| ).transpose(1, 2) | |
| if text_len == 0: | |
| text_feat = text_feat.new_zeros(audio_feat.shape[0], target_len, text_feat.shape[2]) | |
| else: | |
| text_feat = F.interpolate( | |
| text_feat.transpose(1, 2), | |
| size=target_len, | |
| mode="nearest", | |
| ).transpose(1, 2) | |
| logger.warning( | |
| "Resampled modality features for length mismatch (audio=%d, text=%d -> %d)", | |
| audio_len, | |
| text_len, | |
| target_len, | |
| ) | |
| if raw_audio is not None and self.raw_audio: | |
| # Keep the same transpose pattern for consistency | |
| # raw_feat = self.extract_wavlm_feats(raw_audio) | |
| raw_feat = self.audio_projection(raw_audio) | |
| at_feat = torch.cat([audio_feat, raw_feat, text_feat], dim=2) | |
| else: | |
| at_feat = torch.cat([audio_feat, text_feat], dim=2) # [B, T, D] | |
| at_feat = self.mix_audio_text(at_feat) # [B, T, D'] | |
| at_feat = F.avg_pool1d(at_feat.transpose(1, 2), squeeze_scale) | |
| at_feat = at_feat.transpose(1, 2) # [B, T/scale, D'] | |
| return at_feat | |
| def load_and_freeze_wavlm(self, wavlm_path='./dataloaders/wavlm/WavLM-Base+.pt'): | |
| checkpoint = torch.load(wavlm_path) | |
| self.wavlm_cfg = WavLMConfig(checkpoint['cfg']) | |
| self.audio_encoder = WavLM(self.wavlm_cfg) | |
| self.audio_encoder.load_state_dict(checkpoint['model']) | |
| self.audio_encoder.eval() | |
| for param in self.audio_encoder.parameters(): | |
| param.requires_grad = False | |
| def extract_wavlm_feats(self, wav_input_16khz): | |
| assert self.audio_encoder is not None, "Please load the wavlm model first" | |
| # check the input type | |
| if isinstance(wav_input_16khz, np.ndarray): | |
| wav_input_16khz = torch.from_numpy(wav_input_16khz) | |
| if wav_input_16khz.dim() == 1: | |
| wav_input_16khz = wav_input_16khz.unsqueeze(0) | |
| device = next(self.audio_encoder.parameters()).device | |
| wav_input_16khz = wav_input_16khz.to(device) | |
| if self.wavlm_cfg.normalize: | |
| wav_input_16khz = F.layer_norm(wav_input_16khz, wav_input_16khz.shape) | |
| wavlm_feats = self.audio_encoder.extract_features(wav_input_16khz)[0] | |
| wavlm_feats = wavlm_feats.detach() # (bs, seq_len, dim) | |
| target_size = math.ceil(wavlm_feats.shape[1] / 50 * self.audio_fps) | |
| wavlm_feats = F.interpolate( | |
| wavlm_feats.transpose(1, 2), | |
| size=target_size, | |
| align_corners=True, | |
| mode='linear' | |
| ).transpose(1, 2) | |
| return wavlm_feats | |