|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import whisper |
|
|
|
|
|
|
|
|
from beats_model import BEATsConfig, BEATs |
|
|
BEATS_AVAILABLE = True |
|
|
print("✅ BEATs imported successfully from local beats directory") |
|
|
|
|
|
os.environ["WHISPER_CACHE_DIR"] = "/data1/cxy/plm-v/modeling/cache" |
|
|
|
|
|
def replace_layer_norm(module): |
|
|
from whisper.model import LayerNorm |
|
|
for name, child in module.named_children(): |
|
|
if isinstance(child, LayerNorm): |
|
|
old_params = child.state_dict() |
|
|
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) |
|
|
new_layer_norm.load_state_dict(old_params) |
|
|
setattr(module, name, new_layer_norm) |
|
|
else: |
|
|
replace_layer_norm(child) |
|
|
|
|
|
|
|
|
class DualWrappedEncoder(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.whisper_model = self.load_whisper(config) |
|
|
self.beats_model = self.load_beats(config) |
|
|
|
|
|
def load_whisper(self, model_config): |
|
|
|
|
|
|
|
|
download_root = getattr(model_config, 'whisper_cache_dir', None) |
|
|
if not download_root: |
|
|
download_root = os.environ.get('WHISPER_CACHE_DIR', None) |
|
|
|
|
|
|
|
|
if hasattr(model_config, 'whisper_config') and model_config.whisper_config: |
|
|
print(f"Loading Whisper with custom config: {model_config.whisper_config.get('d_model', 'default')}") |
|
|
|
|
|
if download_root: |
|
|
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu', download_root=download_root).encoder |
|
|
else: |
|
|
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder |
|
|
else: |
|
|
if download_root: |
|
|
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu', download_root=download_root).encoder |
|
|
else: |
|
|
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder |
|
|
|
|
|
replace_layer_norm(encoder) |
|
|
return encoder |
|
|
|
|
|
def load_beats(self, model_config): |
|
|
|
|
|
if not getattr(model_config, 'use_beats', False): |
|
|
print("BEATs model disabled in config") |
|
|
return None |
|
|
|
|
|
if not BEATS_AVAILABLE: |
|
|
print("BEATs not available - skipping music encoder") |
|
|
return None |
|
|
|
|
|
beats_path = getattr(model_config, 'beats_model_path', None) |
|
|
if not beats_path: |
|
|
print("No BEATs model path specified") |
|
|
return None |
|
|
|
|
|
try: |
|
|
print(f"Loading BEATs Model from {beats_path}") |
|
|
beats_ckpt = torch.load(beats_path, map_location='cpu') |
|
|
beats_cfg = BEATsConfig(beats_ckpt['cfg']) |
|
|
beats = BEATs(beats_cfg) |
|
|
beats.load_state_dict(beats_ckpt['model']) |
|
|
print("BEATs model loaded successfully") |
|
|
return beats |
|
|
except Exception as e: |
|
|
print(f"Failed to load BEATs model: {e}") |
|
|
return None |
|
|
|
|
|
def forward(self, x, raw_wav=None, audio_padding_mask=None): |
|
|
with torch.no_grad(): |
|
|
speech_embeds = self.whisper_model(x) |
|
|
|
|
|
|
|
|
if self.beats_model is not None and raw_wav is not None: |
|
|
|
|
|
self.beats_model = self.beats_model.float() |
|
|
audio_embeds, _ = self.beats_model.extract_features( |
|
|
raw_wav.float(), |
|
|
padding_mask=audio_padding_mask, |
|
|
feature_only=True |
|
|
) |
|
|
|
|
|
|
|
|
if audio_embeds.size(1) < speech_embeds.size(1): |
|
|
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) |
|
|
elif audio_embeds.size(1) > speech_embeds.size(1): |
|
|
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) |
|
|
|
|
|
|
|
|
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) |
|
|
print(f"Combined Whisper + BEATs features: {speech_embeds.shape}") |
|
|
|
|
|
|
|
|
speech_embeds = speech_embeds.to(torch.bfloat16) |
|
|
return speech_embeds |