Shrey Goel commited on
Commit
0fa2d2b
·
1 Parent(s): 6154b48

cleaned training code

Browse files
src/madsbm/wt_peptide/control_field.py CHANGED
@@ -98,13 +98,7 @@ class PeptideControlField(nn.Module):
98
  self.embed_model.eval()
99
  for param in self.embed_model.parameters():
100
  param.requires_grad = False
101
-
102
- # # Unfreeze QKV in last few encoder layers
103
- # encoder_layers = self.embed_model.esm.encoder.layer
104
- # for layer in encoder_layers[-cfg.training.n_unfrozen:]:
105
- # for param in layer.parameters():
106
- # param.requires_grad = True
107
-
108
  self.time_embed = TimeEmbedding(
109
  hidden_dim=cfg.time_embed.time_dim,
110
  fourier_dim=cfg.time_embed.fourier_dim,
@@ -118,10 +112,6 @@ class PeptideControlField(nn.Module):
118
 
119
  self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6)
120
 
121
- # self.output_proj = self.embed_model.lm_head
122
- # for param in self.output_proj.parameters():
123
- # param.requires_grad = False
124
-
125
  self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size)
126
  nn.init.zeros_(self.output_proj.weight)
127
  nn.init.zeros_(self.output_proj.bias)
@@ -150,50 +140,4 @@ class PeptideControlField(nn.Module):
150
  "dit": logits,
151
  "madsbm": u_base + logits
152
  }
153
-
154
-
155
-
156
-
157
- # def forward(self, t, xt, attention_mask):
158
- # outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True)
159
- # h = outs.hidden_states[-1]
160
- # t_emb = self.time_embed(t) # [B, time_dim]
161
-
162
- # # Transformer head (key_padding_mask=True for pads)
163
- # key_padding_mask = (attention_mask == 0) # (B, L) bool
164
- # for dit_block in self.blocks:
165
- # h = dit_block(h, t_emb, key_padding_mask=key_padding_mask)
166
-
167
- # # Final norm + projection to vocab logits
168
- # h = self.final_norm(h) # [B, L, hidden_dim]
169
- # logits = self.output_proj(h) # [B, L, V]
170
- # return logits
171
-
172
-
173
- # def forward(self, xt, attention_mask, t):
174
- # with torch.no_grad():
175
- # base_out = self.embed_model(
176
- # input_ids=xt,
177
- # attention_mask=attention_mask,
178
- # output_hidden_states=True
179
- # )
180
-
181
- # logits_base = base_out.logits
182
- # h_base = base_out.hidden_states[-1]
183
-
184
- # norm = self.token_norm_sqrd.view(1,1,-1) # 1, 1, V
185
-
186
- # log_R0 = (self.beta1 * logits_base) - (self.beta2 * norm)
187
-
188
- # t_emb = self.time_embed(t) # [B, time_dim]
189
- # key_padding_mask = (attention_mask == 0) # (B, L) bool
190
-
191
- # h_ctrl = h_base
192
- # for dit_block in self.blocks:
193
- # h_ctrl = dit_block(h_ctrl, t_emb, key_padding_mask=key_padding_mask)
194
-
195
- # h_ctrl = self.final_norm(h_ctrl)
196
- # u_theta = self.output_proj(h_ctrl)
197
- # tot_logits = log_R0 + u_theta
198
-
199
- # return tot_logits, u_theta
 
98
  self.embed_model.eval()
99
  for param in self.embed_model.parameters():
100
  param.requires_grad = False
101
+
 
 
 
 
 
 
102
  self.time_embed = TimeEmbedding(
103
  hidden_dim=cfg.time_embed.time_dim,
104
  fourier_dim=cfg.time_embed.fourier_dim,
 
112
 
113
  self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6)
114
 
 
 
 
 
115
  self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size)
116
  nn.init.zeros_(self.output_proj.weight)
117
  nn.init.zeros_(self.output_proj.bias)
 
140
  "dit": logits,
141
  "madsbm": u_base + logits
142
  }
143
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/madsbm/wt_peptide/sbm_module.py CHANGED
@@ -31,19 +31,13 @@ class MadSBM(pl.LightningModule):
31
  for param in self.embed_model.parameters():
32
  param.requires_grad = False
33
 
34
- self.beta = 1.0 / self.config.model.hidden_dim
35
-
36
- # self.L = config.data.max_seq_len
37
- # self.V = self.vocab_size
38
- # self.log_R0 = - math.log(self.L * self.V) # uninformed generator is constant
39
-
40
  self.time_schedule = config.time_embed.time_schedule
41
  self.anneal_frac = config.time_embed.anneal_frac
42
  self.eps = float(config.time_embed.min_time)
