File size: 4,742 Bytes
62d115a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
# import types
import torch
import torch.nn as nn
import torch.nn.functional as F
# from transformers import WhisperFeatureExtractor
import whisper

# Import BEATs from local beats directory
from beats_model import BEATsConfig, BEATs
BEATS_AVAILABLE = True
print("✅ BEATs imported successfully from local beats directory")

os.environ["WHISPER_CACHE_DIR"] = "/data1/cxy/plm-v/modeling/cache"

def replace_layer_norm(module):
    from whisper.model import LayerNorm
    for name, child in module.named_children():
        if isinstance(child, LayerNorm):
            old_params = child.state_dict()
            new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
            new_layer_norm.load_state_dict(old_params)
            setattr(module, name, new_layer_norm)
        else:
            replace_layer_norm(child)


class DualWrappedEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.whisper_model = self.load_whisper(config)
        self.beats_model = self.load_beats(config)

    def load_whisper(self, model_config):
        # Check if custom cache directory is specified
        # Priority: model_config.whisper_cache_dir > WHISPER_CACHE_DIR env var > default
        download_root = getattr(model_config, 'whisper_cache_dir', None)
        if not download_root:
            download_root = os.environ.get('WHISPER_CACHE_DIR', None)

        # Use whisper_config if available, otherwise fallback to speech_encoder name
        if hasattr(model_config, 'whisper_config') and model_config.whisper_config:
            print(f"Loading Whisper with custom config: {model_config.whisper_config.get('d_model', 'default')}")
            # For now, still load by name but could be extended to use custom config
            if download_root:
                encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu', download_root=download_root).encoder
            else:
                encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
        else:
            if download_root:
                encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu', download_root=download_root).encoder
            else:
                encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder

        replace_layer_norm(encoder)
        return encoder

    def load_beats(self, model_config):
        # Check if BEATs should be used
        if not getattr(model_config, 'use_beats', False):
            print("BEATs model disabled in config")
            return None

        if not BEATS_AVAILABLE:
            print("BEATs not available - skipping music encoder")
            return None

        beats_path = getattr(model_config, 'beats_model_path', None)
        if not beats_path:
            print("No BEATs model path specified")
            return None

        try:
            print(f"Loading BEATs Model from {beats_path}")
            beats_ckpt = torch.load(beats_path, map_location='cpu')
            beats_cfg = BEATsConfig(beats_ckpt['cfg'])
            beats = BEATs(beats_cfg)
            beats.load_state_dict(beats_ckpt['model'])
            print("BEATs model loaded successfully")
            return beats
        except Exception as e:
            print(f"Failed to load BEATs model: {e}")
            return None

    def forward(self, x, raw_wav=None, audio_padding_mask=None):
        with torch.no_grad():
            speech_embeds = self.whisper_model(x)

            # Process with BEATs if available
            if self.beats_model is not None and raw_wav is not None:
                
                self.beats_model = self.beats_model.float()
                audio_embeds, _ = self.beats_model.extract_features(
                    raw_wav.float(), 
                    padding_mask=audio_padding_mask, 
                    feature_only=True
                )

                # Align sequence lengths
                if audio_embeds.size(1) < speech_embeds.size(1):
                    audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
                elif audio_embeds.size(1) > speech_embeds.size(1):
                    speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))

                # Concatenate Whisper and BEATs features
                speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
                print(f"Combined Whisper + BEATs features: {speech_embeds.shape}")


            speech_embeds = speech_embeds.to(torch.bfloat16)
        return speech_embeds