Training in progress - step 500
Browse files- model.safetensors +2 -2
- projectors.py +101 -133
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:720402f0e8107015c77907789cc8b20307741b0008412ad5e2cffc08462ae5c9
|
| 3 |
+
size 265642600
|
projectors.py
CHANGED
|
@@ -237,8 +237,8 @@ class SwiGLU(nn.Module):
|
|
| 237 |
|
| 238 |
class SwiGLUAudioProjector(nn.Module):
|
| 239 |
"""
|
| 240 |
-
|
| 241 |
-
|
| 242 |
"""
|
| 243 |
|
| 244 |
def __init__(self, config):
|
|
@@ -247,154 +247,44 @@ class SwiGLUAudioProjector(nn.Module):
|
|
| 247 |
encoder_dim = config.encoder_dim
|
| 248 |
llm_dim = config.llm_dim
|
| 249 |
|
| 250 |
-
#
|
| 251 |
-
|
| 252 |
-
hidden_dim = int(encoder_dim * 2)
|
| 253 |
|
| 254 |
-
#
|
| 255 |
-
|
| 256 |
-
swiglu_inner = int(hidden_dim * 8 / 3)
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
out_channels=hidden_dim,
|
| 261 |
-
kernel_size=self.k,
|
| 262 |
-
stride=self.k,
|
| 263 |
-
padding=0,
|
| 264 |
-
)
|
| 265 |
|
|
|
|
| 266 |
self.norm = LlamaRMSNorm(hidden_dim, eps=1e-8)
|
| 267 |
|
|
|
|
|
|
|
| 268 |
self.proj = SwiGLU(hidden_dim, swiglu_inner, llm_dim)
|
| 269 |
|
| 270 |
-
self.apply(self._init_weights)
|
| 271 |
-
|
| 272 |
-
def _init_weights(self, m):
|
| 273 |
-
if isinstance(m, (nn.Linear, nn.Conv1d)):
|
| 274 |
-
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 275 |
-
if m.bias is not None:
|
| 276 |
-
nn.init.constant_(m.bias, 0)
|
| 277 |
-
|
| 278 |
def forward(self, x):
|
| 279 |
# x: [Batch, Seq, Dim]
|
| 280 |
batch, seq, dim = x.shape
|
| 281 |
|
| 282 |
-
#
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
| 286 |
|
| 287 |
-
# [B, S, D] -> [B,
|
| 288 |
-
x = x.
|
| 289 |
|
| 290 |
-
#
|
| 291 |
-
x = self.
|
| 292 |
-
|
| 293 |
-
# [B, D, S] -> [B, S, D]
|
| 294 |
-
x = x.transpose(1, 2)
|
| 295 |
|
| 296 |
-
# Norm &
|
| 297 |
x = self.norm(x)
|
| 298 |
return self.proj(x)
|
| 299 |
|
| 300 |
def get_output_length(self, input_length: int) -> int:
|
| 301 |
-
return (input_length
|
| 302 |
-
|
| 303 |
-
# =============================================================================
|
| 304 |
-
# Residual Projector
|
| 305 |
-
# =============================================================================
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
class ResidualMLP(nn.Module):
|
| 309 |
-
"""MLP block with residual connection: Output = x + MLP(x)."""
|
| 310 |
-
|
| 311 |
-
def __init__(self, dim, hidden_dim):
|
| 312 |
-
super().__init__()
|
| 313 |
-
self.fc1 = nn.Linear(dim, hidden_dim)
|
| 314 |
-
self.fc2 = nn.Linear(hidden_dim, dim)
|
| 315 |
-
self.act = nn.GELU()
|
| 316 |
-
|
| 317 |
-
def forward(self, x):
|
| 318 |
-
residual = x
|
| 319 |
-
x = self.fc1(x)
|
| 320 |
-
x = self.act(x)
|
| 321 |
-
x = self.fc2(x)
|
| 322 |
-
return residual + x
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
class ResidualAudioProjector(nn.Module):
|
| 326 |
-
"""Residual MLP projector for audio-to-LLM feature translation."""
|
| 327 |
-
|
| 328 |
-
def __init__(self, config):
|
| 329 |
-
super().__init__()
|
| 330 |
-
|
| 331 |
-
self.k = getattr(config, "projector_pool_stride", 4)
|
| 332 |
-
in_dim = config.encoder_dim * self.k
|
| 333 |
-
out_dim = config.llm_dim
|
| 334 |
-
hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
|
| 335 |
-
self.num_layers = getattr(config, "projector_num_layers", 2)
|
| 336 |
-
|
| 337 |
-
self.input_proj = nn.Linear(in_dim, out_dim)
|
| 338 |
-
self.ln_input = LlamaRMSNorm(out_dim, eps=1e-8)
|
| 339 |
-
|
| 340 |
-
self.layers = nn.ModuleList(
|
| 341 |
-
[ResidualMLP(out_dim, hidden_dim) for _ in range(self.num_layers)]
|
| 342 |
-
)
|
| 343 |
-
self.layer_norms = nn.ModuleList(
|
| 344 |
-
[LlamaRMSNorm(out_dim, eps=1e-8) for _ in range(self.num_layers)]
|
| 345 |
-
)
|
| 346 |
-
|
| 347 |
-
self._init_weights(config)
|
| 348 |
-
|
| 349 |
-
def _init_weights(self, config):
|
| 350 |
-
std = getattr(config, "projector_init_std", 0.02)
|
| 351 |
-
|
| 352 |
-
with torch.no_grad():
|
| 353 |
-
nn.init.normal_(self.input_proj.weight, mean=0.0, std=std)
|
| 354 |
-
if self.input_proj.bias is not None:
|
| 355 |
-
nn.init.zeros_(self.input_proj.bias)
|
| 356 |
-
|
| 357 |
-
self.ln_input.weight.data.fill_(1.0)
|
| 358 |
-
for ln in self.layer_norms:
|
| 359 |
-
ln.weight.data.fill_(1.0)
|
| 360 |
-
|
| 361 |
-
for layer in self.layers:
|
| 362 |
-
nn.init.normal_(layer.fc1.weight, mean=0.0, std=std)
|
| 363 |
-
nn.init.normal_(layer.fc2.weight, mean=0.0, std=std * 0.1)
|
| 364 |
-
if layer.fc1.bias is not None:
|
| 365 |
-
nn.init.zeros_(layer.fc1.bias)
|
| 366 |
-
if layer.fc2.bias is not None:
|
| 367 |
-
nn.init.zeros_(layer.fc2.bias)
|
| 368 |
-
|
| 369 |
-
def get_output_length(self, input_length: int) -> int:
|
| 370 |
-
"""Calculate output sequence length given input length."""
|
| 371 |
-
# Temporal pooling with stride k
|
| 372 |
-
remainder = input_length % self.k
|
| 373 |
-
if remainder:
|
| 374 |
-
input_length += self.k - remainder
|
| 375 |
-
return input_length // self.k
|
| 376 |
-
|
| 377 |
-
def forward(self, x):
|
| 378 |
-
batch_size, seq_len, dim = x.size()
|
| 379 |
-
|
| 380 |
-
target_dtype = self.input_proj.weight.dtype
|
| 381 |
-
if x.dtype != target_dtype:
|
| 382 |
-
x = x.to(target_dtype)
|
| 383 |
-
|
| 384 |
-
remainder = seq_len % self.k
|
| 385 |
-
if remainder:
|
| 386 |
-
pad_len = self.k - remainder
|
| 387 |
-
x = F.pad(x, (0, 0, 0, pad_len))
|
| 388 |
-
|
| 389 |
-
x = x.contiguous().view(batch_size, -1, dim * self.k)
|
| 390 |
-
x = self.input_proj(x)
|
| 391 |
-
x = self.ln_input(x)
|
| 392 |
-
|
| 393 |
-
for layer, ln in zip(self.layers, self.layer_norms):
|
| 394 |
-
x = layer(x)
|
| 395 |
-
x = ln(x)
|
| 396 |
-
|
| 397 |
-
return x
|
| 398 |
|
| 399 |
|
| 400 |
# =============================================================================
|
|
@@ -688,6 +578,84 @@ class QFormerAudioProjector(nn.Module):
|
|
| 688 |
return self.linear(query_proj)
|
| 689 |
|
| 690 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
# =============================================================================
|
| 692 |
# Projector Registry
|
| 693 |
# =============================================================================
|
|
@@ -696,7 +664,7 @@ PROJECTOR_CLASSES = {
|
|
| 696 |
"mlp": MLPAudioProjector,
|
| 697 |
"mosa": MOSAProjector,
|
| 698 |
"swiglu": SwiGLUAudioProjector,
|
| 699 |
-
"residual": ResidualAudioProjector,
|
| 700 |
"shared_moe": SharedMoEAudioProjector,
|
| 701 |
"qformer": QFormerAudioProjector,
|
|
|
|
| 702 |
}
|
|
|
|
| 237 |
|
| 238 |
class SwiGLUAudioProjector(nn.Module):
|
| 239 |
"""
|
| 240 |
+
SwiGLU projector with frame stacking (FunASR-style).
|
| 241 |
+
Uses frame stacking for downsampling, linear projection, then SwiGLU.
|
| 242 |
"""
|
| 243 |
|
| 244 |
def __init__(self, config):
|
|
|
|
| 247 |
encoder_dim = config.encoder_dim
|
| 248 |
llm_dim = config.llm_dim
|
| 249 |
|
| 250 |
+
# Frame stacking input dimension
|
| 251 |
+
in_dim = encoder_dim * self.k # 1280 * 4 = 5120
|
|
|
|
| 252 |
|
| 253 |
+
# Hidden dim after initial projection (balanced compression like transformer)
|
| 254 |
+
hidden_dim = getattr(config, "projector_hidden_dim", None) or 4096
|
|
|
|
| 255 |
|
| 256 |
+
# Initial linear projection (frame stacking → hidden)
|
| 257 |
+
self.linear = nn.Linear(in_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
+
# Norm before SwiGLU
|
| 260 |
self.norm = LlamaRMSNorm(hidden_dim, eps=1e-8)
|
| 261 |
|
| 262 |
+
# SwiGLU with 8/3 expansion ratio
|
| 263 |
+
swiglu_inner = int(hidden_dim * 8 / 3)
|
| 264 |
self.proj = SwiGLU(hidden_dim, swiglu_inner, llm_dim)
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
def forward(self, x):
|
| 267 |
# x: [Batch, Seq, Dim]
|
| 268 |
batch, seq, dim = x.shape
|
| 269 |
|
| 270 |
+
# Padding to multiple of k
|
| 271 |
+
chunk_num = (seq - 1) // self.k + 1
|
| 272 |
+
pad_num = chunk_num * self.k - seq
|
| 273 |
+
if pad_num > 0:
|
| 274 |
+
x = F.pad(x, (0, 0, 0, pad_num))
|
| 275 |
|
| 276 |
+
# Frame stacking: [B, S, D] -> [B, S/k, D*k]
|
| 277 |
+
x = x.contiguous().view(batch, chunk_num, dim * self.k)
|
| 278 |
|
| 279 |
+
# Linear projection
|
| 280 |
+
x = self.linear(x)
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
+
# Norm & SwiGLU
|
| 283 |
x = self.norm(x)
|
| 284 |
return self.proj(x)
|
| 285 |
|
| 286 |
def get_output_length(self, input_length: int) -> int:
|
| 287 |
+
return (input_length - 1) // self.k + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
|
| 290 |
# =============================================================================
|
|
|
|
| 578 |
return self.linear(query_proj)
|
| 579 |
|
| 580 |
|
| 581 |
+
# =============================================================================
|
| 582 |
+
# Transformer Projector
|
| 583 |
+
# =============================================================================
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class TransformerAudioProjector(nn.Module):
|
| 587 |
+
"""
|
| 588 |
+
Transformer Projector (FunASR Style).
|
| 589 |
+
Projects to LLM dim first, then applies transformer blocks for context mixing.
|
| 590 |
+
"""
|
| 591 |
+
|
| 592 |
+
def __init__(self, config):
|
| 593 |
+
super().__init__()
|
| 594 |
+
# Default stride 6: Whisper (2x) * Projector (6x) = 12x total → ~8 Hz
|
| 595 |
+
# Matches FunASR's total stride (6x encoder * 2x projector = 12x)
|
| 596 |
+
self.k = getattr(config, "projector_pool_stride", 6)
|
| 597 |
+
|
| 598 |
+
encoder_dim = config.encoder_dim
|
| 599 |
+
llm_dim = config.llm_dim
|
| 600 |
+
|
| 601 |
+
# Input: Stacked frames (e.g. 1280 * 2 = 2560)
|
| 602 |
+
in_dim = encoder_dim * self.k
|
| 603 |
+
|
| 604 |
+
# FFN hidden dim for initial projection (balanced compression)
|
| 605 |
+
# 7680 → 4096 → 2048 distributes compression evenly (~2x each layer)
|
| 606 |
+
ffn_dim = getattr(config, "projector_hidden_dim", None) or 4096
|
| 607 |
+
|
| 608 |
+
# FunASR-style projection: linear1 -> relu -> linear2
|
| 609 |
+
self.linear1 = nn.Linear(in_dim, ffn_dim)
|
| 610 |
+
self.relu = nn.ReLU()
|
| 611 |
+
self.linear2 = nn.Linear(ffn_dim, llm_dim)
|
| 612 |
+
|
| 613 |
+
# Transformer blocks operating at llm_dim
|
| 614 |
+
num_layers = getattr(config, "projector_num_layers", 2)
|
| 615 |
+
if num_layers > 0:
|
| 616 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 617 |
+
d_model=llm_dim,
|
| 618 |
+
nhead=getattr(config, "projector_num_heads", 8),
|
| 619 |
+
dim_feedforward=1024, # Match FunASR (audio complexity is LLM-independent)
|
| 620 |
+
dropout=0.0,
|
| 621 |
+
activation="relu",
|
| 622 |
+
batch_first=True,
|
| 623 |
+
norm_first=True,
|
| 624 |
+
)
|
| 625 |
+
self.blocks = nn.TransformerEncoder(
|
| 626 |
+
encoder_layer, num_layers=num_layers, enable_nested_tensor=False
|
| 627 |
+
)
|
| 628 |
+
else:
|
| 629 |
+
self.blocks = None
|
| 630 |
+
|
| 631 |
+
def forward(self, x):
|
| 632 |
+
# x: [Batch, Seq, Dim]
|
| 633 |
+
batch, seq, dim = x.shape
|
| 634 |
+
|
| 635 |
+
# Padding to multiple of k
|
| 636 |
+
chunk_num = (seq - 1) // self.k + 1
|
| 637 |
+
pad_num = chunk_num * self.k - seq
|
| 638 |
+
if pad_num > 0:
|
| 639 |
+
x = F.pad(x, (0, 0, 0, pad_num))
|
| 640 |
+
|
| 641 |
+
# Frame stacking: [B, S, D] -> [B, S/k, D*k]
|
| 642 |
+
x = x.contiguous().view(batch, chunk_num, dim * self.k)
|
| 643 |
+
|
| 644 |
+
# FunASR-style projection to LLM dim
|
| 645 |
+
x = self.linear1(x)
|
| 646 |
+
x = self.relu(x)
|
| 647 |
+
x = self.linear2(x)
|
| 648 |
+
|
| 649 |
+
# Transformer context mixing
|
| 650 |
+
if self.blocks is not None:
|
| 651 |
+
x = self.blocks(x)
|
| 652 |
+
|
| 653 |
+
return x
|
| 654 |
+
|
| 655 |
+
def get_output_length(self, input_length: int) -> int:
|
| 656 |
+
return (input_length - 1) // self.k + 1
|
| 657 |
+
|
| 658 |
+
|
| 659 |
# =============================================================================
|
| 660 |
# Projector Registry
|
| 661 |
# =============================================================================
|
|
|
|
| 664 |
"mlp": MLPAudioProjector,
|
| 665 |
"mosa": MOSAProjector,
|
| 666 |
"swiglu": SwiGLUAudioProjector,
|
|
|
|
| 667 |
"shared_moe": SharedMoEAudioProjector,
|
| 668 |
"qformer": QFormerAudioProjector,
|
| 669 |
+
"transformer": TransformerAudioProjector,
|
| 670 |
}
|