mazesmazes commited on
Commit
d8c53de
·
verified ·
1 Parent(s): 134fc67

Training in progress - step 1000

Browse files
Files changed (2) hide show
  1. asr_pipeline.py +3 -0
  2. projectors.py +76 -30
asr_pipeline.py CHANGED
@@ -485,6 +485,8 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
485
  if not text:
486
  return ""
487
 
 
 
488
  # 1. LOWERCASE
489
  text = text.lower()
490
 
@@ -502,6 +504,7 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
502
  if repeat_count >= 1:
503
  words = words[: idx + n]
504
  text = " ".join(words)
 
505
  break
506
 
507
  # 3. COMBINE ACRONYMS
 
485
  if not text:
486
  return ""
487
 
488
+ original_len = len(text.split())
489
+
490
  # 1. LOWERCASE
491
  text = text.lower()
492
 
 
504
  if repeat_count >= 1:
505
  words = words[: idx + n]
506
  text = " ".join(words)
507
+ print(f"[DEBUG] Truncated repetition: {original_len} -> {len(words)} words (n={n}, repeats={repeat_count})")
508
  break
509
 
510
  # 3. COMBINE ACRONYMS
projectors.py CHANGED
@@ -1,9 +1,10 @@
1
  """Audio projector modules for bridging encoder and decoder embeddings.
2
 
3
  This module contains all projector architectures:
4
- - MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
5
- - MOSAProjector: MOSA-style dense mixture of experts
6
- - SharedMoEAudioProjector: Shared expert + sparse routed experts
 
7
  - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
8
  """
9
 
@@ -15,6 +16,51 @@ import torch.nn.functional as F # noqa: N812
15
  from transformers import AutoModel, Blip2QFormerConfig
16
  from transformers.models.llama.modeling_llama import LlamaRMSNorm
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # =============================================================================
19
  # MLP Projector
20
  # =============================================================================
@@ -63,12 +109,12 @@ class MLPAudioProjector(nn.Module):
63
 
64
 
65
  class SimpleAdapter(nn.Module):
66
- """Simple 2-layer GELU adapter (from MOSA paper)."""
67
 
68
  def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
69
  super().__init__()
70
  self.fc1 = nn.Linear(input_dim, hidden_dim)
71
- self.act = nn.GELU()
72
  self.fc2 = nn.Linear(hidden_dim, output_dim)
73
 
74
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -93,24 +139,27 @@ class MOSAProjector(nn.Module):
93
 
94
  Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
95
  Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
96
- Uses frame-stacking for downsampling (like MLP projector).
97
  """
98
 
99
  def __init__(self, config):
100
  super().__init__()
101
  self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
102
  self.llm_dim = getattr(config, "llm_dim", None) or 2048
103
- self.k = getattr(config, "projector_pool_stride", 4)
104
  self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
105
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
106
 
107
- # Optional pre-norm before projection
108
- self.use_pre_norm = getattr(config, "projector_pre_norm", False)
109
- if self.use_pre_norm:
110
- self.pre_norm = LlamaRMSNorm(self.encoder_dim, eps=1e-8)
 
 
 
 
111
 
112
- # Frame stacking: concat k adjacent frames then project
113
- in_dim = self.encoder_dim * self.k
114
 
115
  # --- 1. Simple Router (MOSA-Base: 2 layers with ReLU) ---
116
  # Maps encoder_dim -> 512 -> num_experts
@@ -122,7 +171,7 @@ class MOSAProjector(nn.Module):
122
  )
123
 
124
  # --- 2. Experts (Simple 2-layer GELU adapters) ---
125
- # Each expert: in_dim (stacked frames) -> hidden -> llm_dim
126
  self.experts = nn.ModuleList(
127
  [SimpleAdapter(in_dim, adapter_hidden, self.llm_dim) for _ in range(self.num_experts)]
128
  )
@@ -131,31 +180,27 @@ class MOSAProjector(nn.Module):
131
  # x: (B, S, encoder_dim)
132
  batch_size, seq_len, dim = x.shape
133
 
134
- # Apply pre-norm if enabled
135
- if self.use_pre_norm:
136
- x = self.pre_norm(x)
137
-
138
- # --- 1. Router Branch ---
139
- # Mean pool encoder outputs for routing decisions
140
- x_pooled = x.reshape(batch_size, -1, self.k, self.encoder_dim).mean(dim=2) # (B, S//k, D)
141
 
142
- # Router logits and softmax gating (dense MoE)
143
- routing_weights = F.softmax(self.router(x_pooled), dim=-1) # (B, S//k, num_experts)
144
-
145
- # --- 2. Frame stacking for experts ---
146
- # Reshape to combine k frames: [B, S, D] -> [B, S//k, D*k]
147
- x_stacked = x.reshape(batch_size, -1, dim * self.k)
148
 
149
  # --- 3. Expert Mixture (Dense Execution) ---
150
  # Run all experts and compute weighted sum
151
  expert_outputs = torch.stack(
152
- [expert(x_stacked) for expert in self.experts]
153
- ) # (E, B, S//k, D)
154
  return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
155
 
156
  def get_output_length(self, input_length: int) -> int:
157
  """Calculate output sequence length given input length."""
158
- return input_length // self.k
 
159
 
160
 
161
  # =============================================================================
@@ -452,6 +497,7 @@ class QFormerAudioProjector(nn.Module):
452
  # =============================================================================
453
 
454
  PROJECTOR_CLASSES = {
 
455
  "mlp": MLPAudioProjector,
456
  "mosa": MOSAProjector,
457
  "moe": MoEAudioProjector,
 
1
  """Audio projector modules for bridging encoder and decoder embeddings.
2
 
3
  This module contains all projector architectures:
