lhallee commited on
Commit
a9244b3
·
verified ·
1 Parent(s): 4a48dcd

Upload vb_layers_transition.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vb_layers_transition.py +78 -78
vb_layers_transition.py CHANGED
@@ -1,78 +1,78 @@
1
- from typing import Optional
2
-
3
- from torch import Tensor, nn
4
-
5
- from . import vb_layers_initialize as init
6
-
7
-
8
- class Transition(nn.Module):
9
- """Perform a two-layer MLP."""
10
-
11
- def __init__(
12
- self,
13
- dim: int = 128,
14
- hidden: int = 512,
15
- out_dim: Optional[int] = None,
16
- ) -> None:
17
- """Initialize the TransitionUpdate module.
18
-
19
- Parameters
20
- ----------
21
- dim: int
22
- The dimension of the input, default 128
23
- hidden: int
24
- The dimension of the hidden, default 512
25
- out_dim: Optional[int]
26
- The dimension of the output, default None
27
-
28
- """
29
- super().__init__()
30
- if out_dim is None:
31
- out_dim = dim
32
-
33
- self.norm = nn.LayerNorm(dim, eps=1e-5)
34
- self.fc1 = nn.Linear(dim, hidden, bias=False)
35
- self.fc2 = nn.Linear(dim, hidden, bias=False)
36
- self.fc3 = nn.Linear(hidden, out_dim, bias=False)
37
- self.silu = nn.SiLU()
38
- self.hidden = hidden
39
-
40
- init.bias_init_one_(self.norm.weight)
41
- init.bias_init_zero_(self.norm.bias)
42
-
43
- init.lecun_normal_init_(self.fc1.weight)
44
- init.lecun_normal_init_(self.fc2.weight)
45
- init.final_init_(self.fc3.weight)
46
-
47
- def forward(self, x: Tensor, chunk_size: int = None) -> Tensor:
48
- """Perform a forward pass.
49
-
50
- Parameters
51
- ----------
52
- x: torch.Tensor
53
- The input data of shape (..., D)
54
-
55
- Returns
56
- -------
57
- x: torch.Tensor
58
- The output data of shape (..., D)
59
-
60
- """
61
- x = self.norm(x)
62
-
63
- if chunk_size is None or self.training:
64
- x = self.silu(self.fc1(x)) * self.fc2(x)
65
- x = self.fc3(x)
66
- return x
67
- else:
68
- # Compute in chunks
69
- for i in range(0, self.hidden, chunk_size):
70
- fc1_slice = self.fc1.weight[i : i + chunk_size, :]
71
- fc2_slice = self.fc2.weight[i : i + chunk_size, :]
72
- fc3_slice = self.fc3.weight[:, i : i + chunk_size]
73
- x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T)
74
- if i == 0:
75
- x_out = x_chunk @ fc3_slice.T
76
- else:
77
- x_out = x_out + x_chunk @ fc3_slice.T
78
- return x_out
 
1
+ from typing import Optional
2
+
3
+ from torch import Tensor, nn
4
+
5
+ from . import vb_layers_initialize as init
6
+
7
+
8
+ class Transition(nn.Module):
9
+ """Perform a two-layer MLP."""
10
+
11
+ def __init__(
12
+ self,
13
+ dim: int = 128,
14
+ hidden: int = 512,
15
+ out_dim: Optional[int] = None,
16
+ ) -> None:
17
+ """Initialize the TransitionUpdate module.
18
+
19
+ Parameters
20
+ ----------
21
+ dim: int
22
+ The dimension of the input, default 128
23
+ hidden: int
24
+ The dimension of the hidden, default 512
25
+ out_dim: Optional[int]
26
+ The dimension of the output, default None
27
+
28
+ """
29
+ super().__init__()
30
+ if out_dim is None:
31
+ out_dim = dim
32
+
33
+ self.norm = nn.LayerNorm(dim, eps=1e-5)
34
+ self.fc1 = nn.Linear(dim, hidden, bias=False)
35
+ self.fc2 = nn.Linear(dim, hidden, bias=False)
36
+ self.fc3 = nn.Linear(hidden, out_dim, bias=False)
37
+ self.silu = nn.SiLU()
38
+ self.hidden = hidden
39
+
40
+ init.bias_init_one_(self.norm.weight)
41
+ init.bias_init_zero_(self.norm.bias)
42
+
43
+ init.lecun_normal_init_(self.fc1.weight)
44
+ init.lecun_normal_init_(self.fc2.weight)
45
+ init.final_init_(self.fc3.weight)
46
+
47
+ def forward(self, x: Tensor, chunk_size: int = None) -> Tensor:
48
+ """Perform a forward pass.
49
+
50
+ Parameters
51
+ ----------
52
+ x: torch.Tensor
53
+ The input data of shape (..., D)
54
+
55
+ Returns
56
+ -------
57
+ x: torch.Tensor
58
+ The output data of shape (..., D)
59
+
60
+ """
61
+ x = self.norm(x)
62
+
63
+ if chunk_size is None or self.training:
64
+ x = self.silu(self.fc1(x)) * self.fc2(x)
65
+ x = self.fc3(x)
66
+ return x
67
+ else:
68
+ # Compute in chunks
69
+ for i in range(0, self.hidden, chunk_size):
70
+ fc1_slice = self.fc1.weight[i : i + chunk_size, :]
71
+ fc2_slice = self.fc2.weight[i : i + chunk_size, :]
72
+ fc3_slice = self.fc3.weight[:, i : i + chunk_size]
73
+ x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T)
74
+ if i == 0:
75
+ x_out = x_chunk @ fc3_slice.T
76
+ else:
77
+ x_out = x_out + x_chunk @ fc3_slice.T
78
+ return x_out