File size: 15,909 Bytes
84ff315 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor
import whisper
import torch
try:
torch.set_default_device("cpu")
except Exception:
pass
import accelerate
from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs
class WhisperWrappedEncoder:
@classmethod
def load(cls, model_config):
def replace_layer_norm(module):
from whisper.model import LayerNorm
for name, child in module.named_children():
if isinstance(child, LayerNorm):
# Check if any parameter is a meta tensor
has_meta = any(p.is_meta for p in child.parameters())
if has_meta:
# For meta tensors, create new layer norm with same shape
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
else:
old_params = child.state_dict()
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
new_layer_norm.load_state_dict(old_params)
setattr(module, name, new_layer_norm)
else:
replace_layer_norm(child)
# Load whisper model, handling both file paths and model names
speech_encoder_path = model_config.speech_encoder
# First try loading directly (works for both file paths and model names)
try:
encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder
except (NotImplementedError, RuntimeError) as e:
if "meta tensor" in str(e):
# Meta tensor issue - load model without device specification
print(f"Detected meta tensor issue, using alternative loading approach...")
# Load checkpoint directly to avoid device issues
import os
if os.path.isfile(speech_encoder_path):
# Load from file
checkpoint = torch.load(speech_encoder_path, map_location='cpu')
# Create model from checkpoint
from whisper.model import ModelDimensions, Whisper
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
# Load state dict without moving to device
model.load_state_dict(checkpoint["model_state_dict"])
# Get encoder without device movement
encoder = model.encoder
else:
# Try loading as model name without device
import whisper.model as whisper_model
# This is a fallback - may need adjustment based on actual model
raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues")
else:
raise e
replace_layer_norm(encoder)
return encoder
class DualWrappedEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.whisper_model = self.load_whisper(config)
self.beats_model = self.load_beats(config)
def load_whisper(self, model_config):
def replace_layer_norm(module):
from whisper.model import LayerNorm
for name, child in module.named_children():
if isinstance(child, LayerNorm):
# Check if any parameter is a meta tensor
has_meta = any(p.is_meta for p in child.parameters())
if has_meta:
# For meta tensors, create new layer norm with same shape
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
else:
old_params = child.state_dict()
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
new_layer_norm.load_state_dict(old_params)
setattr(module, name, new_layer_norm)
else:
replace_layer_norm(child)
# Load whisper model, handling both file paths and model names
speech_encoder_path = model_config.speech_encoder
# First try loading directly (works for both file paths and model names)
# try:
# breakpoint()
import torch
from whisper.model import Whisper, ModelDimensions
# 1) Load checkpoint to CPU (weights are real tensors here)
ckpt = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location="cpu")
dims = ModelDimensions(**ckpt["dims"])
# 2) Build the module skeleton, then MATERIALIZE tensors on CPU
model = Whisper(dims)
model.to_empty(device="cpu") # <-- crucial when meta is involved
# 3) Load weights
missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=True)
model.eval()
encoder = model.encoder
print("missing:", missing)
print("unexpected:", unexpected)
# with accelerate.init_empty_weights():
# encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder
# state = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location='cpu')['model_state_dict']['encoder.positional_embedding']
# breakpoint()
# except (NotImplementedError, RuntimeError) as e:
# if "meta tensor" in str(e):
# # Meta tensor issue - load model without device specification
# print(f"Detected meta tensor issue, using alternative loading approach...")
# # Load checkpoint directly to avoid device issues
# import os
# if os.path.isfile(speech_encoder_path):
# # Load from file
# checkpoint = torch.load(speech_encoder_path, map_location='cpu')
# # Create model from checkpoint
# # breakpoint()
# from whisper.model import ModelDimensions, Whisper
# dims = ModelDimensions(**checkpoint["dims"])
# model = Whisper(dims)
# # Load state dict without moving to device
# model.load_state_dict(checkpoint["model_state_dict"])
# # Get encoder without device movement
# encoder = model.encoder
# else:
# # Try loading as model name without device
# import whisper.model as whisper_model
# # This is a fallback - may need adjustment based on actual model
# raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues")
# else:
# raise e
replace_layer_norm(encoder)
return encoder
def load_beats(self, model_config):
beats_path = model_config.music_encoder
print("Loading BEATs Model")
beats_ckpt = torch.load(beats_path, map_location='cpu')
beats_cfg = BEATsConfig(beats_ckpt['cfg'])
beats = BEATs(beats_cfg)
beats = beats.to_empty(device='cpu')
# Load state dict
beats.load_state_dict(beats_ckpt['model'], strict=True)
# breakpoint()
# 检查BEATs模型权重是否有问题
print("Checking BEATs model weights for NaN/Inf values...")
nan_count = 0
inf_count = 0
for name, param in beats.named_parameters():
if torch.isnan(param).any():
print(f"ERROR - BEATs parameter {name} contains NaN values!")
print(f"Debug - Parameter shape: {param.shape}")
print(f"Debug - Parameter dtype: {param.dtype}")
print(f"Debug - Parameter device: {param.device}")
print(f"Debug - NaN count: {torch.isnan(param).sum().item()}")
nan_count += 1
if torch.isinf(param).any():
print(f"ERROR - BEATs parameter {name} contains Inf values!")
print(f"Debug - Parameter shape: {param.shape}")
print(f"Debug - Inf count: {torch.isinf(param).sum().item()}")
inf_count += 1
if nan_count > 0 or inf_count > 0:
print(f"ERROR - Found NaN values in {nan_count} parameters and Inf values in {inf_count} parameters")
print("This indicates the BEATs model weights are corrupted!")
raise ValueError(f"BEATs model weights are corrupted: {nan_count} NaN parameters, {inf_count} Inf parameters")
else:
print("BEATs model weights are clean (no NaN or Inf values)")
return beats
def forward(self, x, raw_wav=None, audio_padding_mask=None):
with torch.no_grad():
self.beats_model = self.beats_model.float()
# Debug: Check input data
print(f"Debug - Speech encoder input x range: {x.min().item()} to {x.max().item()}")
print(f"Debug - Speech encoder input x has nan: {torch.isnan(x).any().item()}")
print(f"Debug - Speech encoder input raw_wav range: {raw_wav.min().item()} to {raw_wav.max().item()}")
print(f"Debug - Speech encoder input raw_wav has nan: {torch.isnan(raw_wav).any().item()}")
# Check Whisper model
print(f"Debug - Whisper model device: {next(self.whisper_model.parameters()).device}")
print(f"Debug - Input x device: {x.device}")
speech_embeds = self.whisper_model(x)
print(f"Debug - Whisper output range: {speech_embeds.min().item()} to {speech_embeds.max().item()}")
print(f"Debug - Whisper output has nan: {torch.isnan(speech_embeds).any().item()}")
# Check BEATs model
print(f"Debug - BEATs model device: {next(self.beats_model.parameters()).device}")
print(f"Debug - Input raw_wav device: {raw_wav.device}")
# Check if BEATs model has nan weights (should be fixed now)
has_nan_weights = False
for name, param in self.beats_model.named_parameters():
if torch.isnan(param).any():
print(f"WARNING - BEATs parameter {name} still has nan values after fix!")
has_nan_weights = True
if not has_nan_weights:
print("Debug - BEATs model weights are clean (no nan)")
try:
# 详细检查BEATs模型输入
raw_wav_float = raw_wav.float()
print(f"Debug - BEATs input raw_wav_float range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}")
print(f"Debug - BEATs input raw_wav_float shape: {raw_wav_float.shape}")
print(f"Debug - BEATs input raw_wav_float has nan: {torch.isnan(raw_wav_float).any().item()}")
print(f"Debug - BEATs input raw_wav_float has inf: {torch.isinf(raw_wav_float).any().item()}")
print(f"Debug - BEATs input raw_wav_float dtype: {raw_wav_float.dtype}")
print(f"Debug - BEATs input raw_wav_float device: {raw_wav_float.device}")
# 检查输入是否在BEATs期望的范围内 [-1, 1]
if raw_wav_float.min().item() < -1.0 or raw_wav_float.max().item() > 1.0:
print(f"WARNING - BEATs input out of expected range [-1, 1]! Clipping to valid range.")
raw_wav_float = torch.clamp(raw_wav_float, -1.0, 1.0)
print(f"Debug - After clipping range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}")
else:
print("Debug - BEATs input is within expected range [-1, 1]")
if audio_padding_mask is not None:
print(f"Debug - BEATs input padding_mask range: {audio_padding_mask.min().item()} to {audio_padding_mask.max().item()}")
print(f"Debug - BEATs input padding_mask shape: {audio_padding_mask.shape}")
print(f"Debug - BEATs input padding_mask has nan: {torch.isnan(audio_padding_mask).any().item()}")
print(f"Debug - BEATs input padding_mask dtype: {audio_padding_mask.dtype}")
else:
print("Debug - BEATs input padding_mask is None")
# 在调用BEATs之前,让我们检查模型状态
print("Debug - BEATs model training mode:", self.beats_model.training)
print("Debug - BEATs model device:", next(self.beats_model.parameters()).device)
# 让我们逐步调试BEATs的内部处理
print("Debug - Calling BEATs extract_features...")
audio_embeds, _ = self.beats_model.extract_features(raw_wav_float, padding_mask=audio_padding_mask, feature_only=True)
print(f"Debug - BEATs output range: {audio_embeds.min().item()} to {audio_embeds.max().item()}")
print(f"Debug - BEATs output has nan: {torch.isnan(audio_embeds).any().item()}")
print(f"Debug - BEATs output shape: {audio_embeds.shape}")
print(f"Debug - BEATs output dtype: {audio_embeds.dtype}")
# 检查BEATs输出是否有NaN值
if torch.isnan(audio_embeds).any():
print("ERROR - BEATs output contains NaN values!")
print(f"Debug - NaN positions: {torch.isnan(audio_embeds).sum().item()} out of {audio_embeds.numel()}")
print(f"Debug - NaN ratio: {torch.isnan(audio_embeds).float().mean().item():.4f}")
# 不替换,直接抛出异常来找出根本原因
raise ValueError("BEATs model produced NaN values - this indicates a bug in the model or input data")
except Exception as e:
print(f"ERROR - BEATs model failed: {e}")
print("Falling back to Whisper-only mode")
# Create zero audio embeddings with the same shape as expected
audio_embeds = torch.zeros(speech_embeds.shape[0], speech_embeds.shape[1], 1024, device=speech_embeds.device, dtype=speech_embeds.dtype)
if audio_embeds.size(1) < speech_embeds.size(1):
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
elif audio_embeds.size(1) > speech_embeds.size(1):
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
speech_embeds = speech_embeds.to(torch.bfloat16)
# 最终检查是否有NaN值
if torch.isnan(speech_embeds).any():
print("ERROR - Final speech embeddings contain NaN values!")
print(f"Debug - NaN positions: {torch.isnan(speech_embeds).sum().item()} out of {speech_embeds.numel()}")
print(f"Debug - NaN ratio: {torch.isnan(speech_embeds).float().mean().item():.4f}")
raise ValueError("Final speech embeddings contain NaN values - this indicates a bug in the speech encoder")
return speech_embeds |