| | import itertools |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import pytorch_lightning as L |
| | import torchmetrics |
| | from dataclasses import dataclass |
| | from esm_utils import load_esm2_model |
| | from transformers import AutoModel, AutoTokenizer |
| | import dit, ema |
| | import sys |
| | import config |
| | import wandb |
| | import noise_schedule |
| |
|
| | wandb_key = "2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f" |
| | wandb.login(key=wandb_key) |
| | wandb.init(project=config.Wandb.PROJECT, group=config.Wandb.GROUP) |
| |
|
| | LOG2 = math.log(2) |
| |
|
| | |
| | |
| | class WrapESM(nn.Module): |
| | def __init__(self, esm_model_path): |
| | super(WrapESM, self).__init__() |
| | self.esm_tokenizer, self.esm_model, _ = load_esm2_model(esm_model_path) |
| |
|
| | |
| | |
| | model_layers = len(self.esm_model.esm.encoder.layer) |
| |
|
| | |
| | for param in self.esm_model.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | for i, layer in enumerate(self.esm_model.esm.encoder.layer): |
| | if i >= model_layers-config.ESM_LAYERS: |
| | for module in layer.attention.self.key.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | for module in layer.attention.self.query.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | for module in layer.attention.self.value.modules(): |
| | for param in module.parameters(): |
| | param.requires_grad = True |
| | |
| | def forward(self, latents, sigma): |
| | return latents |
| |
|
| | @dataclass |
| | class Loss: |
| | loss: torch.FloatTensor |
| | nlls: torch.FloatTensor |
| | token_mask: torch.FloatTensor |
| |
|
| | class NLL(torchmetrics.MeanMetric): |
| | pass |
| |
|
| | class BPD(NLL): |
| | def compute(self) -> torch.Tensor: |
| | """Computes the bits per dimension. |
| | Returns: |
| | bpd |
| | """ |
| | return self.mean_value / self.weight / LOG2 |
| |
|
| | class Perplexity(NLL): |
| | def compute(self) -> torch.Tensor: |
| | """Computes the Perplexity. |
| | Returns: |
| | Perplexity |
| | """ |
| | return torch.exp(self.mean_value / self.weight) |
| |
|
| |
|
| | |
| | class Diffusion(L.LightningModule): |
| | def __init__(self, config, latent_dim, tokenizer): |
| | super().__init__() |
| | self.config = config |
| | self.latent_dim = latent_dim |
| | self.tokenizer = tokenizer |
| |
|
| | self.softplus = torch.nn.Softplus() |
| | metrics = torchmetrics.MetricCollection({ |
| | 'nll': NLL(), |
| | 'bpd': BPD(), |
| | 'ppl': Perplexity(), |
| | }) |
| | metrics.set_dtype(torch.float64) |
| | self.train_metrics = metrics.clone(prefix='train/') |
| | self.valid_metrics = metrics.clone(prefix='val/') |
| | self.test_metrics = metrics.clone(prefix='test/') |
| |
|
| | self.T = self.config.T |
| | self.lr = self.config.Optim.LR |
| | self.backbone = WrapESM(self.config.MODEL_NAME) |
| | self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype) |
| | self.time_conditioning = self.config.TIME_CONDITIONING |
| | self.subs_masking = self.config.SUBS_MASKING |
| | self.mask_index = self.tokenizer.mask_token_id |
| | self.antithetic_sampling = self.config.Training.ANTITHETIC_SAMPLING |
| | self.sampling_eps = self.config.Training.SAMPLING_EPS |
| | self.neg_infinity = -1000000.0 |
| |
|
| |
|
| | |
| | def subs_parameterization(self, logits, noised_latents): |
| | print(logits.size()) |
| | logits = logits.float() |
| | logits[:, :, self.mask_index] += self.neg_infinity |
| | |
| | |
| | logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) |
| |
|
| | unmasked_indices = (noised_latents != self.mask_index) |
| | logits[unmasked_indices] = self.neg_infinity |
| | logits[~unmasked_indices] = 0 |
| | |
| | return logits |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def forward(self, latents, sigma): |
| | latents = latents.long() |
| | logits = self.backbone(latents, sigma) |
| | optimized_logits = self.subs_parameterization(logits, latents) |
| | return optimized_logits |
| | |
| | def q_xt(self, latents, move_chance): |
| | """ |
| | Computes the noisy sample xt. |
| | Args: |
| | x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input. |
| | move_chance: float torch.Tensor with shape (batch_size, 1). |
| | """ |
| | latents = torch.mean(latents, dim=2) |
| | move_indices = torch.rand(* latents.shape, device=latents.device) < move_chance |
| | noised_latents = torch.where(move_indices, self.mask_index, latents) |
| | return noised_latents |
| |
|
| | def sample_timestep(self, n, device): |
| | _eps_t = torch.rand(n, device=device) |
| | if self.antithetic_sampling: |
| | offset = torch.arange(n, device=device) / n |
| | _eps_t = (_eps_t / n + offset) % 1 |
| | t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps |
| | |
| | |
| | return t |
| |
|
| | def forward_diffusion(self, x0): |
| | """Forward diffusion process, adds noise to the latents.""" |
| |
|
| | t = self.sample_timestep(x0.shape[0], x0.device) |
| | sigma, dsigma = self.noise(t) |
| | unet_conditioning = sigma[:, None] |
| | move_chance = 1 - torch.exp(-sigma[:, None, None]) |
| |
|
| | xt = self.q_xt(x0, move_chance) |
| | model_output = self.forward(xt, unet_conditioning) |
| | print(f'model out: {model_output}') |
| | print(f'model out dim: {model_output.size()}') |
| | |
| | |
| | idx = torch.mean(x0, dim=2).long()[:, :, None] |
| | print(f'idx: {idx}') |
| | print(f'idx dim: {idx.size()}') |
| |
|
| | log_p_theta = torch.gather(input=model_output, dim=-1, index=idx).squeeze(-1) |
| | scale = (dsigma / torch.expm1(sigma))[:, None] |
| | return - log_p_theta * scale |
| |
|
| |
|
| | |
| | def compute_loss(self, latents, attention_mask): |
| | """"Average of MLM losses to stabilize training""" |
| | loss = self.forward_diffusion(latents) |
| |
|
| | nlls = loss * attention_mask |
| | count = attention_mask.sum() |
| | batch_nll = nlls.sum() |
| | token_nll = batch_nll / count |
| |
|
| | return Loss(loss=token_nll, nlls=nlls, token_mask=attention_mask) |
| |
|
| |
|
| | |
| | def training_step(self, batch): |
| | latents, attention_mask = batch |
| | loss = self.compute_loss(latents, attention_mask) |
| | wandb.log({"train_loss": loss.loss.item()}) |
| | return loss.loss |
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
| | return optimizer |
| |
|
| | def validation_step(self, batch): |
| | latents, attention_mask = batch |
| | loss = self.compute_loss(latents, attention_mask) |
| | wandb.log({"val_loss": loss.loss.item()}) |
| | return loss.loss |
| | |
| |
|
| | |
| | def sample_prior(self, *batch_dims): |
| | return self.mask_index * torch.ones(* batch_dims, dtype=torch.int64) |
| |
|
| | def sample_categorical(categorical_probs): |
| | gumbel_norm = (1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()) |
| | return (categorical_probs / gumbel_norm).argmax(dim=-1) |
| |
|
| | def ddpm_caching_update(self, x, t, dt, p_x0=None): |
| | assert self.config.noise.type == 'loglinear' |
| | sigma_t, _ = self.noise(t) |
| | if t.ndim > 1: |
| | t = t.squeeze(-1) |
| | assert t.ndim == 1 |
| | move_chance_t = t[:, None, None] |
| | move_chance_s = (t - dt)[:, None, None] |
| | assert move_chance_t.ndim == 3, move_chance_t.shape |
| | if p_x0 is None: |
| | p_x0 = self.forward(x, sigma_t).exp() |
| | |
| | assert move_chance_t.ndim == p_x0.ndim |
| | q_xs = p_x0 * (move_chance_t - move_chance_s) |
| | q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0] |
| | _x = self.sample_categorical(q_xs) |
| | |
| | copy_flag = (x != self.mask_index).to(x.dtype) |
| | return p_x0, copy_flag * x + (1 - copy_flag) * _x |
| |
|
| |
|
| | @torch.no_grad() |
| | def sample_subs_guidance(self, n_samples, stride_length, num_strides, dt=0.001): |
| | ones = torch.ones(n_samples, dtype=self.dtype,device=self.device) |
| | num_steps = int(1 / dt) |
| | sampling_steps = 0 |
| | intermediate_tokens = [] |
| | target = None |
| |
|
| | for _ in range(num_strides + 1): |
| | p_x0_cache = None |
| | x = self._sample_prior(n_samples,self.config.model.length).to(self.device) |
| | |
| | if target is not None: |
| | x[:, : -stride_length] = target |
| | |
| | for i in range(num_steps + 1): |
| | p_x0_cache, x_next = self.ddpm_caching_update(x=x, t=(1 - i * dt) * ones, dt=dt, p_x0=p_x0_cache) |
| | if (not torch.allclose(x_next, x) or self.time_conditioning): |
| | p_x0_cache = None |
| | sampling_steps += 1 |
| | x = x_next |
| | x = self.forward(x, 0 * ones).argmax(dim=-1) |
| | intermediate_tokens.append(x[:, :stride_length].cpu().numpy()) |
| | target = x[:, stride_length:] |
| | |
| | intermediate_tokens.append(target.cpu().numpy()) |
| | intermediate_text_samples = [] |
| | sequence_lengths = ((np.concatenate(intermediate_tokens, axis=1)[:, 1:] |
| | == self.tokenizer.eos_token_id).cumsum(-1) == 0).sum(-1) |
| | |
| | for i in range(2, len(intermediate_tokens) + 1): |
| | intermediate_text_samples.append(self.tokenizer.decode(np.concatenate(intermediate_tokens[:i], axis=1))) |
| | |
| | return (sampling_steps, intermediate_text_samples, |
| | sequence_lengths) |
| |
|
| | def restore_model_and_semi_ar_sample(self, stride_length, num_strides, dt=0.001): |
| | """Generate samples from the model.""" |
| | |
| | self.backbone.eval() |
| | self.noise.eval() |
| | |
| | (sampling_steps, samples, sequence_lengths) = self.sample_subs_guidance(n_samples=self.config.Loader.BATCH_SIZE,stride_length=stride_length,num_strides=num_strides,dt=dt) |
| |
|
| | self.backbone.train() |
| | self.noise.train() |
| | return sampling_steps, samples, sequence_lengths |