import types import torch import torch.nn as nn import torch.nn.functional as F from transformers import WhisperFeatureExtractor import whisper import torch try: torch.set_default_device("cpu") except Exception: pass import accelerate from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs class WhisperWrappedEncoder: @classmethod def load(cls, model_config): def replace_layer_norm(module): from whisper.model import LayerNorm for name, child in module.named_children(): if isinstance(child, LayerNorm): # Check if any parameter is a meta tensor has_meta = any(p.is_meta for p in child.parameters()) if has_meta: # For meta tensors, create new layer norm with same shape new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) else: old_params = child.state_dict() new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) new_layer_norm.load_state_dict(old_params) setattr(module, name, new_layer_norm) else: replace_layer_norm(child) # Load whisper model, handling both file paths and model names speech_encoder_path = model_config.speech_encoder # First try loading directly (works for both file paths and model names) try: encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder except (NotImplementedError, RuntimeError) as e: if "meta tensor" in str(e): # Meta tensor issue - load model without device specification print(f"Detected meta tensor issue, using alternative loading approach...") # Load checkpoint directly to avoid device issues import os if os.path.isfile(speech_encoder_path): # Load from file checkpoint = torch.load(speech_encoder_path, map_location='cpu') # Create model from checkpoint from whisper.model import ModelDimensions, Whisper dims = ModelDimensions(**checkpoint["dims"]) model = Whisper(dims) # Load state dict without moving to device model.load_state_dict(checkpoint["model_state_dict"]) # Get encoder without device movement encoder = model.encoder else: # Try loading as model name without device import whisper.model as whisper_model # This is a fallback - may need adjustment based on actual model raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues") else: raise e replace_layer_norm(encoder) return encoder class DualWrappedEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.whisper_model = self.load_whisper(config) self.beats_model = self.load_beats(config) def load_whisper(self, model_config): def replace_layer_norm(module): from whisper.model import LayerNorm for name, child in module.named_children(): if isinstance(child, LayerNorm): # Check if any parameter is a meta tensor has_meta = any(p.is_meta for p in child.parameters()) if has_meta: # For meta tensors, create new layer norm with same shape new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) else: old_params = child.state_dict() new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) new_layer_norm.load_state_dict(old_params) setattr(module, name, new_layer_norm) else: replace_layer_norm(child) # Load whisper model, handling both file paths and model names speech_encoder_path = model_config.speech_encoder # First try loading directly (works for both file paths and model names) # try: # breakpoint() import torch from whisper.model import Whisper, ModelDimensions # 1) Load checkpoint to CPU (weights are real tensors here) ckpt = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location="cpu") dims = ModelDimensions(**ckpt["dims"]) # 2) Build the module skeleton, then MATERIALIZE tensors on CPU model = Whisper(dims) model.to_empty(device="cpu") # <-- crucial when meta is involved # 3) Load weights missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=True) model.eval() encoder = model.encoder print("missing:", missing) print("unexpected:", unexpected) # with accelerate.init_empty_weights(): # encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder # state = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location='cpu')['model_state_dict']['encoder.positional_embedding'] # breakpoint() # except (NotImplementedError, RuntimeError) as e: # if "meta tensor" in str(e): # # Meta tensor issue - load model without device specification # print(f"Detected meta tensor issue, using alternative loading approach...") # # Load checkpoint directly to avoid device issues # import os # if os.path.isfile(speech_encoder_path): # # Load from file # checkpoint = torch.load(speech_encoder_path, map_location='cpu') # # Create model from checkpoint # # breakpoint() # from whisper.model import ModelDimensions, Whisper # dims = ModelDimensions(**checkpoint["dims"]) # model = Whisper(dims) # # Load state dict without moving to device # model.load_state_dict(checkpoint["model_state_dict"]) # # Get encoder without device movement # encoder = model.encoder # else: # # Try loading as model name without device # import whisper.model as whisper_model # # This is a fallback - may need adjustment based on actual model # raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues") # else: # raise e replace_layer_norm(encoder) return encoder def load_beats(self, model_config): beats_path = model_config.music_encoder print("Loading BEATs Model") beats_ckpt = torch.load(beats_path, map_location='cpu') beats_cfg = BEATsConfig(beats_ckpt['cfg']) beats = BEATs(beats_cfg) beats = beats.to_empty(device='cpu') # Load state dict beats.load_state_dict(beats_ckpt['model'], strict=True) # breakpoint() # 检查BEATs模型权重是否有问题 print("Checking BEATs model weights for NaN/Inf values...") nan_count = 0 inf_count = 0 for name, param in beats.named_parameters(): if torch.isnan(param).any(): print(f"ERROR - BEATs parameter {name} contains NaN values!") print(f"Debug - Parameter shape: {param.shape}") print(f"Debug - Parameter dtype: {param.dtype}") print(f"Debug - Parameter device: {param.device}") print(f"Debug - NaN count: {torch.isnan(param).sum().item()}") nan_count += 1 if torch.isinf(param).any(): print(f"ERROR - BEATs parameter {name} contains Inf values!") print(f"Debug - Parameter shape: {param.shape}") print(f"Debug - Inf count: {torch.isinf(param).sum().item()}") inf_count += 1 if nan_count > 0 or inf_count > 0: print(f"ERROR - Found NaN values in {nan_count} parameters and Inf values in {inf_count} parameters") print("This indicates the BEATs model weights are corrupted!") raise ValueError(f"BEATs model weights are corrupted: {nan_count} NaN parameters, {inf_count} Inf parameters") else: print("BEATs model weights are clean (no NaN or Inf values)") return beats def forward(self, x, raw_wav=None, audio_padding_mask=None): with torch.no_grad(): self.beats_model = self.beats_model.float() # Debug: Check input data print(f"Debug - Speech encoder input x range: {x.min().item()} to {x.max().item()}") print(f"Debug - Speech encoder input x has nan: {torch.isnan(x).any().item()}") print(f"Debug - Speech encoder input raw_wav range: {raw_wav.min().item()} to {raw_wav.max().item()}") print(f"Debug - Speech encoder input raw_wav has nan: {torch.isnan(raw_wav).any().item()}") # Check Whisper model print(f"Debug - Whisper model device: {next(self.whisper_model.parameters()).device}") print(f"Debug - Input x device: {x.device}") speech_embeds = self.whisper_model(x) print(f"Debug - Whisper output range: {speech_embeds.min().item()} to {speech_embeds.max().item()}") print(f"Debug - Whisper output has nan: {torch.isnan(speech_embeds).any().item()}") # Check BEATs model print(f"Debug - BEATs model device: {next(self.beats_model.parameters()).device}") print(f"Debug - Input raw_wav device: {raw_wav.device}") # Check if BEATs model has nan weights (should be fixed now) has_nan_weights = False for name, param in self.beats_model.named_parameters(): if torch.isnan(param).any(): print(f"WARNING - BEATs parameter {name} still has nan values after fix!") has_nan_weights = True if not has_nan_weights: print("Debug - BEATs model weights are clean (no nan)") try: # 详细检查BEATs模型输入 raw_wav_float = raw_wav.float() print(f"Debug - BEATs input raw_wav_float range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}") print(f"Debug - BEATs input raw_wav_float shape: {raw_wav_float.shape}") print(f"Debug - BEATs input raw_wav_float has nan: {torch.isnan(raw_wav_float).any().item()}") print(f"Debug - BEATs input raw_wav_float has inf: {torch.isinf(raw_wav_float).any().item()}") print(f"Debug - BEATs input raw_wav_float dtype: {raw_wav_float.dtype}") print(f"Debug - BEATs input raw_wav_float device: {raw_wav_float.device}") # 检查输入是否在BEATs期望的范围内 [-1, 1] if raw_wav_float.min().item() < -1.0 or raw_wav_float.max().item() > 1.0: print(f"WARNING - BEATs input out of expected range [-1, 1]! Clipping to valid range.") raw_wav_float = torch.clamp(raw_wav_float, -1.0, 1.0) print(f"Debug - After clipping range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}") else: print("Debug - BEATs input is within expected range [-1, 1]") if audio_padding_mask is not None: print(f"Debug - BEATs input padding_mask range: {audio_padding_mask.min().item()} to {audio_padding_mask.max().item()}") print(f"Debug - BEATs input padding_mask shape: {audio_padding_mask.shape}") print(f"Debug - BEATs input padding_mask has nan: {torch.isnan(audio_padding_mask).any().item()}") print(f"Debug - BEATs input padding_mask dtype: {audio_padding_mask.dtype}") else: print("Debug - BEATs input padding_mask is None") # 在调用BEATs之前,让我们检查模型状态 print("Debug - BEATs model training mode:", self.beats_model.training) print("Debug - BEATs model device:", next(self.beats_model.parameters()).device) # 让我们逐步调试BEATs的内部处理 print("Debug - Calling BEATs extract_features...") audio_embeds, _ = self.beats_model.extract_features(raw_wav_float, padding_mask=audio_padding_mask, feature_only=True) print(f"Debug - BEATs output range: {audio_embeds.min().item()} to {audio_embeds.max().item()}") print(f"Debug - BEATs output has nan: {torch.isnan(audio_embeds).any().item()}") print(f"Debug - BEATs output shape: {audio_embeds.shape}") print(f"Debug - BEATs output dtype: {audio_embeds.dtype}") # 检查BEATs输出是否有NaN值 if torch.isnan(audio_embeds).any(): print("ERROR - BEATs output contains NaN values!") print(f"Debug - NaN positions: {torch.isnan(audio_embeds).sum().item()} out of {audio_embeds.numel()}") print(f"Debug - NaN ratio: {torch.isnan(audio_embeds).float().mean().item():.4f}") # 不替换,直接抛出异常来找出根本原因 raise ValueError("BEATs model produced NaN values - this indicates a bug in the model or input data") except Exception as e: print(f"ERROR - BEATs model failed: {e}") print("Falling back to Whisper-only mode") # Create zero audio embeddings with the same shape as expected audio_embeds = torch.zeros(speech_embeds.shape[0], speech_embeds.shape[1], 1024, device=speech_embeds.device, dtype=speech_embeds.dtype) if audio_embeds.size(1) < speech_embeds.size(1): audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) elif audio_embeds.size(1) > speech_embeds.size(1): speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) speech_embeds = speech_embeds.to(torch.bfloat16) # 最终检查是否有NaN值 if torch.isnan(speech_embeds).any(): print("ERROR - Final speech embeddings contain NaN values!") print(f"Debug - NaN positions: {torch.isnan(speech_embeds).sum().item()} out of {speech_embeds.numel()}") print(f"Debug - NaN ratio: {torch.isnan(speech_embeds).float().mean().item():.4f}") raise ValueError("Final speech embeddings contain NaN values - this indicates a bug in the speech encoder") return speech_embeds