jjw0126's picture
Upload files
84ff315 verified
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor
import whisper
import torch
try:
torch.set_default_device("cpu")
except Exception:
pass
import accelerate
from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs
class WhisperWrappedEncoder:
@classmethod
def load(cls, model_config):
def replace_layer_norm(module):
from whisper.model import LayerNorm
for name, child in module.named_children():
if isinstance(child, LayerNorm):
# Check if any parameter is a meta tensor
has_meta = any(p.is_meta for p in child.parameters())
if has_meta:
# For meta tensors, create new layer norm with same shape
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
else:
old_params = child.state_dict()
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
new_layer_norm.load_state_dict(old_params)
setattr(module, name, new_layer_norm)
else:
replace_layer_norm(child)
# Load whisper model, handling both file paths and model names
speech_encoder_path = model_config.speech_encoder
# First try loading directly (works for both file paths and model names)
try:
encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder
except (NotImplementedError, RuntimeError) as e:
if "meta tensor" in str(e):
# Meta tensor issue - load model without device specification
print(f"Detected meta tensor issue, using alternative loading approach...")
# Load checkpoint directly to avoid device issues
import os
if os.path.isfile(speech_encoder_path):
# Load from file
checkpoint = torch.load(speech_encoder_path, map_location='cpu')
# Create model from checkpoint
from whisper.model import ModelDimensions, Whisper
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
# Load state dict without moving to device
model.load_state_dict(checkpoint["model_state_dict"])
# Get encoder without device movement
encoder = model.encoder
else:
# Try loading as model name without device
import whisper.model as whisper_model
# This is a fallback - may need adjustment based on actual model
raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues")
else:
raise e
replace_layer_norm(encoder)
return encoder
class DualWrappedEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.whisper_model = self.load_whisper(config)
self.beats_model = self.load_beats(config)
def load_whisper(self, model_config):
def replace_layer_norm(module):
from whisper.model import LayerNorm
for name, child in module.named_children():
if isinstance(child, LayerNorm):
# Check if any parameter is a meta tensor
has_meta = any(p.is_meta for p in child.parameters())
if has_meta:
# For meta tensors, create new layer norm with same shape
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
else:
old_params = child.state_dict()
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
new_layer_norm.load_state_dict(old_params)
setattr(module, name, new_layer_norm)
else:
replace_layer_norm(child)
# Load whisper model, handling both file paths and model names
speech_encoder_path = model_config.speech_encoder
# First try loading directly (works for both file paths and model names)
# try:
# breakpoint()
import torch
from whisper.model import Whisper, ModelDimensions
# 1) Load checkpoint to CPU (weights are real tensors here)
ckpt = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location="cpu")
dims = ModelDimensions(**ckpt["dims"])
# 2) Build the module skeleton, then MATERIALIZE tensors on CPU
model = Whisper(dims)
model.to_empty(device="cpu") # <-- crucial when meta is involved
# 3) Load weights
missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=True)
model.eval()
encoder = model.encoder
print("missing:", missing)
print("unexpected:", unexpected)
# with accelerate.init_empty_weights():
# encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder
# state = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location='cpu')['model_state_dict']['encoder.positional_embedding']
# breakpoint()
# except (NotImplementedError, RuntimeError) as e:
# if "meta tensor" in str(e):
# # Meta tensor issue - load model without device specification
# print(f"Detected meta tensor issue, using alternative loading approach...")
# # Load checkpoint directly to avoid device issues
# import os
# if os.path.isfile(speech_encoder_path):
# # Load from file
# checkpoint = torch.load(speech_encoder_path, map_location='cpu')
# # Create model from checkpoint
# # breakpoint()
# from whisper.model import ModelDimensions, Whisper
# dims = ModelDimensions(**checkpoint["dims"])
# model = Whisper(dims)
# # Load state dict without moving to device
# model.load_state_dict(checkpoint["model_state_dict"])
# # Get encoder without device movement
# encoder = model.encoder
# else:
# # Try loading as model name without device
# import whisper.model as whisper_model
# # This is a fallback - may need adjustment based on actual model
# raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues")
# else:
# raise e
replace_layer_norm(encoder)
return encoder
def load_beats(self, model_config):
beats_path = model_config.music_encoder
print("Loading BEATs Model")
beats_ckpt = torch.load(beats_path, map_location='cpu')
beats_cfg = BEATsConfig(beats_ckpt['cfg'])
beats = BEATs(beats_cfg)
beats = beats.to_empty(device='cpu')
# Load state dict
beats.load_state_dict(beats_ckpt['model'], strict=True)
# breakpoint()
# 检查BEATs模型权重是否有问题
print("Checking BEATs model weights for NaN/Inf values...")
nan_count = 0
inf_count = 0
for name, param in beats.named_parameters():
if torch.isnan(param).any():
print(f"ERROR - BEATs parameter {name} contains NaN values!")
print(f"Debug - Parameter shape: {param.shape}")
print(f"Debug - Parameter dtype: {param.dtype}")
print(f"Debug - Parameter device: {param.device}")
print(f"Debug - NaN count: {torch.isnan(param).sum().item()}")
nan_count += 1
if torch.isinf(param).any():
print(f"ERROR - BEATs parameter {name} contains Inf values!")
print(f"Debug - Parameter shape: {param.shape}")
print(f"Debug - Inf count: {torch.isinf(param).sum().item()}")
inf_count += 1
if nan_count > 0 or inf_count > 0:
print(f"ERROR - Found NaN values in {nan_count} parameters and Inf values in {inf_count} parameters")
print("This indicates the BEATs model weights are corrupted!")
raise ValueError(f"BEATs model weights are corrupted: {nan_count} NaN parameters, {inf_count} Inf parameters")
else:
print("BEATs model weights are clean (no NaN or Inf values)")
return beats
def forward(self, x, raw_wav=None, audio_padding_mask=None):
with torch.no_grad():
self.beats_model = self.beats_model.float()
# Debug: Check input data
print(f"Debug - Speech encoder input x range: {x.min().item()} to {x.max().item()}")
print(f"Debug - Speech encoder input x has nan: {torch.isnan(x).any().item()}")
print(f"Debug - Speech encoder input raw_wav range: {raw_wav.min().item()} to {raw_wav.max().item()}")
print(f"Debug - Speech encoder input raw_wav has nan: {torch.isnan(raw_wav).any().item()}")
# Check Whisper model
print(f"Debug - Whisper model device: {next(self.whisper_model.parameters()).device}")
print(f"Debug - Input x device: {x.device}")
speech_embeds = self.whisper_model(x)
print(f"Debug - Whisper output range: {speech_embeds.min().item()} to {speech_embeds.max().item()}")
print(f"Debug - Whisper output has nan: {torch.isnan(speech_embeds).any().item()}")
# Check BEATs model
print(f"Debug - BEATs model device: {next(self.beats_model.parameters()).device}")
print(f"Debug - Input raw_wav device: {raw_wav.device}")
# Check if BEATs model has nan weights (should be fixed now)
has_nan_weights = False
for name, param in self.beats_model.named_parameters():
if torch.isnan(param).any():
print(f"WARNING - BEATs parameter {name} still has nan values after fix!")
has_nan_weights = True
if not has_nan_weights:
print("Debug - BEATs model weights are clean (no nan)")
try:
# 详细检查BEATs模型输入
raw_wav_float = raw_wav.float()
print(f"Debug - BEATs input raw_wav_float range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}")
print(f"Debug - BEATs input raw_wav_float shape: {raw_wav_float.shape}")
print(f"Debug - BEATs input raw_wav_float has nan: {torch.isnan(raw_wav_float).any().item()}")
print(f"Debug - BEATs input raw_wav_float has inf: {torch.isinf(raw_wav_float).any().item()}")
print(f"Debug - BEATs input raw_wav_float dtype: {raw_wav_float.dtype}")
print(f"Debug - BEATs input raw_wav_float device: {raw_wav_float.device}")
# 检查输入是否在BEATs期望的范围内 [-1, 1]
if raw_wav_float.min().item() < -1.0 or raw_wav_float.max().item() > 1.0:
print(f"WARNING - BEATs input out of expected range [-1, 1]! Clipping to valid range.")
raw_wav_float = torch.clamp(raw_wav_float, -1.0, 1.0)
print(f"Debug - After clipping range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}")
else:
print("Debug - BEATs input is within expected range [-1, 1]")
if audio_padding_mask is not None:
print(f"Debug - BEATs input padding_mask range: {audio_padding_mask.min().item()} to {audio_padding_mask.max().item()}")
print(f"Debug - BEATs input padding_mask shape: {audio_padding_mask.shape}")
print(f"Debug - BEATs input padding_mask has nan: {torch.isnan(audio_padding_mask).any().item()}")
print(f"Debug - BEATs input padding_mask dtype: {audio_padding_mask.dtype}")
else:
print("Debug - BEATs input padding_mask is None")
# 在调用BEATs之前,让我们检查模型状态
print("Debug - BEATs model training mode:", self.beats_model.training)
print("Debug - BEATs model device:", next(self.beats_model.parameters()).device)
# 让我们逐步调试BEATs的内部处理
print("Debug - Calling BEATs extract_features...")
audio_embeds, _ = self.beats_model.extract_features(raw_wav_float, padding_mask=audio_padding_mask, feature_only=True)
print(f"Debug - BEATs output range: {audio_embeds.min().item()} to {audio_embeds.max().item()}")
print(f"Debug - BEATs output has nan: {torch.isnan(audio_embeds).any().item()}")
print(f"Debug - BEATs output shape: {audio_embeds.shape}")
print(f"Debug - BEATs output dtype: {audio_embeds.dtype}")
# 检查BEATs输出是否有NaN值
if torch.isnan(audio_embeds).any():
print("ERROR - BEATs output contains NaN values!")
print(f"Debug - NaN positions: {torch.isnan(audio_embeds).sum().item()} out of {audio_embeds.numel()}")
print(f"Debug - NaN ratio: {torch.isnan(audio_embeds).float().mean().item():.4f}")
# 不替换,直接抛出异常来找出根本原因
raise ValueError("BEATs model produced NaN values - this indicates a bug in the model or input data")
except Exception as e:
print(f"ERROR - BEATs model failed: {e}")
print("Falling back to Whisper-only mode")
# Create zero audio embeddings with the same shape as expected
audio_embeds = torch.zeros(speech_embeds.shape[0], speech_embeds.shape[1], 1024, device=speech_embeds.device, dtype=speech_embeds.dtype)
if audio_embeds.size(1) < speech_embeds.size(1):
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
elif audio_embeds.size(1) > speech_embeds.size(1):
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
speech_embeds = speech_embeds.to(torch.bfloat16)
# 最终检查是否有NaN值
if torch.isnan(speech_embeds).any():
print("ERROR - Final speech embeddings contain NaN values!")
print(f"Debug - NaN positions: {torch.isnan(speech_embeds).sum().item()} out of {speech_embeds.numel()}")
print(f"Debug - NaN ratio: {torch.isnan(speech_embeds).float().mean().item():.4f}")
raise ValueError("Final speech embeddings contain NaN values - this indicates a bug in the speech encoder")
return speech_embeds