Update modeling_loop_lm.py
Browse files- modeling_loop_lm.py +6 -12
modeling_loop_lm.py
CHANGED
|
@@ -61,15 +61,12 @@ class Linear(nn.Module):
|
|
| 61 |
def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None):
|
| 62 |
super().__init__()
|
| 63 |
|
| 64 |
-
#
|
| 65 |
-
|
| 66 |
|
| 67 |
# for muP, derive initial std deviation from given base model's std_deviation and width ratio
|
| 68 |
std_scaled = std_base / math.sqrt(width_ratio)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# assign as instance variable
|
| 72 |
-
self.weight = nn.Parameter(weights)
|
| 73 |
|
| 74 |
def forward(self, x: Tensor) -> Tensor:
|
| 75 |
# Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in"
|
|
@@ -81,14 +78,11 @@ class Embedding(nn.Module):
|
|
| 81 |
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
|
| 82 |
super().__init__()
|
| 83 |
|
| 84 |
-
#
|
| 85 |
-
|
| 86 |
|
| 87 |
# normalize the embeddings to spec
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
# save and enroll as torch param
|
| 91 |
-
self.weight = nn.Parameter(embeddings)
|
| 92 |
|
| 93 |
def forward(self, token_ids: Tensor) -> Tensor:
|
| 94 |
# for every id, we need to pull the row vector associated
|
|
|
|
| 61 |
def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None):
|
| 62 |
super().__init__()
|
| 63 |
|
| 64 |
+
# Register parameter first so shape is always stored (required for HF meta-device loading)
|
| 65 |
+
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device=device))
|
| 66 |
|
| 67 |
# for muP, derive initial std deviation from given base model's std_deviation and width ratio
|
| 68 |
std_scaled = std_base / math.sqrt(width_ratio)
|
| 69 |
+
nn.init.trunc_normal_(self.weight, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
def forward(self, x: Tensor) -> Tensor:
|
| 72 |
# Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in"
|
|
|
|
| 78 |
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
|
| 79 |
super().__init__()
|
| 80 |
|
| 81 |
+
# Register parameter first so shape is always stored (required for HF meta-device loading)
|
| 82 |
+
self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device))
|
| 83 |
|
| 84 |
# normalize the embeddings to spec
|
| 85 |
+
nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3, b=3)
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
def forward(self, token_ids: Tensor) -> Tensor:
|
| 88 |
# for every id, we need to pull the row vector associated
|