mazesmazes commited on
Commit
a9a2c8b
·
verified ·
1 Parent(s): 6a72d9f

Delete mlp_projector.py

Browse files
Files changed (1) hide show
  1. mlp_projector.py +0 -42
mlp_projector.py DELETED
@@ -1,42 +0,0 @@
1
- import torch.nn as nn
2
-
3
-
4
- class MLPAudioProjector(nn.Module):
5
- """2-layer MLP projector with Qwen-style 2x temporal downsampling."""
6
-
7
- def __init__(self, config):
8
- super().__init__()
9
-
10
- encoder_dim = getattr(config, "encoder_dim", 768)
11
- llm_dim = getattr(config, "llm_dim", 2048)
12
-
13
- self.downsample = nn.Conv1d(
14
- encoder_dim, encoder_dim, kernel_size=3, stride=2, padding=1, bias=False
15
- )
16
- self.linear_1 = nn.Linear(encoder_dim, llm_dim, bias=False)
17
- self.act = nn.GELU()
18
- self.linear_2 = nn.Linear(llm_dim, llm_dim, bias=False)
19
-
20
- self.apply(self._init_weights)
21
-
22
- def _init_weights(self, module):
23
- if isinstance(module, nn.Linear):
24
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
25
- elif isinstance(module, nn.Conv1d):
26
- nn.init.normal_(module.weight, mean=0.0, std=0.02)
27
- if module.bias is not None:
28
- nn.init.zeros_(module.bias)
29
-
30
- def forward(self, x):
31
- """
32
- x: [Batch, Seq_Len, Dim]
33
- Returns: [Batch, Seq_Len // 2, llm_dim]
34
- """
35
- # Conv1d expects [Batch, Channels, Seq_Len]
36
- x = x.transpose(1, 2)
37
- x = self.downsample(x)
38
- x = x.transpose(1, 2)
39
-
40
- x = self.linear_1(x)
41
- x = self.act(x)
42
- return self.linear_2(x)