ml-ryanlee commited on
Commit
4735b6a
·
verified ·
1 Parent(s): 0a6d9d7

Update modeling_loop_lm.py

Browse files
Files changed (1) hide show
  1. modeling_loop_lm.py +6 -12
modeling_loop_lm.py CHANGED
@@ -61,15 +61,12 @@ class Linear(nn.Module):
61
  def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None):
62
  super().__init__()
63
 
64
- # initialize weights matrix
65
- weights = torch.empty(out_features, in_features, dtype=dtype, device=device)
66
 
67
  # for muP, derive initial std deviation from given base model's std_deviation and width ratio
68
  std_scaled = std_base / math.sqrt(width_ratio)
69
- weights = nn.init.trunc_normal_(weights, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
70
-
71
- # assign as instance variable
72
- self.weight = nn.Parameter(weights)
73
 
74
  def forward(self, x: Tensor) -> Tensor:
75
  # Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in"
@@ -81,14 +78,11 @@ class Embedding(nn.Module):
81
  def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
82
  super().__init__()
83
 
84
- # initialize a matrix of vocab_size x embedding_dim
85
- embeddings = torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device)
86
 
87
  # normalize the embeddings to spec
88
- embeddings = nn.init.trunc_normal_(embeddings, mean=0.0, std=1.0, a=-3, b=3)
89
-
90
- # save and enroll as torch param
91
- self.weight = nn.Parameter(embeddings)
92
 
93
  def forward(self, token_ids: Tensor) -> Tensor:
94
  # for every id, we need to pull the row vector associated
 
61
  def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None):
62
  super().__init__()
63
 
64
+ # Register parameter first so shape is always stored (required for HF meta-device loading)
65
+ self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device=device))
66
 
67
  # for muP, derive initial std deviation from given base model's std_deviation and width ratio
68
  std_scaled = std_base / math.sqrt(width_ratio)
69
+ nn.init.trunc_normal_(self.weight, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
 
 
 
70
 
71
  def forward(self, x: Tensor) -> Tensor:
72
  # Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in"
 
78
  def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
79
  super().__init__()
80
 
81
+ # Register parameter first so shape is always stored (required for HF meta-device loading)
82
+ self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device))
83
 
84
  # normalize the embeddings to spec
85
+ nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3, b=3)
 
 
 
86
 
87
  def forward(self, token_ids: Tensor) -> Tensor:
88
  # for every id, we need to pull the row vector associated