Shrey Goel commited on
Commit ·
0fa2d2b
1
Parent(s): 6154b48
cleaned training code
Browse files
src/madsbm/wt_peptide/control_field.py
CHANGED
|
@@ -98,13 +98,7 @@ class PeptideControlField(nn.Module):
|
|
| 98 |
self.embed_model.eval()
|
| 99 |
for param in self.embed_model.parameters():
|
| 100 |
param.requires_grad = False
|
| 101 |
-
|
| 102 |
-
# # Unfreeze QKV in last few encoder layers
|
| 103 |
-
# encoder_layers = self.embed_model.esm.encoder.layer
|
| 104 |
-
# for layer in encoder_layers[-cfg.training.n_unfrozen:]:
|
| 105 |
-
# for param in layer.parameters():
|
| 106 |
-
# param.requires_grad = True
|
| 107 |
-
|
| 108 |
self.time_embed = TimeEmbedding(
|
| 109 |
hidden_dim=cfg.time_embed.time_dim,
|
| 110 |
fourier_dim=cfg.time_embed.fourier_dim,
|
|
@@ -118,10 +112,6 @@ class PeptideControlField(nn.Module):
|
|
| 118 |
|
| 119 |
self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6)
|
| 120 |
|
| 121 |
-
# self.output_proj = self.embed_model.lm_head
|
| 122 |
-
# for param in self.output_proj.parameters():
|
| 123 |
-
# param.requires_grad = False
|
| 124 |
-
|
| 125 |
self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size)
|
| 126 |
nn.init.zeros_(self.output_proj.weight)
|
| 127 |
nn.init.zeros_(self.output_proj.bias)
|
|
@@ -150,50 +140,4 @@ class PeptideControlField(nn.Module):
|
|
| 150 |
"dit": logits,
|
| 151 |
"madsbm": u_base + logits
|
| 152 |
}
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
# def forward(self, t, xt, attention_mask):
|
| 158 |
-
# outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True)
|
| 159 |
-
# h = outs.hidden_states[-1]
|
| 160 |
-
# t_emb = self.time_embed(t) # [B, time_dim]
|
| 161 |
-
|
| 162 |
-
# # Transformer head (key_padding_mask=True for pads)
|
| 163 |
-
# key_padding_mask = (attention_mask == 0) # (B, L) bool
|
| 164 |
-
# for dit_block in self.blocks:
|
| 165 |
-
# h = dit_block(h, t_emb, key_padding_mask=key_padding_mask)
|
| 166 |
-
|
| 167 |
-
# # Final norm + projection to vocab logits
|
| 168 |
-
# h = self.final_norm(h) # [B, L, hidden_dim]
|
| 169 |
-
# logits = self.output_proj(h) # [B, L, V]
|
| 170 |
-
# return logits
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
# def forward(self, xt, attention_mask, t):
|
| 174 |
-
# with torch.no_grad():
|
| 175 |
-
# base_out = self.embed_model(
|
| 176 |
-
# input_ids=xt,
|
| 177 |
-
# attention_mask=attention_mask,
|
| 178 |
-
# output_hidden_states=True
|
| 179 |
-
# )
|
| 180 |
-
|
| 181 |
-
# logits_base = base_out.logits
|
| 182 |
-
# h_base = base_out.hidden_states[-1]
|
| 183 |
-
|
| 184 |
-
# norm = self.token_norm_sqrd.view(1,1,-1) # 1, 1, V
|
| 185 |
-
|
| 186 |
-
# log_R0 = (self.beta1 * logits_base) - (self.beta2 * norm)
|
| 187 |
-
|
| 188 |
-
# t_emb = self.time_embed(t) # [B, time_dim]
|
| 189 |
-
# key_padding_mask = (attention_mask == 0) # (B, L) bool
|
| 190 |
-
|
| 191 |
-
# h_ctrl = h_base
|
| 192 |
-
# for dit_block in self.blocks:
|
| 193 |
-
# h_ctrl = dit_block(h_ctrl, t_emb, key_padding_mask=key_padding_mask)
|
| 194 |
-
|
| 195 |
-
# h_ctrl = self.final_norm(h_ctrl)
|
| 196 |
-
# u_theta = self.output_proj(h_ctrl)
|
| 197 |
-
# tot_logits = log_R0 + u_theta
|
| 198 |
-
|
| 199 |
-
# return tot_logits, u_theta
|
|
|
|
| 98 |
self.embed_model.eval()
|
| 99 |
for param in self.embed_model.parameters():
|
| 100 |
param.requires_grad = False
|
| 101 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
self.time_embed = TimeEmbedding(
|
| 103 |
hidden_dim=cfg.time_embed.time_dim,
|
| 104 |
fourier_dim=cfg.time_embed.fourier_dim,
|
|
|
|
| 112 |
|
| 113 |
self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size)
|
| 116 |
nn.init.zeros_(self.output_proj.weight)
|
| 117 |
nn.init.zeros_(self.output_proj.bias)
|
|
|
|
| 140 |
"dit": logits,
|
| 141 |
"madsbm": u_base + logits
|
| 142 |
}
|
| 143 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/madsbm/wt_peptide/sbm_module.py
CHANGED
|
@@ -31,19 +31,13 @@ class MadSBM(pl.LightningModule):
|
|
| 31 |
for param in self.embed_model.parameters():
|
| 32 |
param.requires_grad = False
|
| 33 |
|
| 34 |
-
self.beta = 1.0 / self.config.model.hidden_dim
|
| 35 |
-
|
| 36 |
-
# self.L = config.data.max_seq_len
|
| 37 |
-
# self.V = self.vocab_size
|
| 38 |
-
# self.log_R0 = - math.log(self.L * self.V) # uninformed generator is constant
|
| 39 |
-
|
| 40 |
self.time_schedule = config.time_embed.time_schedule
|
| 41 |
self.anneal_frac = config.time_embed.anneal_frac
|
| 42 |
self.eps = float(config.time_embed.min_time)
|
| 43 |
self.t_max = 1.0 - self.eps
|
| 44 |
|
| 45 |
|
| 46 |
-
# -------#
|
| 47 |
def forward(self, input_ids, attention_mask, t):
|
| 48 |
return self.model(xt=input_ids, attention_mask=attention_mask, t=t)
|
| 49 |
|
|
@@ -76,32 +70,8 @@ class MadSBM(pl.LightningModule):
|
|
| 76 |
loss = sample_loss.mean()
|
| 77 |
ppl = torch.exp(loss)
|
| 78 |
|
| 79 |
-
_print(f'loss: {loss}')
|
| 80 |
-
_print(f'ppl: {ppl}')
|
| 81 |
-
|
| 82 |
return loss, ppl, max_u_logit, max_esm_logit
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# def step(self, batch):
|
| 87 |
-
# x1 = batch['input_ids']
|
| 88 |
-
# attn_mask = batch['attention_mask']
|
| 89 |
-
# maskable = self.is_maskable(x1)
|
| 90 |
-
|
| 91 |
-
# t = self.sample_t(x1)
|
| 92 |
-
# xt = self.noise_seq(x1, t, maskable_mask=maskable)
|
| 93 |
-
|
| 94 |
-
# u_theta = self.forward(xt, attn_mask, t)
|
| 95 |
-
# b, l, v_target = self.compute_target(x1, xt, t, maskable_mask=maskable)
|
| 96 |
-
# loss, ppl = self.compute_loss(u_theta, v_target, x1, b, l)
|
| 97 |
-
|
| 98 |
-
# _print(f'loss: {loss}')
|
| 99 |
-
# _print(f'ppl: {ppl}')
|
| 100 |
-
|
| 101 |
-
# return loss, ppl
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
# -------# Main Training Logic #-------- #
|
| 105 |
def noise_seq(self, x1, t, maskable_mask):
|
| 106 |
B, L = x1.shape
|
| 107 |
t = t.unsqueeze(1) # B, 1
|
|
@@ -114,46 +84,6 @@ class MadSBM(pl.LightningModule):
|
|
| 114 |
xt = xt.masked_fill(masked, self.mask_id)
|
| 115 |
|
| 116 |
return xt
|
| 117 |
-
|
| 118 |
-
# def compute_target(self, x1, xt, t, maskable_mask):
|
| 119 |
-
# L = x1.size(1)
|
| 120 |
-
# V = self.vocab_size
|
| 121 |
-
# device = x1.device
|
| 122 |
-
|
| 123 |
-
# mask = (xt == self.mask_id) & maskable_mask
|
| 124 |
-
# b, l = torch.nonzero(mask, as_tuple=True)
|
| 125 |
-
|
| 126 |
-
# if b.numel() == 0:
|
| 127 |
-
# return b, l, torch.empty(0, device=device, dtype=torch.long)
|
| 128 |
-
|
| 129 |
-
# log_R0 = - math.log(L * V) # uniform generator with rates (1 / L*V)
|
| 130 |
-
# time = - torch.log(1 - t[b])
|
| 131 |
-
|
| 132 |
-
# v_target = time - log_R0 # log(1/1-t) - log(1/L*V)
|
| 133 |
-
# v_target = v_target.clamp(min=-100.0, max=100.0)
|
| 134 |
-
|
| 135 |
-
# return b, l, v_target
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
# def compute_loss(self, u_theta, v_target, x1, b, l):
|
| 139 |
-
# if b.numel() == 0:
|
| 140 |
-
# dummy_loss = 0.0 * u_theta.sum()
|
| 141 |
-
# return dummy_loss, torch.tensor(0.0, device=u_theta.device)
|
| 142 |
-
|
| 143 |
-
# true_toks = x1[b, l]
|
| 144 |
-
# u_pred = u_theta[b, l, :] # N_masks, V
|
| 145 |
-
|
| 146 |
-
# tgt = torch.zeros_like(u_pred)
|
| 147 |
-
# tgt.scatter_(1, true_toks.unsqueeze(1), v_target.unsqueeze(1))
|
| 148 |
-
|
| 149 |
-
# sse = F.mse_loss(u_pred, tgt, reduction='sum')
|
| 150 |
-
# loss = sse / b.numel() if b.numel != 0 else sse # normalize by number of masks
|
| 151 |
-
|
| 152 |
-
# with torch.no_grad():
|
| 153 |
-
# ppl = torch.exp(F.cross_entropy(u_pred, true_toks))
|
| 154 |
-
|
| 155 |
-
# return loss, ppl
|
| 156 |
-
|
| 157 |
|
| 158 |
# -------# Time Schedules #-------- #
|
| 159 |
def sample_t(self, x1):
|
|
|
|
| 31 |
for param in self.embed_model.parameters():
|
| 32 |
param.requires_grad = False
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
self.time_schedule = config.time_embed.time_schedule
|
| 35 |
self.anneal_frac = config.time_embed.anneal_frac
|
| 36 |
self.eps = float(config.time_embed.min_time)
|
| 37 |
self.t_max = 1.0 - self.eps
|
| 38 |
|
| 39 |
|
| 40 |
+
# -------# Main Training Logic #-------- #
|
| 41 |
def forward(self, input_ids, attention_mask, t):
|
| 42 |
return self.model(xt=input_ids, attention_mask=attention_mask, t=t)
|
| 43 |
|
|
|
|
| 70 |
loss = sample_loss.mean()
|
| 71 |
ppl = torch.exp(loss)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
return loss, ppl, max_u_logit, max_esm_logit
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def noise_seq(self, x1, t, maskable_mask):
|
| 76 |
B, L = x1.shape
|
| 77 |
t = t.unsqueeze(1) # B, 1
|
|
|
|
| 84 |
xt = xt.masked_fill(masked, self.mask_id)
|
| 85 |
|
| 86 |
return xt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# -------# Time Schedules #-------- #
|
| 89 |
def sample_t(self, x1):
|