mazesmazes commited on
Commit
f61bf72
·
verified ·
1 Parent(s): b50e976

Delete moe_projector.py

Browse files
Files changed (1) hide show
  1. moe_projector.py +0 -162
moe_projector.py DELETED
@@ -1,162 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F # noqa: N812
4
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
5
-
6
-
7
- class SimpleAdapter(nn.Module):
8
- """
9
- MOSA Section III-B:
10
- "consists of two linear layers with a ReLU activation in between,
11
- projecting the hidden dimension from 3072 to 4096 and back to 3072."
12
- """
13
-
14
- def __init__(self, in_features, hidden_features, out_features, dropout=0.0):
15
- super().__init__()
16
- self.fc1 = nn.Linear(in_features, hidden_features)
17
- self.relu = nn.ReLU()
18
- self.dropout = nn.Dropout(dropout)
19
- self.fc2 = nn.Linear(hidden_features, out_features)
20
-
21
- def forward(self, x):
22
- x = self.fc1(x)
23
- x = self.relu(x)
24
- x = self.dropout(x)
25
- return self.fc2(x)
26
-
27
-
28
- class MoEAudioProjector(nn.Module):
29
- """
30
- MOSA-style projector: Mixture of Simple Adapters.
31
-
32
- From paper (arXiv:2508.18998):
33
- - Dense mixture (softmax over ALL experts) instead of sparse Top-K
34
- - Simple Linear->ReLU->Linear adapters (3072->4096->3072)
35
- - No auxiliary losses - just cross-entropy on transcripts
36
- - Conv downsampling: stride 4 total (two conv layers, stride 2 each)
37
- """
38
-
39
- def __init__(self, config):
40
- super().__init__()
41
-
42
- # Dimensions:
43
- # Whisper-large-v3 encoder_dim = 1280
44
- # SmolLM3-3B hidden_size = 2048
45
- self.encoder_dim = config.encoder_dim # 1280
46
- self.llm_dim = config.llm_dim # 2048
47
-
48
- # Number of experts: Base=4, Large=8
49
- self.num_experts = getattr(config, "num_experts", 4)
50
-
51
- # Adapter hidden dim: paper uses 4096
52
- adapter_hidden = getattr(config, "projector_hidden_dim", None) or 4096
53
-
54
- # Dropout rate for experts (not applied to router)
55
- self.dropout_rate = getattr(config, "projector_dropout", 0.1)
56
-
57
- # --- Convolutional Subsampling (Section III-B) ---
58
- # "two convolutional layers, each with a kernel size of 3 and a stride of 2"
59
- # Maps encoder_dim (1280) -> llm_dim (3072), total stride=4
60
- self.conv = nn.Sequential(
61
- nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
62
- nn.ReLU(),
63
- nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
64
- nn.ReLU(),
65
- )
66
-
67
- # --- Router (Section III-B) ---
68
- # Base: "two linear layers... mapping from 1280 to 512 and finally to 4"
69
- router_hidden = 512
70
- self.router = nn.Sequential(
71
- nn.Linear(self.encoder_dim, router_hidden),
72
- nn.ReLU(),
73
- nn.Linear(router_hidden, self.num_experts),
74
- )
75
-
76
- # --- Experts / Adapters (Section III-B) ---
77
- # "projecting the hidden dimension from 3072 to 4096 and back to 3072"
78
- self.experts = nn.ModuleList(
79
- [
80
- SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim, dropout=self.dropout_rate)
81
- for _ in range(self.num_experts)
82
- ]
83
- )
84
-
85
- # Normalization for stability (not in original MOSA but prevents FPE)
86
- self.ln_post = LlamaRMSNorm(self.llm_dim, eps=1e-6)
87
-
88
- # Initialize weights
89
- self._init_weights()
90
-
91
- def _init_weights(self):
92
- """Initialize weights for stable training."""
93
- std = 0.02
94
- with torch.no_grad():
95
- # Conv layers
96
- for module in self.conv:
97
- if isinstance(module, nn.Conv1d):
98
- nn.init.normal_(module.weight, mean=0.0, std=std)
99
- if module.bias is not None:
100
- nn.init.zeros_(module.bias)
101
-
102
- # Router
103
- for module in self.router:
104
- if isinstance(module, nn.Linear):
105
- nn.init.normal_(module.weight, mean=0.0, std=std)
106
- if module.bias is not None:
107
- nn.init.zeros_(module.bias)
108
-
109
- # Experts
110
- for expert in self.experts:
111
- nn.init.normal_(expert.fc1.weight, mean=0.0, std=std)
112
- nn.init.normal_(expert.fc2.weight, mean=0.0, std=std)
113
- if expert.fc1.bias is not None:
114
- nn.init.zeros_(expert.fc1.bias)
115
- if expert.fc2.bias is not None:
116
- nn.init.zeros_(expert.fc2.bias)
117
-
118
- # LayerNorm
119
- self.ln_post.weight.data.fill_(1.0)
120
-
121
- def forward(self, x):
122
- """
123
- Args:
124
- x: [batch_size, seq_len, encoder_dim] from Whisper encoder (1280)
125
-
126
- Returns:
127
- output: [batch_size, seq_len // 4, llm_dim] (3072)
128
- """
129
- batch_size, seq_len, _ = x.shape
130
-
131
- # Pad to be divisible by stride (4)
132
- pad_amt = (4 - (seq_len % 4)) % 4
133
- if pad_amt > 0:
134
- x = F.pad(x, (0, 0, 0, pad_amt))
135
- seq_len = x.shape[1]
136
-
137
- # 1. Convolutional Downsampling
138
- # (B, T, C) -> (B, C, T) -> conv -> (B, C, T//4) -> (B, T//4, C)
139
- h_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1)
140
-
141
- # 2. Router on high-res input, then downsample weights
142
- router_logits = self.router(x) # [B, T, num_experts]
143
- # Average over stride window to match conv output
144
- router_logits = router_logits.view(batch_size, seq_len // 4, 4, self.num_experts).mean(
145
- dim=2
146
- )
147
- # Dense softmax
148
- routing_weights = F.softmax(router_logits, dim=-1) # [B, T//4, num_experts]
149
-
150
- # 3. Weighted sum of expert outputs (Eq. 2: y = sum(w_i * E_i(x)))
151
- # Use in-place add to reduce memory allocations
152
- final_out = torch.zeros_like(h_conv)
153
- for i, expert in enumerate(self.experts):
154
- expert_out = expert(h_conv)
155
- expert_weight = routing_weights[:, :, i : i + 1]
156
- final_out.add_(expert_out * expert_weight)
157
-
158
- return self.ln_post(final_out)
159
-
160
- def get_aux_loss(self) -> torch.Tensor:
161
- """Return auxiliary loss (none for dense MoE - all experts always used)."""
162
- return torch.tensor(0.0)