import os # import types import torch import torch.nn as nn import torch.nn.functional as F # from transformers import WhisperFeatureExtractor import whisper # Import BEATs from local beats directory from beats_model import BEATsConfig, BEATs BEATS_AVAILABLE = True print("✅ BEATs imported successfully from local beats directory") os.environ["WHISPER_CACHE_DIR"] = "/data1/cxy/plm-v/modeling/cache" def replace_layer_norm(module): from whisper.model import LayerNorm for name, child in module.named_children(): if isinstance(child, LayerNorm): old_params = child.state_dict() new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) new_layer_norm.load_state_dict(old_params) setattr(module, name, new_layer_norm) else: replace_layer_norm(child) class DualWrappedEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.whisper_model = self.load_whisper(config) self.beats_model = self.load_beats(config) def load_whisper(self, model_config): # Check if custom cache directory is specified # Priority: model_config.whisper_cache_dir > WHISPER_CACHE_DIR env var > default download_root = getattr(model_config, 'whisper_cache_dir', None) if not download_root: download_root = os.environ.get('WHISPER_CACHE_DIR', None) # Use whisper_config if available, otherwise fallback to speech_encoder name if hasattr(model_config, 'whisper_config') and model_config.whisper_config: print(f"Loading Whisper with custom config: {model_config.whisper_config.get('d_model', 'default')}") # For now, still load by name but could be extended to use custom config if download_root: encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu', download_root=download_root).encoder else: encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder else: if download_root: encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu', download_root=download_root).encoder else: encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder replace_layer_norm(encoder) return encoder def load_beats(self, model_config): # Check if BEATs should be used if not getattr(model_config, 'use_beats', False): print("BEATs model disabled in config") return None if not BEATS_AVAILABLE: print("BEATs not available - skipping music encoder") return None beats_path = getattr(model_config, 'beats_model_path', None) if not beats_path: print("No BEATs model path specified") return None try: print(f"Loading BEATs Model from {beats_path}") beats_ckpt = torch.load(beats_path, map_location='cpu') beats_cfg = BEATsConfig(beats_ckpt['cfg']) beats = BEATs(beats_cfg) beats.load_state_dict(beats_ckpt['model']) print("BEATs model loaded successfully") return beats except Exception as e: print(f"Failed to load BEATs model: {e}") return None def forward(self, x, raw_wav=None, audio_padding_mask=None): with torch.no_grad(): speech_embeds = self.whisper_model(x) # Process with BEATs if available if self.beats_model is not None and raw_wav is not None: self.beats_model = self.beats_model.float() audio_embeds, _ = self.beats_model.extract_features( raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True ) # Align sequence lengths if audio_embeds.size(1) < speech_embeds.size(1): audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) elif audio_embeds.size(1) > speech_embeds.size(1): speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) # Concatenate Whisper and BEATs features speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) print(f"Combined Whisper + BEATs features: {speech_embeds.shape}") speech_embeds = speech_embeds.to(torch.bfloat16) return speech_embeds