4
+ - LinearProjector: Simple avg pool + linear (Chinese Dialects paper, best for Stage 1)
5
+ - MLPAudioProjector: 2-layer MLP with frame stacking downsampling
6
+ - MOSAProjector: MOSA-style dense mixture of experts (arXiv:2508.18998)
7
+ - MoEAudioProjector: Shared expert + sparse routed experts
8
  - QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
9
  """
10
 
 
16
  from transformers import AutoModel, Blip2QFormerConfig
17
  from transformers.models.llama.modeling_llama import LlamaRMSNorm
18
 
19
+ # =============================================================================
20
+ # Linear Projector (Chinese Dialects paper style)
21
+ # =============================================================================
22
+
23
+
24
+ class LinearProjector(nn.Module):
25
+ """Simple linear projector with average pooling downsampling.
26
+
27
+ Based on Chinese Dialects paper (arXiv:2505.21138) which found this
28
+ outperformed Conv1D, Transformer, and Q-Former in Stage 1 (projector-only).
29
+
30
+ Architecture: AvgPool(4x) -> Linear(encoder_dim, llm_dim)
31
+ """
32
+
33
+ def __init__(self, config):
34
+ super().__init__()
35
+
36
+ self.encoder_dim = getattr(config, "encoder_dim", 768)
37
+ self.llm_dim = getattr(config, "llm_dim", 2048)
38
+ self.pool_stride = getattr(config, "projector_pool_stride", 4) # 4x = 12.5Hz
39
+
40
+ # Single linear projection (no hidden layers, no activation)
41
+ self.linear = nn.Linear(self.encoder_dim, self.llm_dim)
42
+
43
+ def get_output_length(self, input_length: int) -> int:
44
+ """Calculate output sequence length given input length."""
45
+ return input_length // self.pool_stride
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Args:
50
+ x: [batch, seq_len, encoder_dim]
51
+
52
+ Returns:
53
+ [batch, seq_len // pool_stride, llm_dim]
54
+ """
55
+ # Average pooling for downsampling (better than frame stacking for linear)
56
+ # Transpose for avg_pool1d: [B, S, D] -> [B, D, S]
57
+ x = x.transpose(1, 2)
58
+ x = F.avg_pool1d(x, kernel_size=self.pool_stride, stride=self.pool_stride)
59
+ x = x.transpose(1, 2) # [B, S//k, D]
60
+
61
+ return self.linear(x)
62
+
63
+
64
  # =============================================================================
65
  # MLP Projector
66
  # =============================================================================
 
109
 
110
 
111
  class SimpleAdapter(nn.Module):
112
+ """Simple 2-layer ReLU adapter (from MOSA paper, arXiv:2508.18998)."""
113
 
114
  def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
115
  super().__init__()
116
  self.fc1 = nn.Linear(input_dim, hidden_dim)
117
+ self.act = nn.ReLU()
118
  self.fc2 = nn.Linear(hidden_dim, output_dim)
119
 
120
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
139
 
140
  Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
141
  Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
142
+ Uses conv-based downsampling (2x Conv1d stride-2) as described in the paper.
143
  """
144
 
145
  def __init__(self, config):
146
  super().__init__()
147
  self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
148
  self.llm_dim = getattr(config, "llm_dim", None) or 2048
 
149
  self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
150
  adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
151
 
152
+ # --- Conv-based downsampling (paper: 2x Conv1d, kernel=3, stride=2) ---
153
+ # Total 4x downsampling: 50Hz -> 12.5Hz
154
+ self.conv_downsample = nn.Sequential(
155
+ nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
156
+ nn.ReLU(),
157
+ nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
158
+ nn.ReLU(),
159
+ )
160
 
161
+ # Input dim to adapters is now just encoder_dim (not encoder_dim * k)
162
+ in_dim = self.encoder_dim
163
 
164
  # --- 1. Simple Router (MOSA-Base: 2 layers with ReLU) ---
165
  # Maps encoder_dim -> 512 -> num_experts
 
171
  )
172
 
173
  # --- 2. Experts (Simple 2-layer GELU adapters) ---
174
+ # Each expert: encoder_dim -> hidden -> llm_dim
175
  self.experts = nn.ModuleList(
176
  [SimpleAdapter(in_dim, adapter_hidden, self.llm_dim) for _ in range(self.num_experts)]
177
  )
 
180
  # x: (B, S, encoder_dim)
181
  batch_size, seq_len, dim = x.shape
182
 
183
+ # --- 1. Conv downsampling ---
184
+ # Conv1d expects (B, C, S), so transpose
185
+ x_conv = x.transpose(1, 2) # (B, D, S)
186
+ x_conv = self.conv_downsample(x_conv) # (B, D, S//4)
187
+ x_downsampled = x_conv.transpose(1, 2) # (B, S//4, D)
 
 
188
 
189
+ # --- 2. Router Branch ---
190
+ # Router operates on downsampled features
191
+ routing_weights = F.softmax(self.router(x_downsampled), dim=-1) # (B, S//4, num_experts)
 
 
 
192
 
193
  # --- 3. Expert Mixture (Dense Execution) ---
194
  # Run all experts and compute weighted sum
195
  expert_outputs = torch.stack(
196
+ [expert(x_downsampled) for expert in self.experts]
197
+ ) # (E, B, S//4, D)
198
  return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
199
 
200
  def get_output_length(self, input_length: int) -> int:
201
  """Calculate output sequence length given input length."""
202
+ # Two stride-2 convs = 4x downsampling
203
+ return input_length // 4
204
 
205
 
206
  # =============================================================================
 
497
  # =============================================================================
498
 
499
  PROJECTOR_CLASSES = {
500
+ "linear": LinearProjector,
501
  "mlp": MLPAudioProjector,
502
  "mosa": MOSAProjector,
503
  "moe": MoEAudioProjector,