File size: 15,909 Bytes
84ff315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor
import whisper
import torch
try:
    torch.set_default_device("cpu")
except Exception:
    pass 
import accelerate
from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs

class WhisperWrappedEncoder:
    
    @classmethod
    def load(cls, model_config):

        def replace_layer_norm(module):
            from whisper.model import LayerNorm
            for name, child in module.named_children():
                if isinstance(child, LayerNorm):
                    # Check if any parameter is a meta tensor
                    has_meta = any(p.is_meta for p in child.parameters())
                    if has_meta:
                        # For meta tensors, create new layer norm with same shape
                        new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
                    else:
                        old_params = child.state_dict()
                        new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
                        new_layer_norm.load_state_dict(old_params)
                    setattr(module, name, new_layer_norm)
                else:
                    replace_layer_norm(child)

        # Load whisper model, handling both file paths and model names
        speech_encoder_path = model_config.speech_encoder
        
        # First try loading directly (works for both file paths and model names)
        try:
            encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder
        except (NotImplementedError, RuntimeError) as e:
            if "meta tensor" in str(e):
                # Meta tensor issue - load model without device specification
                print(f"Detected meta tensor issue, using alternative loading approach...")
                
                # Load checkpoint directly to avoid device issues
                import os
                if os.path.isfile(speech_encoder_path):
                    # Load from file
                    checkpoint = torch.load(speech_encoder_path, map_location='cpu')
                    
                    # Create model from checkpoint
                    from whisper.model import ModelDimensions, Whisper
                    dims = ModelDimensions(**checkpoint["dims"])
                    model = Whisper(dims)
                    
                    # Load state dict without moving to device
                    model.load_state_dict(checkpoint["model_state_dict"])
                    
                    # Get encoder without device movement
                    encoder = model.encoder
                else:
                    # Try loading as model name without device
                    import whisper.model as whisper_model
                    # This is a fallback - may need adjustment based on actual model
                    raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues")
            else:
                raise e

        replace_layer_norm(encoder)
        return encoder
    
class DualWrappedEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.whisper_model = self.load_whisper(config)
        self.beats_model = self.load_beats(config)
    
    def load_whisper(self, model_config):

        def replace_layer_norm(module):
            from whisper.model import LayerNorm
            for name, child in module.named_children():
                if isinstance(child, LayerNorm):
                    # Check if any parameter is a meta tensor
                    has_meta = any(p.is_meta for p in child.parameters())
                    if has_meta:
                        # For meta tensors, create new layer norm with same shape
                        new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
                    else:
                        old_params = child.state_dict()
                        new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
                        new_layer_norm.load_state_dict(old_params)
                    setattr(module, name, new_layer_norm)
                else:
                    replace_layer_norm(child)

        # Load whisper model, handling both file paths and model names
        speech_encoder_path = model_config.speech_encoder
        
        # First try loading directly (works for both file paths and model names)
        # try:
        # breakpoint()
        import torch
        from whisper.model import Whisper, ModelDimensions

        # 1) Load checkpoint to CPU (weights are real tensors here)
        ckpt = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location="cpu")
        dims = ModelDimensions(**ckpt["dims"])

        # 2) Build the module skeleton, then MATERIALIZE tensors on CPU
        model = Whisper(dims)
        model.to_empty(device="cpu")   # <-- crucial when meta is involved

        # 3) Load weights
        missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=True)
        model.eval()

        encoder = model.encoder
        print("missing:", missing)
        print("unexpected:", unexpected)
        # with accelerate.init_empty_weights():
        #     encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder
            # state  = torch.load("/data1/cxy/model/THUdyh/Ola-7b/large-v3.pt", map_location='cpu')['model_state_dict']['encoder.positional_embedding']
            # breakpoint()
        # except (NotImplementedError, RuntimeError) as e:
        #     if "meta tensor" in str(e):
        #         # Meta tensor issue - load model without device specification
        #         print(f"Detected meta tensor issue, using alternative loading approach...")
                
        #         # Load checkpoint directly to avoid device issues
        #         import os
        #         if os.path.isfile(speech_encoder_path):
        #             # Load from file
        #             checkpoint = torch.load(speech_encoder_path, map_location='cpu')
                    
        #             # Create model from checkpoint
        #             # breakpoint()
        #             from whisper.model import ModelDimensions, Whisper
        #             dims = ModelDimensions(**checkpoint["dims"])
        #             model = Whisper(dims)
                    
        #             # Load state dict without moving to device
        #             model.load_state_dict(checkpoint["model_state_dict"])
                    
        #             # Get encoder without device movement
        #             encoder = model.encoder
        #         else:
        #             # Try loading as model name without device
        #             import whisper.model as whisper_model
        #             # This is a fallback - may need adjustment based on actual model
        #             raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues")
        #     else:
        #         raise e

        replace_layer_norm(encoder)
        return encoder

    def load_beats(self, model_config):
        beats_path = model_config.music_encoder
        print("Loading BEATs Model")
        beats_ckpt = torch.load(beats_path, map_location='cpu')
        beats_cfg = BEATsConfig(beats_ckpt['cfg'])
        beats = BEATs(beats_cfg)
        beats = beats.to_empty(device='cpu')
        # Load state dict
        beats.load_state_dict(beats_ckpt['model'], strict=True)
        # breakpoint()
        # 检查BEATs模型权重是否有问题
        print("Checking BEATs model weights for NaN/Inf values...")
        nan_count = 0
        inf_count = 0
        for name, param in beats.named_parameters():
            if torch.isnan(param).any():
                print(f"ERROR - BEATs parameter {name} contains NaN values!")
                print(f"Debug - Parameter shape: {param.shape}")
                print(f"Debug - Parameter dtype: {param.dtype}")
                print(f"Debug - Parameter device: {param.device}")
                print(f"Debug - NaN count: {torch.isnan(param).sum().item()}")
                nan_count += 1
            if torch.isinf(param).any():
                print(f"ERROR - BEATs parameter {name} contains Inf values!")
                print(f"Debug - Parameter shape: {param.shape}")
                print(f"Debug - Inf count: {torch.isinf(param).sum().item()}")
                inf_count += 1
        
        if nan_count > 0 or inf_count > 0:
            print(f"ERROR - Found NaN values in {nan_count} parameters and Inf values in {inf_count} parameters")
            print("This indicates the BEATs model weights are corrupted!")
            raise ValueError(f"BEATs model weights are corrupted: {nan_count} NaN parameters, {inf_count} Inf parameters")
        else:
            print("BEATs model weights are clean (no NaN or Inf values)")
            
        return beats

    def forward(self, x, raw_wav=None, audio_padding_mask=None):
        with torch.no_grad():
            self.beats_model = self.beats_model.float()
            
            # Debug: Check input data
            print(f"Debug - Speech encoder input x range: {x.min().item()} to {x.max().item()}")
            print(f"Debug - Speech encoder input x has nan: {torch.isnan(x).any().item()}")
            print(f"Debug - Speech encoder input raw_wav range: {raw_wav.min().item()} to {raw_wav.max().item()}")
            print(f"Debug - Speech encoder input raw_wav has nan: {torch.isnan(raw_wav).any().item()}")
            
            # Check Whisper model
            print(f"Debug - Whisper model device: {next(self.whisper_model.parameters()).device}")
            print(f"Debug - Input x device: {x.device}")
            
            speech_embeds = self.whisper_model(x)
            print(f"Debug - Whisper output range: {speech_embeds.min().item()} to {speech_embeds.max().item()}")
            print(f"Debug - Whisper output has nan: {torch.isnan(speech_embeds).any().item()}")
            
            # Check BEATs model
            print(f"Debug - BEATs model device: {next(self.beats_model.parameters()).device}")
            print(f"Debug - Input raw_wav device: {raw_wav.device}")
            
            # Check if BEATs model has nan weights (should be fixed now)
            has_nan_weights = False
            for name, param in self.beats_model.named_parameters():
                if torch.isnan(param).any():
                    print(f"WARNING - BEATs parameter {name} still has nan values after fix!")
                    has_nan_weights = True
            if not has_nan_weights:
                print("Debug - BEATs model weights are clean (no nan)")
            
            try:
                # 详细检查BEATs模型输入
                raw_wav_float = raw_wav.float()
                print(f"Debug - BEATs input raw_wav_float range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}")
                print(f"Debug - BEATs input raw_wav_float shape: {raw_wav_float.shape}")
                print(f"Debug - BEATs input raw_wav_float has nan: {torch.isnan(raw_wav_float).any().item()}")
                print(f"Debug - BEATs input raw_wav_float has inf: {torch.isinf(raw_wav_float).any().item()}")
                print(f"Debug - BEATs input raw_wav_float dtype: {raw_wav_float.dtype}")
                print(f"Debug - BEATs input raw_wav_float device: {raw_wav_float.device}")
                
                # 检查输入是否在BEATs期望的范围内 [-1, 1]
                if raw_wav_float.min().item() < -1.0 or raw_wav_float.max().item() > 1.0:
                    print(f"WARNING - BEATs input out of expected range [-1, 1]! Clipping to valid range.")
                    raw_wav_float = torch.clamp(raw_wav_float, -1.0, 1.0)
                    print(f"Debug - After clipping range: {raw_wav_float.min().item()} to {raw_wav_float.max().item()}")
                else:
                    print("Debug - BEATs input is within expected range [-1, 1]")
                
                if audio_padding_mask is not None:
                    print(f"Debug - BEATs input padding_mask range: {audio_padding_mask.min().item()} to {audio_padding_mask.max().item()}")
                    print(f"Debug - BEATs input padding_mask shape: {audio_padding_mask.shape}")
                    print(f"Debug - BEATs input padding_mask has nan: {torch.isnan(audio_padding_mask).any().item()}")
                    print(f"Debug - BEATs input padding_mask dtype: {audio_padding_mask.dtype}")
                else:
                    print("Debug - BEATs input padding_mask is None")
                
                # 在调用BEATs之前,让我们检查模型状态
                print("Debug - BEATs model training mode:", self.beats_model.training)
                print("Debug - BEATs model device:", next(self.beats_model.parameters()).device)
                
                # 让我们逐步调试BEATs的内部处理
                print("Debug - Calling BEATs extract_features...")
                audio_embeds, _ = self.beats_model.extract_features(raw_wav_float, padding_mask=audio_padding_mask, feature_only=True)
                print(f"Debug - BEATs output range: {audio_embeds.min().item()} to {audio_embeds.max().item()}")
                print(f"Debug - BEATs output has nan: {torch.isnan(audio_embeds).any().item()}")
                print(f"Debug - BEATs output shape: {audio_embeds.shape}")
                print(f"Debug - BEATs output dtype: {audio_embeds.dtype}")
                
                # 检查BEATs输出是否有NaN值
                if torch.isnan(audio_embeds).any():
                    print("ERROR - BEATs output contains NaN values!")
                    print(f"Debug - NaN positions: {torch.isnan(audio_embeds).sum().item()} out of {audio_embeds.numel()}")
                    print(f"Debug - NaN ratio: {torch.isnan(audio_embeds).float().mean().item():.4f}")
                    # 不替换,直接抛出异常来找出根本原因
                    raise ValueError("BEATs model produced NaN values - this indicates a bug in the model or input data")
            except Exception as e:
                print(f"ERROR - BEATs model failed: {e}")
                print("Falling back to Whisper-only mode")
                # Create zero audio embeddings with the same shape as expected
                audio_embeds = torch.zeros(speech_embeds.shape[0], speech_embeds.shape[1], 1024, device=speech_embeds.device, dtype=speech_embeds.dtype)
                
        if audio_embeds.size(1) < speech_embeds.size(1):
            audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
        elif audio_embeds.size(1) > speech_embeds.size(1):
            speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
        speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
        speech_embeds = speech_embeds.to(torch.bfloat16)
        
        # 最终检查是否有NaN值
        if torch.isnan(speech_embeds).any():
            print("ERROR - Final speech embeddings contain NaN values!")
            print(f"Debug - NaN positions: {torch.isnan(speech_embeds).sum().item()} out of {speech_embeds.numel()}")
            print(f"Debug - NaN ratio: {torch.isnan(speech_embeds).float().mean().item():.4f}")
            raise ValueError("Final speech embeddings contain NaN values - this indicates a bug in the speech encoder")
            
        return speech_embeds