internvl_ola / speech_encoder.py
jjw0126's picture
Upload files
62d115a verified
import os
# import types
import torch
import torch.nn as nn
import torch.nn.functional as F
# from transformers import WhisperFeatureExtractor
import whisper
# Import BEATs from local beats directory
from beats_model import BEATsConfig, BEATs
BEATS_AVAILABLE = True
print("✅ BEATs imported successfully from local beats directory")
os.environ["WHISPER_CACHE_DIR"] = "/data1/cxy/plm-v/modeling/cache"
def replace_layer_norm(module):
from whisper.model import LayerNorm
for name, child in module.named_children():
if isinstance(child, LayerNorm):
old_params = child.state_dict()
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
new_layer_norm.load_state_dict(old_params)
setattr(module, name, new_layer_norm)
else:
replace_layer_norm(child)
class DualWrappedEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.whisper_model = self.load_whisper(config)
self.beats_model = self.load_beats(config)
def load_whisper(self, model_config):
# Check if custom cache directory is specified
# Priority: model_config.whisper_cache_dir > WHISPER_CACHE_DIR env var > default
download_root = getattr(model_config, 'whisper_cache_dir', None)
if not download_root:
download_root = os.environ.get('WHISPER_CACHE_DIR', None)
# Use whisper_config if available, otherwise fallback to speech_encoder name
if hasattr(model_config, 'whisper_config') and model_config.whisper_config:
print(f"Loading Whisper with custom config: {model_config.whisper_config.get('d_model', 'default')}")
# For now, still load by name but could be extended to use custom config
if download_root:
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu', download_root=download_root).encoder
else:
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
else:
if download_root:
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu', download_root=download_root).encoder
else:
encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
replace_layer_norm(encoder)
return encoder
def load_beats(self, model_config):
# Check if BEATs should be used
if not getattr(model_config, 'use_beats', False):
print("BEATs model disabled in config")
return None
if not BEATS_AVAILABLE:
print("BEATs not available - skipping music encoder")
return None
beats_path = getattr(model_config, 'beats_model_path', None)
if not beats_path:
print("No BEATs model path specified")
return None
try:
print(f"Loading BEATs Model from {beats_path}")
beats_ckpt = torch.load(beats_path, map_location='cpu')
beats_cfg = BEATsConfig(beats_ckpt['cfg'])
beats = BEATs(beats_cfg)
beats.load_state_dict(beats_ckpt['model'])
print("BEATs model loaded successfully")
return beats
except Exception as e:
print(f"Failed to load BEATs model: {e}")
return None
def forward(self, x, raw_wav=None, audio_padding_mask=None):
with torch.no_grad():
speech_embeds = self.whisper_model(x)
# Process with BEATs if available
if self.beats_model is not None and raw_wav is not None:
self.beats_model = self.beats_model.float()
audio_embeds, _ = self.beats_model.extract_features(
raw_wav.float(),
padding_mask=audio_padding_mask,
feature_only=True
)
# Align sequence lengths
if audio_embeds.size(1) < speech_embeds.size(1):
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
elif audio_embeds.size(1) > speech_embeds.size(1):
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
# Concatenate Whisper and BEATs features
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
print(f"Combined Whisper + BEATs features: {speech_embeds.shape}")
speech_embeds = speech_embeds.to(torch.bfloat16)
return speech_embeds