mazesmazes commited on
Commit
b50e976
·
verified ·
1 Parent(s): 26f7ea5

Delete swiglu_projector.py

Browse files
Files changed (1) hide show
  1. swiglu_projector.py +0 -68
swiglu_projector.py DELETED
@@ -1,68 +0,0 @@
1
- """Simple SwiGLU-based audio projector."""
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F # noqa: N812
6
-
7
-
8
- class SwiGLU(nn.Module):
9
- def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
10
- super().__init__()
11
- self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
12
- self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
13
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
14
- self.act = nn.SiLU()
15
- self.dropout = nn.Dropout(dropout)
16
-
17
- def forward(self, x):
18
- x_gate = self.act(self.w1(x))
19
- x_val = self.w2(x)
20
- x = x_gate * x_val
21
- x = self.dropout(x)
22
- return self.w3(x)
23
-
24
-
25
- class AudioProjector(nn.Module):
26
- def __init__(self, config):
27
- super().__init__()
28
- self.k = getattr(config, "projector_pool_stride", 4)
29
- in_dim = config.encoder_dim * self.k
30
- out_dim = config.llm_dim
31
- hidden_dim = config.projector_hidden_dim
32
- if hidden_dim is None:
33
- hidden_dim = config.encoder_dim * 2
34
-
35
- dropout_rate = getattr(config, "projector_dropout", 0.0)
36
-
37
- self.proj1 = SwiGLU(in_dim, hidden_dim, hidden_dim, dropout=dropout_rate)
38
- self.proj2 = SwiGLU(hidden_dim, hidden_dim, out_dim, dropout=dropout_rate)
39
- self.output_dropout = nn.Dropout(dropout_rate)
40
-
41
- with torch.no_grad():
42
- std = getattr(config, "projector_init_std", 0.02)
43
- # Initialize first layer
44
- nn.init.normal_(self.proj1.w1.weight, mean=0.0, std=std)
45
- nn.init.normal_(self.proj1.w2.weight, mean=0.0, std=std)
46
- nn.init.normal_(self.proj1.w3.weight, mean=0.0, std=std)
47
- # Initialize second layer
48
- nn.init.normal_(self.proj2.w1.weight, mean=0.0, std=std)
49
- nn.init.normal_(self.proj2.w2.weight, mean=0.0, std=std)
50
- nn.init.normal_(self.proj2.w3.weight, mean=0.0, std=std)
51
-
52
- def forward(self, x):
53
- batch_size, seq_len, dim = x.size()
54
-
55
- target_dtype = self.proj1.w1.weight.dtype
56
- if x.dtype != target_dtype:
57
- x = x.to(target_dtype)
58
-
59
- remainder = seq_len % self.k
60
- if remainder:
61
- pad_len = self.k - remainder
62
- x = F.pad(x, (0, 0, 0, pad_len))
63
-
64
- x = x.contiguous().view(batch_size, -1, dim * self.k)
65
- x = self.proj1(x)
66
- x = self.proj2(x)
67
-
68
- return self.output_dropout(x)