| import copy | |
| import math | |
| from collections import defaultdict | |
| import PIL | |
| import numpy as np | |
| import pandas as pd | |
| import torch, time, os | |
| import wandb | |
| import seaborn as sns | |
| import yaml | |
| sns.set_style('whitegrid') | |
| from matplotlib import pyplot as plt | |
| from torch import optim | |
| from models.dna_models import MLPModel, CNNModel, TransformerModel, DeepFlyBrainModel | |
| from utils.flow_utils import DirichletConditionalFlow, expand_simplex, sample_cond_prob_path, simplex_proj, \ | |
| get_wasserstein_dist, update_ema, load_flybrain_designed_seqs | |
| from modules.general_module import GeneralModule | |
| from utils.log import get_logger | |
| from flow_matching.path import MixtureDiscreteProbPath | |
| from flow_matching.path.scheduler import PolynomialConvexScheduler | |
| from flow_matching.solver import MixtureDiscreteEulerSolver | |
| from flow_matching.utils import ModelWrapper | |
| from flow_matching.loss import MixturePathGeneralizedKL | |
| import pdb | |
| logger = get_logger(__name__) | |
| class DNAModule(GeneralModule): | |
| def __init__(self, args, alphabet_size, num_cls, source_distribution="uniform"): | |
| super().__init__(args) | |
| self.alphabet_size = alphabet_size | |
| self.source_distribution = source_distribution | |
| self.epsilon = 1e-3 | |
| if source_distribution == "uniform": | |
| added_token = 0 | |
| elif source_distribution == "mask": | |
| self.mask_token = alphabet_size # tokens starting from zero | |
| added_token = 1 | |
| else: | |
| raise NotImplementedError | |
| self.alphabet_size += added_token | |
| self.load_model(self.alphabet_size, num_cls) | |
| self.scheduler = PolynomialConvexScheduler(n=args.scheduler_n) | |
| self.path = MixtureDiscreteProbPath(scheduler=self.scheduler) | |
| self.loss_fn = MixturePathGeneralizedKL(path=self.path) | |
| self.val_outputs = defaultdict(list) | |
| self.train_outputs = defaultdict(list) | |
| self.train_out_initialized = False | |
| self.mean_log_ema = {} | |
| if self.args.taskiran_seq_path is not None: | |
| self.taskiran_fly_seqs = load_flybrain_designed_seqs(self.args.taskiran_seq_path).to(self.device) | |
| def on_load_checkpoint(self, checkpoint): | |
| checkpoint['state_dict'] = {k: v for k,v in checkpoint['state_dict'].items() if 'cls_model' not in k and 'distill_model' not in k} | |
| def training_step(self, batch, batch_idx): | |
| self.stage = 'train' | |
| loss = self.general_step(batch, batch_idx) | |
| if self.args.ckpt_iterations is not None and self.trainer.global_step in self.args.ckpt_iterations: | |
| self.trainer.save_checkpoint(os.path.join(os.environ["MODEL_DIR"],f"epoch={self.trainer.current_epoch}-step={self.trainer.global_step}.ckpt")) | |
| # self.try_print_log() | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| self.stage = 'val' | |
| loss = self.general_step(batch, batch_idx) | |
| # if self.args.validate: | |
| # self.try_print_log() | |
| def general_step(self, batch, batch_idx=None): | |
| self.iter_step += 1 | |
| x_1, cls = batch | |
| B, L = x_1.shape | |
| x_1 = x_1.to(self.device) | |
| if self.source_distribution == "uniform": | |
| x_0 = torch.randint_like(x_1, high=self.alphabet_size) | |
| elif self.source_distribution == "mask": | |
| x_0 = torch.zeros_like(x_1) + self.mask_token | |
| else: | |
| raise NotImplementedError | |
| # pdb.set_trace() | |
| t = torch.rand(x_1.shape[0]) * (1 - self.epsilon) | |
| t = t.to(x_1.device) | |
| path_sample = self.path.sample(t=t, x_0=x_0, x_1=x_1) | |
| logits = self.model(x_t=path_sample.x_t, t=path_sample.t) | |
| loss = self.loss_fn(logits=logits, x_1=x_1, x_t=path_sample.x_t, t=path_sample.t) | |
| # pdb.set_trace() | |
| self.lg('loss', loss) | |
| if self.stage == "val": | |
| predicted = logits.argmax(dim=-1) | |
| accuracy = (predicted == x_1).float().mean() | |
| self.lg('acc', accuracy) | |
| self.last_log_time = time.time() | |
| return loss | |
| def dirichlet_flow_inference(self, seq, cls, model, args): | |
| B, L = seq.shape | |
| K = model.alphabet_size | |
| x0 = torch.distributions.Dirichlet(torch.ones(B, L, model.alphabet_size, device=seq.device)).sample() | |
| eye = torch.eye(K).to(x0) | |
| xt = x0.clone() | |
| t_span = torch.linspace(1, args.alpha_max, self.args.num_integration_steps, device=self.device) | |
| for i, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])): | |
| xt_expanded, prior_weights = expand_simplex(xt, s[None].expand(B), args.prior_pseudocount) | |
| logits = model(xt_expanded, t=s[None].expand(B)) | |
| flow_probs = torch.nn.functional.softmax(logits / args.flow_temp, -1) # [B, L, K] | |
| if not torch.allclose(flow_probs.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (flow_probs >= 0).all(): | |
| print(f'WARNING: flow_probs.min(): {flow_probs.min()}. Some values of flow_probs do not lie on the simplex. There are we are {(flow_probs<0).sum()} negative values in flow_probs of shape {flow_probs.shape} that are negative. We are projecting them onto the simplex.') | |
| flow_probs = simplex_proj(flow_probs) | |
| c_factor = self.condflow.c_factor(xt.cpu().numpy(), s.item()) | |
| c_factor = torch.from_numpy(c_factor).to(xt) | |
| self.inf_counter += 1 | |
| if not (flow_probs >= 0).all(): print(f'flow_probs.min(): {flow_probs.min()}') | |
| cond_flows = (eye - xt.unsqueeze(-1)) * c_factor.unsqueeze(-2) | |
| flow = (flow_probs.unsqueeze(-2) * cond_flows).sum(-1) | |
| xt = xt + flow * (t - s) | |
| if not torch.allclose(xt.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (xt >= 0).all(): | |
| print(f'WARNING: xt.min(): {xt.min()}. Some values of xt do not lie on the simplex. There are we are {(xt<0).sum()} negative values in xt of shape {xt.shape} that are negative. We are projecting them onto the simplex.') | |
| xt = simplex_proj(xt) | |
| return logits, x0 | |
| def on_validation_epoch_start(self): | |
| self.inf_counter = 1 | |
| self.nan_inf_counter = 0 | |
| def on_validation_epoch_end(self): | |
| self.generator = np.random.default_rng() | |
| log = self._log | |
| log = {key: log[key] for key in log if "val_" in key} | |
| log = self.gather_log(log, self.trainer.world_size) | |
| mean_log = self.get_log_mean(log) | |
| mean_log.update({'val_nan_inf_step_fraction': self.nan_inf_counter / self.inf_counter}) | |
| mean_log.update({'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) | |
| self.mean_log_ema = update_ema(current_dict=mean_log, prev_ema=self.mean_log_ema, gamma=0.9) | |
| mean_log.update(self.mean_log_ema) | |
| if self.trainer.is_global_zero: | |
| logger.info(str(mean_log)) | |
| self.log_dict(mean_log, batch_size=1) | |
| if self.args.wandb: | |
| wandb.log(mean_log) | |
| path = os.path.join(os.environ["MODEL_DIR"], f"val_{self.trainer.global_step}.csv") | |
| pd.DataFrame(log).to_csv(path) | |
| for key in list(log.keys()): | |
| if "val_" in key: | |
| del self._log[key] | |
| self.val_outputs = defaultdict(list) | |
| def on_train_epoch_start(self) -> None: | |
| self.inf_counter = 1 | |
| self.nan_inf_counter = 0 | |
| # if not self.loaded_distill_model and self.args.distill_ckpt is not None: | |
| # self.load_distill_model() | |
| # self.loaded_distill_model = True | |
| # if not self.loaded_classifiers: | |
| # self.load_classifiers(load_cls=self.args.cls_ckpt is not None, load_clean_cls=self.args.clean_cls_ckpt is not None) | |
| # self.loaded_classifiers = True | |
| def on_train_epoch_end(self): | |
| self.train_out_initialized = True | |
| log = self._log | |
| log = {key: log[key] for key in log if "train_" in key} | |
| log = self.gather_log(log, self.trainer.world_size) | |
| mean_log = self.get_log_mean(log) | |
| mean_log.update( | |
| {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) | |
| if self.trainer.is_global_zero: | |
| logger.info(str(mean_log)) | |
| self.log_dict(mean_log, batch_size=1) | |
| if self.args.wandb: | |
| wandb.log(mean_log) | |
| for key in list(log.keys()): | |
| if "train_" in key: | |
| del self._log[key] | |
| def lg(self, key, data): | |
| if isinstance(data, torch.Tensor): | |
| data = data.detach().cpu().numpy() | |
| log = self._log | |
| if self.args.validate or self.stage == 'train': | |
| log["iter_" + key].append(data) | |
| log[self.stage + "_" + key].append(data) | |
| def configure_optimizers(self): | |
| optimizer = optim.Adam(self.parameters(), lr=self.args.lr) | |
| return optimizer | |
| def plot_empirical_and_true(self, empirical_dist, true_dist): | |
| num_datasets_to_plot = min(4, empirical_dist.shape[0]) | |
| width = 1 | |
| # Creating a figure and axes | |
| fig, axes = plt.subplots(math.ceil(num_datasets_to_plot/2), 2, figsize=(10, 8)) | |
| for i in range(num_datasets_to_plot): | |
| row, col = i // 2, i % 2 | |
| x = np.arange(len(empirical_dist[i])) | |
| axes[row, col].bar(x, empirical_dist[i], width, label=f'empirical') | |
| axes[row, col].plot(x, true_dist[i], label=f'true density', color='orange') | |
| axes[row, col].legend() | |
| axes[row, col].set_title(f'Sequence position {i + 1}') | |
| axes[row, col].set_xlabel('Category') | |
| axes[row, col].set_ylabel('Density') | |
| plt.tight_layout() | |
| fig.canvas.draw() | |
| pil_img = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| plt.close() | |
| return pil_img | |
| def load_model(self, alphabet_size, num_cls): | |
| if self.args.model == 'cnn': | |
| self.model = CNNModel(self.args, alphabet_size=alphabet_size) | |
| elif self.args.model == 'mlp': | |
| self.model = MLPModel(input_dim=alphabet_size, time_dim=1, hidden_dim=self.args.hidden_dim, length=self.args.length) | |
| elif self.args.model == 'transformer': | |
| self.model = TransformerModel(alphabet_size=alphabet_size, seq_length=self.args.length, embed_dim=self.args.hidden_dim, \ | |
| num_layers=self.args.num_layers, num_heads=self.args.num_heads, dropout=self.args.dropout) | |
| elif self.args.model == 'deepflybrain': | |
| self.model = DeepFlyBrainModel(self.args, alphabet_size=alphabet_size,num_cls=num_cls) | |
| else: | |
| raise NotImplementedError() | |
| def plot_score_and_probs(self): | |
| clss = torch.cat(self.val_outputs['clss_noisycls']) | |
| probs = torch.softmax(torch.cat(self.val_outputs['logits_noisycls']), dim=-1) | |
| scores = torch.cat(self.val_outputs['scores_noisycls']).cpu().numpy() | |
| score_norms = np.linalg.norm(scores, axis=-1) | |
| alphas = torch.cat(self.val_outputs['alphas_noisycls']).cpu().numpy() | |
| true_probs = probs[torch.arange(len(probs)), clss].cpu().numpy() | |
| bins = np.linspace(min(alphas), 12, 20) | |
| indices = np.digitize(alphas, bins) | |
| bin_means = [np.mean(true_probs[indices == i]) for i in range(1, len(bins))] | |
| bin_std = [np.std(true_probs[indices == i]) for i in range(1, len(bins))] | |
| bin_centers = 0.5 * (bins[:-1] + bins[1:]) | |
| bin_pos_std = [np.std(true_probs[indices == i][true_probs[indices == i] > np.mean(true_probs[indices == i])]) for i in range(1, len(bins))] | |
| bin_neg_std = [np.std(true_probs[indices == i][true_probs[indices == i] < np.mean(true_probs[indices == i])]) for i in range(1, len(bins))] | |
| plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std}) | |
| plt.figure(figsize=(10, 6)) | |
| sns.lineplot(x='Alphas', y='Means', data=plot_data) | |
| plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'], plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3) | |
| plt.xlabel('Binned alphas values') | |
| plt.ylabel('Mean of predicted probs for true class') | |
| fig = plt.gcf() | |
| fig.canvas.draw() | |
| pil_probs = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| plt.close() | |
| bin_means = [np.mean(score_norms[indices == i]) for i in range(1, len(bins))] | |
| bin_std = [np.std(score_norms[indices == i]) for i in range(1, len(bins))] | |
| bin_pos_std = [np.std(score_norms[indices == i][score_norms[indices == i] > np.mean(score_norms[indices == i])]) for i in range(1, len(bins))] | |
| bin_neg_std = [np.std(score_norms[indices == i][score_norms[indices == i] < np.mean(score_norms[indices == i])]) for i in range(1, len(bins))] | |
| plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std}) | |
| plt.figure(figsize=(10, 6)) | |
| sns.lineplot(x='Alphas', y='Means', data=plot_data) | |
| plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'], | |
| plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3) | |
| plt.xlabel('Binned alphas values') | |
| plt.ylabel('Mean of norm of the scores') | |
| fig = plt.gcf() | |
| fig.canvas.draw() | |
| pil_score_norms = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| return pil_probs, pil_score_norms | |
| def log_data_similarities(self, seq_pred): | |
| similarities1 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class1[None, :, :]) # batchsize, dataset_size, seq_len | |
| similarities2 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class2[None, :, :]) # batchsize, dataset_size, seq_len | |
| similarities = seq_pred.cpu()[:, None, :].eq(torch.cat([self.toy_data.data_class2[None, :, :], self.toy_data.data_class1[None, :, :]],dim=1)) # batchsize, dataset_size, seq_len | |
| self.lg('data1_sim', similarities1.float().mean(-1).max(-1)[0]) | |
| self.lg('data2_sim', similarities2.float().mean(-1).max(-1)[0]) | |
| self.lg('data_sim', similarities.float().mean(-1).max(-1)[0]) | |
| self.lg('mean_data1_sim', similarities1.float().mean(-1).mean(-1)) | |
| self.lg('mean_data2_sim', similarities2.float().mean(-1).mean(-1)) | |
| self.lg('mean_data_sim', similarities.float().mean(-1).mean(-1)) | |
Xet Storage Details
- Size:
- 14.6 kB
- Xet hash:
- b9520d2422c6ac4447971662b9ee6efea479ae2715f6ec2d6feda0d60e9435fe
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.