Training in progress - step 1000
Browse files- asr_modeling.py +4 -1
- chat_template.jinja +94 -89
- projectors.py +155 -136
- tokenizer.json +2 -2
- tokenizer_config.json +9 -7
asr_modeling.py
CHANGED
|
@@ -419,7 +419,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
|
|
| 419 |
# Compute per-sample encoder output lengths using conv formulas
|
| 420 |
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 421 |
token_counts = torch.tensor(
|
| 422 |
-
[
|
|
|
|
|
|
|
|
|
|
| 423 |
device=audio_embeds.device,
|
| 424 |
)
|
| 425 |
|
|
|
|
| 419 |
# Compute per-sample encoder output lengths using conv formulas
|
| 420 |
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 421 |
token_counts = torch.tensor(
|
| 422 |
+
[
|
| 423 |
+
self.projector.get_output_length(int(length.item()))
|
| 424 |
+
for length in encoder_lengths
|
| 425 |
+
],
|
| 426 |
device=audio_embeds.device,
|
| 427 |
)
|
| 428 |
|
chat_template.jinja
CHANGED
|
@@ -1,89 +1,94 @@
|
|
| 1 |
-
{
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
{%-
|
| 17 |
-
{%- set
|
| 18 |
-
{%-
|
| 19 |
-
{%- set
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
{%-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
{%-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
{%-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
{
|
| 77 |
-
{
|
| 78 |
-
|
| 79 |
-
{%-
|
| 80 |
-
{{
|
| 81 |
-
{%- endif
|
| 82 |
-
|
| 83 |
-
{%-
|
| 84 |
-
{
|
| 85 |
-
{
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
{%-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{# ───── defaults ───── #}
|
| 2 |
+
{%- if enable_thinking is not defined -%}
|
| 3 |
+
{%- set enable_thinking = true -%}
|
| 4 |
+
{%- endif -%}
|
| 5 |
+
|
| 6 |
+
{# ───── reasoning mode ───── #}
|
| 7 |
+
{%- if enable_thinking -%}
|
| 8 |
+
{%- set reasoning_mode = "/think" -%}
|
| 9 |
+
{%- else -%}
|
| 10 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 11 |
+
{%- endif -%}
|
| 12 |
+
|
| 13 |
+
{# ───── header (system message) ───── #}
|
| 14 |
+
{{- "<|im_start|>system\n" -}}
|
| 15 |
+
|
| 16 |
+
{%- if messages[0].role == "system" -%}
|
| 17 |
+
{%- set system_message = messages[0].content -%}
|
| 18 |
+
{%- if "/no_think" in system_message -%}
|
| 19 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 20 |
+
{%- elif "/think" in system_message -%}
|
| 21 |
+
{%- set reasoning_mode = "/think" -%}
|
| 22 |
+
{%- endif -%}
|
| 23 |
+
{%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
|
| 24 |
+
{%- endif -%}
|
| 25 |
+
|
| 26 |
+
{%- if "/system_override" in system_message -%}
|
| 27 |
+
{{- custom_instructions.replace("/system_override", "").rstrip() -}}
|
| 28 |
+
{{- "<|im_end|>\n" -}}
|
| 29 |
+
{%- else -%}
|
| 30 |
+
{{- "## Metadata\n\n" -}}
|
| 31 |
+
{{- "Knowledge Cutoff Date: June 2025\n" -}}
|
| 32 |
+
{%- set today = strftime_now("%d %B %Y") -%}
|
| 33 |
+
{{- "Today Date: " ~ today ~ "\n" -}}
|
| 34 |
+
{{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
|
| 35 |
+
|
| 36 |
+
{{- "## Custom Instructions\n\n" -}}
|
| 37 |
+
{%- if custom_instructions -%}
|
| 38 |
+
{{- custom_instructions + "\n\n" -}}
|
| 39 |
+
{%- elif reasoning_mode == "/think" -%}
|
| 40 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
|
| 41 |
+
{%- else -%}
|
| 42 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
| 43 |
+
{%- endif -%}
|
| 44 |
+
|
| 45 |
+
{%- if xml_tools or python_tools or tools -%}
|
| 46 |
+
{{- "### Tools\n\n" -}}
|
| 47 |
+
{%- if xml_tools or tools -%}
|
| 48 |
+
{%- if tools -%}
|
| 49 |
+
{%- set xml_tools = tools -%}
|
| 50 |
+
{%- endif -%}
|
| 51 |
+
{%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
|
| 52 |
+
{%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
|
| 53 |
+
{%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
|
| 54 |
+
{%- endfor -%}
|
| 55 |
+
{%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
|
| 56 |
+
{{- xml_tool_string -}}
|
| 57 |
+
{%- endif -%}
|
| 58 |
+
{%- if python_tools -%}
|
| 59 |
+
{%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
|
| 60 |
+
{%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
|
| 61 |
+
{%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
|
| 62 |
+
{%- endfor -%}
|
| 63 |
+
{%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
|
| 64 |
+
{{- python_tool_string -}}
|
| 65 |
+
{%- endif -%}
|
| 66 |
+
{{- "\n\n" -}}
|
| 67 |
+
{{- "<|im_end|>\n" -}}
|
| 68 |
+
{%- endif -%}
|
| 69 |
+
{%- endif -%}
|
| 70 |
+
{# ───── main loop ───── #}
|
| 71 |
+
{%- for message in messages -%}
|
| 72 |
+
{%- set content = message.content if message.content is string else "" -%}
|
| 73 |
+
{%- if message.role == "user" -%}
|
| 74 |
+
{{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
|
| 75 |
+
{%- elif message.role == "assistant" -%}
|
| 76 |
+
{% generation %}
|
| 77 |
+
{%- if reasoning_mode == "/think" -%}
|
| 78 |
+
{{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 79 |
+
{%- else -%}
|
| 80 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 81 |
+
{%- endif -%}
|
| 82 |
+
{% endgeneration %}
|
| 83 |
+
{%- elif message.role == "tool" -%}
|
| 84 |
+
{{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
|
| 85 |
+
{%- endif -%}
|
| 86 |
+
{%- endfor -%}
|
| 87 |
+
{# ───── generation prompt ───── #}
|
| 88 |
+
{%- if add_generation_prompt -%}
|
| 89 |
+
{%- if reasoning_mode == "/think" -%}
|
| 90 |
+
{{ "<|im_start|>assistant\n" }}
|
| 91 |
+
{%- else -%}
|
| 92 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
|
| 93 |
+
{%- endif -%}
|
| 94 |
+
{%- endif -%}
|
projectors.py
CHANGED
|
@@ -87,6 +87,34 @@ class SimpleAdapter(nn.Module):
|
|
| 87 |
return self.fc2(self.act(self.fc1(x)))
|
| 88 |
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
class MOSAProjector(nn.Module):
|
| 91 |
"""MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
|
| 92 |
|
|
@@ -166,109 +194,18 @@ class MOSAProjector(nn.Module):
|
|
| 166 |
|
| 167 |
|
| 168 |
# =============================================================================
|
| 169 |
-
# MoE Projector (
|
| 170 |
# =============================================================================
|
| 171 |
|
| 172 |
|
| 173 |
-
class
|
| 174 |
-
"""MoE
|
| 175 |
-
|
| 176 |
-
def __init__(
|
| 177 |
-
self,
|
| 178 |
-
input_dim: int,
|
| 179 |
-
hidden_dim: int,
|
| 180 |
-
output_dim: int,
|
| 181 |
-
num_experts: int = 4,
|
| 182 |
-
top_k: int = 2,
|
| 183 |
-
):
|
| 184 |
-
super().__init__()
|
| 185 |
-
self.num_experts = num_experts
|
| 186 |
-
self.top_k = top_k
|
| 187 |
-
self.output_dim = output_dim
|
| 188 |
-
|
| 189 |
-
# RMSNorm before routing
|
| 190 |
-
self.norm = LlamaRMSNorm(input_dim, eps=1e-8)
|
| 191 |
-
|
| 192 |
-
self.router = nn.Linear(input_dim, num_experts, bias=False)
|
| 193 |
-
nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
|
| 194 |
-
|
| 195 |
-
self.shared_expert = SimpleAdapter(input_dim, hidden_dim, output_dim)
|
| 196 |
-
self.experts = nn.ModuleList(
|
| 197 |
-
[SimpleAdapter(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
self.last_router_logits = None
|
| 201 |
-
self.last_router_probs = None
|
| 202 |
-
|
| 203 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 204 |
-
batch_size, seq_len, dim = hidden_states.shape
|
| 205 |
-
|
| 206 |
-
# 1. Apply Shared Expert
|
| 207 |
-
normed_states = self.norm(hidden_states)
|
| 208 |
-
shared_out = self.shared_expert(normed_states)
|
| 209 |
-
|
| 210 |
-
# 2. Router Logic (Sigmoid Style)
|
| 211 |
-
flat_hidden = normed_states.view(-1, dim)
|
| 212 |
-
router_logits = self.router(flat_hidden)
|
| 213 |
-
|
| 214 |
-
# Sigmoid routing
|
| 215 |
-
router_probs = torch.sigmoid(router_logits)
|
| 216 |
-
|
| 217 |
-
self.last_router_logits = router_logits
|
| 218 |
-
self.last_router_probs = router_probs
|
| 219 |
-
|
| 220 |
-
# 3. Top-K Selection
|
| 221 |
-
top_k_scores, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
| 222 |
-
|
| 223 |
-
# Normalize weights
|
| 224 |
-
top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-6)
|
| 225 |
-
top_k_weights = top_k_weights.to(hidden_states.dtype)
|
| 226 |
-
|
| 227 |
-
# 4. Dispatch
|
| 228 |
-
routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
|
| 229 |
-
routed_out = routed_out.view(batch_size, seq_len, -1)
|
| 230 |
-
|
| 231 |
-
return shared_out + routed_out
|
| 232 |
-
|
| 233 |
-
def _dispatch_experts(
|
| 234 |
-
self,
|
| 235 |
-
hidden_states: torch.Tensor,
|
| 236 |
-
top_k_indices: torch.Tensor,
|
| 237 |
-
top_k_weights: torch.Tensor,
|
| 238 |
-
) -> torch.Tensor:
|
| 239 |
-
num_tokens = hidden_states.shape[0]
|
| 240 |
-
output = torch.zeros(
|
| 241 |
-
num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
|
| 242 |
-
)
|
| 243 |
-
|
| 244 |
-
for expert_idx, expert in enumerate(self.experts):
|
| 245 |
-
expert_mask = top_k_indices == expert_idx
|
| 246 |
-
if not expert_mask.any():
|
| 247 |
-
continue
|
| 248 |
-
|
| 249 |
-
token_indices, slot_indices = torch.where(expert_mask)
|
| 250 |
-
expert_input = hidden_states[token_indices]
|
| 251 |
-
expert_output = expert(expert_input).to(output.dtype)
|
| 252 |
-
weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
|
| 253 |
-
output.index_add_(0, token_indices, expert_output * weights)
|
| 254 |
-
|
| 255 |
-
return output
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
|
| 259 |
-
"""Auxiliary loss to encourage balanced expert usage."""
|
| 260 |
-
prob_per_expert = router_probs.mean(dim=0)
|
| 261 |
-
target_mean = prob_per_expert.mean()
|
| 262 |
-
return (prob_per_expert - target_mean).square().sum() * num_experts
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
|
| 266 |
-
"""Z-loss to prevent router logits from growing too large."""
|
| 267 |
-
return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
|
| 268 |
|
|
|
|
|
|
|
| 269 |
|
| 270 |
-
|
| 271 |
-
"""
|
| 272 |
|
| 273 |
def __init__(self, config):
|
| 274 |
"""Initialize MoE projector.
|
|
@@ -279,40 +216,59 @@ class MoEAudioProjector(nn.Module):
|
|
| 279 |
super().__init__()
|
| 280 |
|
| 281 |
self.k = getattr(config, "projector_pool_stride", 4)
|
| 282 |
-
|
| 283 |
|
| 284 |
-
#
|
| 285 |
-
self.
|
| 286 |
-
|
| 287 |
-
)
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
-
in_dim = encoder_dim * self.k
|
| 290 |
out_dim = config.llm_dim
|
| 291 |
-
hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
self.num_experts = getattr(config, "num_experts", 4)
|
| 294 |
self.top_k = getattr(config, "num_experts_per_tok", 2)
|
| 295 |
-
self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
|
| 296 |
-
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
|
| 297 |
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
self._init_weights()
|
| 300 |
|
|
|
|
|
|
|
| 301 |
def _init_weights(self):
|
|
|
|
| 302 |
with torch.no_grad():
|
| 303 |
-
|
| 304 |
-
nn.init.
|
| 305 |
|
| 306 |
-
for
|
| 307 |
-
|
| 308 |
-
nn.init.
|
|
|
|
| 309 |
|
| 310 |
def get_output_length(self, input_length: int) -> int:
|
| 311 |
-
"""Calculate output sequence length given input length."""
|
| 312 |
-
|
| 313 |
-
if input_length % self.k:
|
| 314 |
-
input_length += self.k - input_length % self.k
|
| 315 |
-
return input_length // self.k
|
| 316 |
|
| 317 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 318 |
"""Project audio features using shared + sparse MoE.
|
|
@@ -323,32 +279,95 @@ class MoEAudioProjector(nn.Module):
|
|
| 323 |
Returns:
|
| 324 |
Projected features of shape [batch, out_len, llm_dim]
|
| 325 |
"""
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
-
#
|
| 333 |
-
|
| 334 |
-
x_ctx = self.temporal_conv(x_ctx)
|
| 335 |
-
x = x + x_ctx.transpose(1, 2)
|
| 336 |
|
| 337 |
-
|
| 338 |
-
|
| 339 |
|
| 340 |
-
|
|
|
|
| 341 |
|
| 342 |
-
|
|
|
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
-
|
| 349 |
-
|
|
|
|
| 350 |
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
|
| 354 |
# =============================================================================
|
|
|
|
| 87 |
return self.fc2(self.act(self.fc1(x)))
|
| 88 |
|
| 89 |
|
| 90 |
+
class SwiGLU(nn.Module):
|
| 91 |
+
"""SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.)."""
|
| 92 |
+
|
| 93 |
+
def __init__(self, dim: int, hidden_dim: int, bias: bool = False):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
|
| 96 |
+
self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
|
| 97 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output
|
| 98 |
+
|
| 99 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class AsymmetricSwiGLU(nn.Module):
|
| 104 |
+
"""SwiGLU that handles different input and output dimensions."""
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self, in_features: int, hidden_features: int, out_features: int, bias: bool = False
|
| 108 |
+
):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.w1 = nn.Linear(in_features, hidden_features, bias=bias) # Gate
|
| 111 |
+
self.w2 = nn.Linear(in_features, hidden_features, bias=bias) # Value
|
| 112 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias) # Output
|
| 113 |
+
|
| 114 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
class MOSAProjector(nn.Module):
|
| 119 |
"""MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
|
| 120 |
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
# =============================================================================
|
| 197 |
+
# MoE Projector (Pure PyTorch with Shared Expert)
|
| 198 |
# =============================================================================
|
| 199 |
|
| 200 |
|
| 201 |
+
class MoEAudioProjector(nn.Module):
|
| 202 |
+
"""MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens.
|
| 205 |
+
No external dependencies (megablocks removed).
|
| 206 |
|
| 207 |
+
Architecture matches main branch: norm → experts(in_dim → hidden → out_dim)
|
| 208 |
+
"""
|
| 209 |
|
| 210 |
def __init__(self, config):
|
| 211 |
"""Initialize MoE projector.
|
|
|
|
| 216 |
super().__init__()
|
| 217 |
|
| 218 |
self.k = getattr(config, "projector_pool_stride", 4)
|
| 219 |
+
self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01)
|
| 220 |
|
| 221 |
+
# Stability coefficients
|
| 222 |
+
self.router_z_loss_coef = getattr(
|
| 223 |
+
config, "router_z_loss_coef", 1e-4
|
| 224 |
+
) # Prevents logit explosion
|
| 225 |
+
self.router_jitter_noise = getattr(
|
| 226 |
+
config, "router_jitter_noise", 0.01
|
| 227 |
+
) # Prevents expert collapse
|
| 228 |
|
| 229 |
+
in_dim = config.encoder_dim * self.k
|
| 230 |
out_dim = config.llm_dim
|
|
|
|
| 231 |
|
| 232 |
+
# Expert hidden dim (default = output dim)
|
| 233 |
+
hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim
|
| 234 |
+
|
| 235 |
+
# Number of experts and top-k selection
|
| 236 |
self.num_experts = getattr(config, "num_experts", 4)
|
| 237 |
self.top_k = getattr(config, "num_experts_per_tok", 2)
|
|
|
|
|
|
|
| 238 |
|
| 239 |
+
# A. Normalize stacked input (like main branch SharedMoEBlock)
|
| 240 |
+
self.norm = LlamaRMSNorm(in_dim, eps=1e-6)
|
| 241 |
+
|
| 242 |
+
# B. Router (operates on stacked input)
|
| 243 |
+
self.router = nn.Linear(in_dim, self.num_experts, bias=False)
|
| 244 |
+
|
| 245 |
+
# C. Experts: simple 2-layer MLP (same as MLPAudioProjector)
|
| 246 |
+
self.experts = nn.ModuleList(
|
| 247 |
+
[SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)]
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# D. Shared Expert (same architecture)
|
| 251 |
+
self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim)
|
| 252 |
+
|
| 253 |
+
# E. Initialize weights for stable training
|
| 254 |
self._init_weights()
|
| 255 |
|
| 256 |
+
self.last_aux_loss = torch.tensor(0.0)
|
| 257 |
+
|
| 258 |
def _init_weights(self):
|
| 259 |
+
"""Initialize weights for stable training start."""
|
| 260 |
with torch.no_grad():
|
| 261 |
+
# Router: small weights -> uniform probability
|
| 262 |
+
nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
|
| 263 |
|
| 264 |
+
# Experts: xavier for fc1, small for fc2 (output)
|
| 265 |
+
for expert in [self.shared_expert, *self.experts]:
|
| 266 |
+
nn.init.xavier_uniform_(expert.fc1.weight)
|
| 267 |
+
nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) # Small init
|
| 268 |
|
| 269 |
def get_output_length(self, input_length: int) -> int:
|
| 270 |
+
"""Calculate output sequence length given input length (matches MLP projector)."""
|
| 271 |
+
return (input_length - self.k) // self.k + 1
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 274 |
"""Project audio features using shared + sparse MoE.
|
|
|
|
| 279 |
Returns:
|
| 280 |
Projected features of shape [batch, out_len, llm_dim]
|
| 281 |
"""
|
| 282 |
+
# 1. Frame Stacking
|
| 283 |
+
batch, seq, dim = x.shape
|
| 284 |
+
out_len = (seq - self.k) // self.k + 1
|
| 285 |
+
x = x[:, : out_len * self.k, :]
|
| 286 |
+
x = x.reshape(batch, out_len, dim * self.k)
|
| 287 |
+
|
| 288 |
+
# 2. Normalize stacked input (like main branch SharedMoEBlock)
|
| 289 |
+
x = self.norm(x)
|
| 290 |
+
flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
|
| 291 |
+
|
| 292 |
+
# 3. Shared Expert (compute first, creates output tensor)
|
| 293 |
+
output = self.shared_expert(flat_x)
|
| 294 |
+
|
| 295 |
+
# 4. Sparse Experts (in-place add to shared output)
|
| 296 |
+
self.last_aux_loss = self._forward_sparse(flat_x, output)
|
| 297 |
+
|
| 298 |
+
return output.view(batch, out_len, -1)
|
| 299 |
+
|
| 300 |
+
def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
|
| 301 |
+
"""Stability-hardened sparse expert dispatch (in-place add to output).
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
x: Flattened input of shape [tokens, dim]
|
| 305 |
+
output: Output tensor to add sparse expert results into (in-place)
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Auxiliary loss tensor
|
| 309 |
+
"""
|
| 310 |
+
# A. Router Logic with Jitter
|
| 311 |
+
logits = self.router(x)
|
| 312 |
|
| 313 |
+
if self.training and self.router_jitter_noise > 0:
|
| 314 |
+
# Jitter: multiply by uniform noise (1-eps, 1+eps) to shake decision boundary
|
| 315 |
+
# Prevents router from getting stuck on one expert early in training
|
| 316 |
+
noise = torch.empty_like(logits).uniform_(
|
| 317 |
+
1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise
|
| 318 |
+
)
|
| 319 |
+
logits = logits * noise
|
| 320 |
|
| 321 |
+
# Force float32 for softmax (bf16/fp16 exponentials can overflow)
|
| 322 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x)
|
|
|
|
|
|
|
| 323 |
|
| 324 |
+
# B. Top-K Selection
|
| 325 |
+
top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
|
| 326 |
|
| 327 |
+
# Normalize weights so they sum to 1.0
|
| 328 |
+
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
|
| 329 |
|
| 330 |
+
# C. Aux Loss + Z-Loss
|
| 331 |
+
aux_loss = torch.tensor(0.0, device=x.device)
|
| 332 |
|
| 333 |
+
if self.training:
|
| 334 |
+
# Load balancing loss (batch-size invariant)
|
| 335 |
+
prob_per_expert = probs.mean(0) # [num_experts]
|
| 336 |
+
target = 1.0 / self.num_experts
|
| 337 |
+
balance_loss = (
|
| 338 |
+
self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Z-loss: penalty on large logits to prevent softmax saturation
|
| 342 |
+
z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean()
|
| 343 |
+
|
| 344 |
+
aux_loss = balance_loss + z_loss
|
| 345 |
+
|
| 346 |
+
# D. Dispatch Loop (in-place add to output)
|
| 347 |
+
for i, expert in enumerate(self.experts):
|
| 348 |
+
# Create boolean mask for tokens that selected Expert 'i'
|
| 349 |
+
mask = top_k_indices == i
|
| 350 |
|
| 351 |
+
if mask.any():
|
| 352 |
+
# token_idx = which tokens, k_idx = 1st or 2nd choice
|
| 353 |
+
token_idx, k_idx = torch.where(mask)
|
| 354 |
|
| 355 |
+
# Gather inputs and compute
|
| 356 |
+
expert_input = x[token_idx]
|
| 357 |
+
expert_output = expert(expert_input)
|
| 358 |
+
|
| 359 |
+
# Apply routing weight
|
| 360 |
+
weight = top_k_weights[token_idx, k_idx].unsqueeze(-1)
|
| 361 |
+
weighted_output = (expert_output * weight).type_as(output)
|
| 362 |
+
|
| 363 |
+
# Scatter back in-place (index_add_ is atomic and deterministic)
|
| 364 |
+
output.index_add_(0, token_idx, weighted_output)
|
| 365 |
+
|
| 366 |
+
return aux_loss
|
| 367 |
+
|
| 368 |
+
def get_aux_loss(self) -> torch.Tensor:
|
| 369 |
+
"""Return auxiliary load balancing loss."""
|
| 370 |
+
return self.last_aux_loss
|
| 371 |
|
| 372 |
|
| 373 |
# =============================================================================
|
tokenizer.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
|
| 3 |
+
size 17209003
|
tokenizer_config.json
CHANGED
|
@@ -1,17 +1,19 @@
|
|
| 1 |
{
|
| 2 |
-
"add_prefix_space": false,
|
| 3 |
"backend": "tokenizers",
|
| 4 |
"bos_token": null,
|
| 5 |
-
"clean_up_tokenization_spaces":
|
| 6 |
"eos_token": "<|im_end|>",
|
| 7 |
-
"errors": "replace",
|
| 8 |
"extra_special_tokens": [
|
| 9 |
"<audio>"
|
| 10 |
],
|
|
|
|
| 11 |
"is_local": false,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"model_max_length": 131072,
|
| 13 |
-
"
|
| 14 |
-
"
|
| 15 |
-
"tokenizer_class": "
|
| 16 |
-
"unk_token": null
|
| 17 |
}
|
|
|
|
| 1 |
{
|
|
|
|
| 2 |
"backend": "tokenizers",
|
| 3 |
"bos_token": null,
|
| 4 |
+
"clean_up_tokenization_spaces": true,
|
| 5 |
"eos_token": "<|im_end|>",
|
|
|
|
| 6 |
"extra_special_tokens": [
|
| 7 |
"<audio>"
|
| 8 |
],
|
| 9 |
+
"fast": false,
|
| 10 |
"is_local": false,
|
| 11 |
+
"model_input_names": [
|
| 12 |
+
"input_ids",
|
| 13 |
+
"attention_mask"
|
| 14 |
+
],
|
| 15 |
"model_max_length": 131072,
|
| 16 |
+
"model_specific_special_tokens": {},
|
| 17 |
+
"pad_token": "<|finetune_right_pad_id|>",
|
| 18 |
+
"tokenizer_class": "TokenizersBackend"
|
|
|
|
| 19 |
}
|