mazesmazes commited on
Commit
26f7ea5
·
verified ·
1 Parent(s): a9a2c8b

Delete residual_projector.py

Browse files
Files changed (1) hide show
  1. residual_projector.py +0 -153
residual_projector.py DELETED
@@ -1,153 +0,0 @@
1
- """Residual MLP projector for Whisper → LLM feature space translation.
2
-
3
- Philosophy: Whisper features are already information-complete. The projector
4
- learns a nonlinear correction/refinement to align them with the LLM's expected
5
- input distribution, rather than replacing them entirely.
6
- """
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F # noqa: N812
11
-
12
-
13
- class ResidualMLP(nn.Module):
14
- """MLP block with residual connection.
15
-
16
- Output = x + MLP(x)
17
-
18
- At initialization (weights near zero), output ≈ input, providing a stable
19
- starting point. The network learns to add nonlinear corrections as needed.
20
- """
21
-
22
- def __init__(self, dim, hidden_dim, dropout=0.0):
23
- super().__init__()
24
- self.fc1 = nn.Linear(dim, hidden_dim)
25
- self.fc2 = nn.Linear(hidden_dim, dim)
26
- self.act = nn.GELU()
27
- self.dropout = nn.Dropout(dropout)
28
-
29
- def forward(self, x):
30
- residual = x
31
- x = self.fc1(x)
32
- x = self.act(x)
33
- x = self.dropout(x)
34
- x = self.fc2(x)
35
- x = self.dropout(x)
36
- return residual + x
37
-
38
-
39
- class ResidualAudioProjector(nn.Module):
40
- """Residual MLP projector for audio-to-LLM feature translation.
41
-
42
- Architecture:
43
- 1. Temporal pooling (concatenate k consecutive frames)
44
- 2. Linear projection to LLM dimension
45
- 3. N residual MLP blocks for nonlinear refinement
46
- 4. Final layer norm
47
-
48
- The linear projection handles dimension matching, while residual MLPs
49
- learn the nonlinear corrections needed to align acoustic features
50
- with semantic embedding space.
51
- """
52
-
53
- def __init__(self, config):
54
- super().__init__()
55
-
56
- # Temporal downsampling factor
57
- self.k = getattr(config, "projector_pool_stride", 4)
58
-
59
- # Dimensions
60
- in_dim = config.encoder_dim * self.k # After concatenating k frames
61
- out_dim = config.llm_dim
62
- hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
63
-
64
- # Number of residual blocks
65
- self.num_layers = getattr(config, "projector_num_layers", 2)
66
-
67
- dropout_rate = getattr(config, "projector_dropout", 0.0)
68
-
69
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
70
-
71
- # Initial projection: encoder_dim * k → llm_dim
72
- self.input_proj = nn.Linear(in_dim, out_dim)
73
- self.ln_input = LlamaRMSNorm(out_dim, eps=1e-6)
74
-
75
- # Residual MLP blocks for nonlinear refinement
76
- self.layers = nn.ModuleList(
77
- [ResidualMLP(out_dim, hidden_dim, dropout=dropout_rate) for _ in range(self.num_layers)]
78
- )
79
-
80
- # Per-layer norms (applied after each residual block)
81
- self.layer_norms = nn.ModuleList(
82
- [LlamaRMSNorm(out_dim, eps=1e-6) for _ in range(self.num_layers)]
83
- )
84
-
85
- self.output_dropout = nn.Dropout(dropout_rate)
86
-
87
- # Initialize for stable training
88
- self._init_weights(config)
89
-
90
- def _init_weights(self, config):
91
- """Initialize weights for stable residual learning.
92
-
93
- Key insight: Initialize fc2 of each residual block to near-zero
94
- so that initially output ≈ input (identity function).
95
- """
96
- std = getattr(config, "projector_init_std", 0.02)
97
-
98
- with torch.no_grad():
99
- # Input projection: standard init
100
- nn.init.normal_(self.input_proj.weight, mean=0.0, std=std)
101
- if self.input_proj.bias is not None:
102
- nn.init.zeros_(self.input_proj.bias)
103
-
104
- # Layer norms
105
- self.ln_input.weight.data.fill_(1.0)
106
- for ln in self.layer_norms:
107
- ln.weight.data.fill_(1.0)
108
-
109
- # Residual blocks: small init on output projection
110
- for layer in self.layers:
111
- nn.init.normal_(layer.fc1.weight, mean=0.0, std=std)
112
- # Initialize fc2 smaller so residual starts near identity
113
- nn.init.normal_(layer.fc2.weight, mean=0.0, std=std * 0.1)
114
- if layer.fc1.bias is not None:
115
- nn.init.zeros_(layer.fc1.bias)
116
- if layer.fc2.bias is not None:
117
- nn.init.zeros_(layer.fc2.bias)
118
-
119
- def forward(self, x):
120
- """
121
- Args:
122
- x: [batch_size, seq_len, encoder_dim] from Whisper encoder
123
-
124
- Returns:
125
- [batch_size, seq_len // k, llm_dim] projected features
126
- """
127
- batch_size, seq_len, dim = x.size()
128
-
129
- # Ensure correct dtype
130
- target_dtype = self.input_proj.weight.dtype
131
- if x.dtype != target_dtype:
132
- x = x.to(target_dtype)
133
-
134
- # Pad sequence to be divisible by k
135
- remainder = seq_len % self.k
136
- if remainder:
137
- pad_len = self.k - remainder
138
- x = F.pad(x, (0, 0, 0, pad_len))
139
-
140
- # Temporal pooling: concatenate k consecutive frames
141
- # [B, T, D] → [B, T//k, D*k]
142
- x = x.contiguous().view(batch_size, -1, dim * self.k)
143
-
144
- # Project to LLM dimension
145
- x = self.input_proj(x)
146
- x = self.ln_input(x)
147
-
148
- # Apply residual MLP blocks
149
- for layer, ln in zip(self.layers, self.layer_norms):
150
- x = layer(x)
151
- x = ln(x)
152
-
153
- return self.output_dropout(x)