43
  self.t_max = 1.0 - self.eps
44
 
45
 
46
- # -------# Forward Pass #-------- #
47
  def forward(self, input_ids, attention_mask, t):
48
  return self.model(xt=input_ids, attention_mask=attention_mask, t=t)
49
 
@@ -76,32 +70,8 @@ class MadSBM(pl.LightningModule):
76
  loss = sample_loss.mean()
77
  ppl = torch.exp(loss)
78
 
79
- _print(f'loss: {loss}')
80
- _print(f'ppl: {ppl}')
81
-
82
  return loss, ppl, max_u_logit, max_esm_logit
83
 
84
-
85
-
86
- # def step(self, batch):
87
- # x1 = batch['input_ids']
88
- # attn_mask = batch['attention_mask']
89
- # maskable = self.is_maskable(x1)
90
-
91
- # t = self.sample_t(x1)
92
- # xt = self.noise_seq(x1, t, maskable_mask=maskable)
93
-
94
- # u_theta = self.forward(xt, attn_mask, t)
95
- # b, l, v_target = self.compute_target(x1, xt, t, maskable_mask=maskable)
96
- # loss, ppl = self.compute_loss(u_theta, v_target, x1, b, l)
97
-
98
- # _print(f'loss: {loss}')
99
- # _print(f'ppl: {ppl}')
100
-
101
- # return loss, ppl
102
-
103
-
104
- # -------# Main Training Logic #-------- #
105
  def noise_seq(self, x1, t, maskable_mask):
106
  B, L = x1.shape
107
  t = t.unsqueeze(1) # B, 1
@@ -114,46 +84,6 @@ class MadSBM(pl.LightningModule):
114
  xt = xt.masked_fill(masked, self.mask_id)
115
 
116
  return xt
117
-
118
- # def compute_target(self, x1, xt, t, maskable_mask):
119
- # L = x1.size(1)
120
- # V = self.vocab_size
121
- # device = x1.device
122
-
123
- # mask = (xt == self.mask_id) & maskable_mask
124
- # b, l = torch.nonzero(mask, as_tuple=True)
125
-
126
- # if b.numel() == 0:
127
- # return b, l, torch.empty(0, device=device, dtype=torch.long)
128
-
129
- # log_R0 = - math.log(L * V) # uniform generator with rates (1 / L*V)
130
- # time = - torch.log(1 - t[b])
131
-
132
- # v_target = time - log_R0 # log(1/1-t) - log(1/L*V)
133
- # v_target = v_target.clamp(min=-100.0, max=100.0)
134
-
135
- # return b, l, v_target
136
-
137
-
138
- # def compute_loss(self, u_theta, v_target, x1, b, l):
139
- # if b.numel() == 0:
140
- # dummy_loss = 0.0 * u_theta.sum()
141
- # return dummy_loss, torch.tensor(0.0, device=u_theta.device)
142
-
143
- # true_toks = x1[b, l]
144
- # u_pred = u_theta[b, l, :] # N_masks, V
145
-
146
- # tgt = torch.zeros_like(u_pred)
147
- # tgt.scatter_(1, true_toks.unsqueeze(1), v_target.unsqueeze(1))
148
-
149
- # sse = F.mse_loss(u_pred, tgt, reduction='sum')
150
- # loss = sse / b.numel() if b.numel != 0 else sse # normalize by number of masks
151
-
152
- # with torch.no_grad():
153
- # ppl = torch.exp(F.cross_entropy(u_pred, true_toks))
154
-
155
- # return loss, ppl
156
-
157
 
158
  # -------# Time Schedules #-------- #
159
  def sample_t(self, x1):
 
31
  for param in self.embed_model.parameters():
32
  param.requires_grad = False
33
 
 
 
 
 
 
 
34
  self.time_schedule = config.time_embed.time_schedule
35
  self.anneal_frac = config.time_embed.anneal_frac
36
  self.eps = float(config.time_embed.min_time)
37
  self.t_max = 1.0 - self.eps
38
 
39
 
40
+ # -------# Main Training Logic #-------- #
41
  def forward(self, input_ids, attention_mask, t):
42
  return self.model(xt=input_ids, attention_mask=attention_mask, t=t)
43
 
 
70
  loss = sample_loss.mean()
71
  ppl = torch.exp(loss)
72
 
 
 
 
73
  return loss, ppl, max_u_logit, max_esm_logit
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def noise_seq(self, x1, t, maskable_mask):
76
  B, L = x1.shape
77
  t = t.unsqueeze(1) # B, 1
 
84
  xt = xt.masked_fill(masked, self.mask_id)
85
 
86
  return xt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # -------# Time Schedules #-------- #
89
  def sample_t(self, x1):