|
|
import types |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import WhisperFeatureExtractor |
|
|
import whisper |
|
|
import torch |
|
|
try: |
|
|
torch.set_default_device("cpu") |
|
|
except Exception: |
|
|
pass |
|
|
import accelerate |
|
|
from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs |
|
|
|
|
|
class WhisperWrappedEncoder: |
|
|
|
|
|
@classmethod |
|
|
def load(cls, model_config): |
|
|
|
|
|
def replace_layer_norm(module): |
|
|
from whisper.model import LayerNorm |
|
|
for name, child in module.named_children(): |
|
|
if isinstance(child, LayerNorm): |
|
|
|
|
|
has_meta = any(p.is_meta for p in child.parameters()) |
|
|
if has_meta: |
|
|
|
|
|
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) |
|
|
else: |
|
|
old_params = child.state_dict() |
|
|
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) |
|
|
new_layer_norm.load_state_dict(old_params) |
|
|
setattr(module, name, new_layer_norm) |
|
|
else: |
|
|
replace_layer_norm(child) |
|
|
|
|
|
|
|
|
speech_encoder_path = model_config.speech_encoder |
|
|
|
|
|
|
|
|
try: |
|
|
encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder |
|
|
except (NotImplementedError, RuntimeError) as e: |
|
|
if "meta tensor" in str(e): |
|
|
|
|
|
print(f"Detected meta tensor issue, using alternative loading approach...") |
|
|
|
|
|
|
|
|
import os |
|
|
if os.path.isfile(speech_encoder_path): |
|
|
|
|
|
checkpoint = torch.load(speech_encoder_path, map_location='cpu') |
|
|
|
|
|
|
|
|
from whisper.model import ModelDimensions, Whisper |
|
|
dims = ModelDimensions(**checkpoint["dims"]) |
|
|
model = Whisper(dims) |
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
|
|
|
|
|
|
encoder = model.encoder |
|
|
else: |
|
|
|
|
|
import whisper.model as whisper_model |
|
|
|
|
|
raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues") |
|
|
else: |
|
|
raise e |
|
|
|
|
|
replace_layer_norm(encoder) |
|
|
return encoder |
|
|
|
|
|
class DualWrappedEncoder(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.whisper_model = self.load_whisper(config) |
|
|
self.beats_model = self.load_beats(config) |
|
|
|
|
|
def load_whisper(self, model_config): |
|
|
|
|
|
def replace_layer_norm(module): |
|
|
from whisper.model import LayerNorm |
|
|
for name, child in module.named_children(): |
|
|
if isinstance(child, LayerNorm): |
|
|
|
|
|
has_meta = any(p.is_meta for p in child.parameters()) |
|
|
if has_meta: |
|
|
|
|
|
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) |
|
|
else: |
|
|
old_params = child.state_dict() |
|
|
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) |
|
|
new_layer_norm.load_state_dict(old_params) |
|
|
setattr(module, name, new_layer_norm) |
|
|
else: |
|
|
replace_layer_norm(child) |
|
|
|
|
|
|
|
|
speech_encoder_path = model_config.speech_encoder |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from whisper.model import Whisper, ModelDimensions |
|
|
|
|
|
|
|
|
ckpt = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location="cpu") |
|
|
dims = ModelDimensions(**ckpt["dims"]) |
|
|
|
|
|
|
|
|
model = Whisper(dims) |
|
|
model.to_empty(device="cpu") |
|
|
|
|
|
|
|
|
missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=True) |
|
|
model.eval() |
|
|
|
|
|
encoder = model.encoder |
|
|
print("missing:", missing) |
|
|
print("unexpected:", unexpected) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
replace_layer_norm(encoder) |
|
|
return encoder |
|
|
|
|
|
def load_beats(self, model_config): |
|
|
beats_path = model_config.music_encoder |
|
|
print("Loading BEATs Model") |
|
|
beats_ckpt = torch.load(beats_path, map_location='cpu') |
|
|
beats_cfg = BEATsConfig(beats_ckpt['cfg']) |
|
|
beats = BEATs(beats_cfg) |
|
|
beats = beats.to_empty(device='cpu') |
|
|
|
|
|
beats.load_state_dict(beats_ckpt['model'], strict=True) |
|
|
|
|
|
|
|
|
print("Checking BEATs model weights for NaN/Inf values...") |
|
|
nan_count = 0 |
|
|
inf_count = 0 |
|
|
for name, param in beats.named_parameters(): |
|
|
if torch.isnan(param).any(): |
|
|
print(f"ERROR - BEATs parameter {name} contains NaN values!") |
|
|
print(f"Debug - Parameter shape: {param.shape}") |
|
|
print(f"Debug - Parameter dtype: {param.dtype}") |
|
|
print(f"Debug - Parameter device: {param.device}") |
|
|
print(f"Debug - NaN count: {torch.isnan(param).sum().item()}") |
|
|
nan_count += 1 |
|
|
if torch.isinf(param).any(): |
|
|
print(f"ERROR - BEATs parameter {name} contains Inf values!") |
|
|
print(f"Debug - Parameter shape: {param.shape}") |
|
|
print(f"Debug - Inf count: {torch.isinf(param).sum().item()}") |
|
|
inf_count += 1 |
|
|
|
|
|
if nan_count > 0 or inf_count > 0: |
|
|
print(f"ERROR - Found NaN values in {nan_count} parameters and Inf values in {inf_count} parameters") |
|
|
print("This indicates the BEATs model weights are corrupted!") |
|
|
raise ValueError(f"BEATs model weights are corrupted: {nan_count} NaN parameters, {inf_count} Inf parameters") |
|
|
else: |
|
|
print("BEATs model weights are clean (no NaN or Inf values)") |
|
|
|
|
|
return beats |
|
|
|
|
|
def forward(self, x, raw_wav=None, audio_padding_mask=None): |
|
|
with torch.no_grad(): |
|
|
self.beats_model = self.beats_model.float() |
|
|
|
|
|
|
|
|
print(f"Debug - Speech encoder input x range: {x.min().item()} to {x.max().item()}") |
|
|
print(f"Debug - Speech encoder input x has nan: {torch.isnan(x).any().item()}") |
|
|
print(f"Debug - Speech encoder input raw_wav range: {raw_wav.min().item()} to {raw_wav.max().item()}") |
|
|
print(f"Debug - Speech encoder input raw_wav has nan: {torch.isnan(raw_wav).any().item()}") |
|
|
|
|
|
|
|
|
print(f"Debug - Whisper model device: {next(self.whisper_model.parameters()).device}") |
|
|
print(f"Debug - Input x device: {x.device}") |
|
|
|
|
|
speech_embeds = self.whisper_model(x) |
|
|
print(f"Debug - Whisper output range: {speech_embeds.min().item()} to {speech_embeds.max().item()}") |
|
|
print(f"Debug - Whisper output has nan: {torch.isnan(speech_embeds).any().item()}") |
|
|
|
|
|
|
|
|
print(f"Debug - BEATs model device: {next(self.beats_model.parameters()).device}") |
|
|
print(f"Debug - Input raw_wav device: {raw_wav.device}") |
|
|
|
|
|
|
|
|
has_nan_weights = False |
|
|
for name, param in self.beats_model.named_parameters(): |
|
|
if torch.isnan(param).any(): |
|
|
print(f"WARNING - BEATs parameter {name} still has nan values after fix!") |
|
|
has_nan_weights = True |
|
|
if not has_nan_weights: |
|
|
print("Debug - BEATs model weights are clean (no nan)") |
|
|
|
|
|
try: |
|
|
|
|
|
raw_wav_float = raw_wav.float() |
|
|
print(f"Debug - BEATs input raw_wav_float range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}") |
|
|
print(f"Debug - BEATs input raw_wav_float shape: {raw_wav_float.shape}") |
|
|
print(f"Debug - BEATs input raw_wav_float has nan: {torch.isnan(raw_wav_float).any().item()}") |
|
|
print(f"Debug - BEATs input raw_wav_float has inf: {torch.isinf(raw_wav_float).any().item()}") |
|
|
print(f"Debug - BEATs input raw_wav_float dtype: {raw_wav_float.dtype}") |
|
|
print(f"Debug - BEATs input raw_wav_float device: {raw_wav_float.device}") |
|
|
|
|
|
|
|
|
if raw_wav_float.min().item() < -1.0 or raw_wav_float.max().item() > 1.0: |
|
|
print(f"WARNING - BEATs input out of expected range [-1, 1]! Clipping to valid range.") |
|
|
raw_wav_float = torch.clamp(raw_wav_float, -1.0, 1.0) |
|
|
print(f"Debug - After clipping range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}") |
|
|
else: |
|
|
print("Debug - BEATs input is within expected range [-1, 1]") |
|
|
|
|
|
if audio_padding_mask is not None: |
|
|
print(f"Debug - BEATs input padding_mask range: {audio_padding_mask.min().item()} to {audio_padding_mask.max().item()}") |
|
|
print(f"Debug - BEATs input padding_mask shape: {audio_padding_mask.shape}") |
|
|
print(f"Debug - BEATs input padding_mask has nan: {torch.isnan(audio_padding_mask).any().item()}") |
|
|
print(f"Debug - BEATs input padding_mask dtype: {audio_padding_mask.dtype}") |
|
|
else: |
|
|
print("Debug - BEATs input padding_mask is None") |
|
|
|
|
|
|
|
|
print("Debug - BEATs model training mode:", self.beats_model.training) |
|
|
print("Debug - BEATs model device:", next(self.beats_model.parameters()).device) |
|
|
|
|
|
|
|
|
print("Debug - Calling BEATs extract_features...") |
|
|
audio_embeds, _ = self.beats_model.extract_features(raw_wav_float, padding_mask=audio_padding_mask, feature_only=True) |
|
|
print(f"Debug - BEATs output range: {audio_embeds.min().item()} to {audio_embeds.max().item()}") |
|
|
print(f"Debug - BEATs output has nan: {torch.isnan(audio_embeds).any().item()}") |
|
|
print(f"Debug - BEATs output shape: {audio_embeds.shape}") |
|
|
print(f"Debug - BEATs output dtype: {audio_embeds.dtype}") |
|
|
|
|
|
|
|
|
if torch.isnan(audio_embeds).any(): |
|
|
print("ERROR - BEATs output contains NaN values!") |
|
|
print(f"Debug - NaN positions: {torch.isnan(audio_embeds).sum().item()} out of {audio_embeds.numel()}") |
|
|
print(f"Debug - NaN ratio: {torch.isnan(audio_embeds).float().mean().item():.4f}") |
|
|
|
|
|
raise ValueError("BEATs model produced NaN values - this indicates a bug in the model or input data") |
|
|
except Exception as e: |
|
|
print(f"ERROR - BEATs model failed: {e}") |
|
|
print("Falling back to Whisper-only mode") |
|
|
|
|
|
audio_embeds = torch.zeros(speech_embeds.shape[0], speech_embeds.shape[1], 1024, device=speech_embeds.device, dtype=speech_embeds.dtype) |
|
|
|
|
|
if audio_embeds.size(1) < speech_embeds.size(1): |
|
|
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) |
|
|
elif audio_embeds.size(1) > speech_embeds.size(1): |
|
|
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) |
|
|
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) |
|
|
speech_embeds = speech_embeds.to(torch.bfloat16) |
|
|
|
|
|
|
|
|
if torch.isnan(speech_embeds).any(): |
|
|
print("ERROR - Final speech embeddings contain NaN values!") |
|
|
print(f"Debug - NaN positions: {torch.isnan(speech_embeds).sum().item()} out of {speech_embeds.numel()}") |
|
|
print(f"Debug - NaN ratio: {torch.isnan(speech_embeds).float().mean().item():.4f}") |
|
|
raise ValueError("Final speech embeddings contain NaN values - this indicates a bug in the speech encoder") |
|
|
|
|
|
return speech_embeds |