import copy import math from collections import defaultdict import PIL import numpy as np import pandas as pd import torch, time, os import wandb import seaborn as sns import yaml sns.set_style('whitegrid') from matplotlib import pyplot as plt from torch import optim from models.dna_models import MLPModel, CNNModel, TransformerModel, DeepFlyBrainModel from utils.flow_utils import DirichletConditionalFlow, expand_simplex, sample_cond_prob_path, simplex_proj, \ get_wasserstein_dist, update_ema, load_flybrain_designed_seqs from modules.general_module import GeneralModule from utils.log import get_logger from flow_matching.path import MixtureDiscreteProbPath from flow_matching.path.scheduler import PolynomialConvexScheduler from flow_matching.solver import MixtureDiscreteEulerSolver from flow_matching.utils import ModelWrapper from flow_matching.loss import MixturePathGeneralizedKL import pdb logger = get_logger(__name__) class DNAModule(GeneralModule): def __init__(self, args, alphabet_size, num_cls, source_distribution="uniform"): super().__init__(args) self.alphabet_size = alphabet_size self.source_distribution = source_distribution self.epsilon = 1e-3 if source_distribution == "uniform": added_token = 0 elif source_distribution == "mask": self.mask_token = alphabet_size # tokens starting from zero added_token = 1 else: raise NotImplementedError self.alphabet_size += added_token self.load_model(self.alphabet_size, num_cls) self.scheduler = PolynomialConvexScheduler(n=args.scheduler_n) self.path = MixtureDiscreteProbPath(scheduler=self.scheduler) self.loss_fn = MixturePathGeneralizedKL(path=self.path) self.val_outputs = defaultdict(list) self.train_outputs = defaultdict(list) self.train_out_initialized = False self.mean_log_ema = {} if self.args.taskiran_seq_path is not None: self.taskiran_fly_seqs = load_flybrain_designed_seqs(self.args.taskiran_seq_path).to(self.device) def on_load_checkpoint(self, checkpoint): checkpoint['state_dict'] = {k: v for k,v in checkpoint['state_dict'].items() if 'cls_model' not in k and 'distill_model' not in k} def training_step(self, batch, batch_idx): self.stage = 'train' loss = self.general_step(batch, batch_idx) if self.args.ckpt_iterations is not None and self.trainer.global_step in self.args.ckpt_iterations: self.trainer.save_checkpoint(os.path.join(os.environ["MODEL_DIR"],f"epoch={self.trainer.current_epoch}-step={self.trainer.global_step}.ckpt")) # self.try_print_log() return loss def validation_step(self, batch, batch_idx): self.stage = 'val' loss = self.general_step(batch, batch_idx) # if self.args.validate: # self.try_print_log() def general_step(self, batch, batch_idx=None): self.iter_step += 1 x_1, cls = batch B, L = x_1.shape x_1 = x_1.to(self.device) if self.source_distribution == "uniform": x_0 = torch.randint_like(x_1, high=self.alphabet_size) elif self.source_distribution == "mask": x_0 = torch.zeros_like(x_1) + self.mask_token else: raise NotImplementedError # pdb.set_trace() t = torch.rand(x_1.shape[0]) * (1 - self.epsilon) t = t.to(x_1.device) path_sample = self.path.sample(t=t, x_0=x_0, x_1=x_1) logits = self.model(x_t=path_sample.x_t, t=path_sample.t) loss = self.loss_fn(logits=logits, x_1=x_1, x_t=path_sample.x_t, t=path_sample.t) # pdb.set_trace() self.lg('loss', loss) if self.stage == "val": predicted = logits.argmax(dim=-1) accuracy = (predicted == x_1).float().mean() self.lg('acc', accuracy) self.last_log_time = time.time() return loss @torch.no_grad() def dirichlet_flow_inference(self, seq, cls, model, args): B, L = seq.shape K = model.alphabet_size x0 = torch.distributions.Dirichlet(torch.ones(B, L, model.alphabet_size, device=seq.device)).sample() eye = torch.eye(K).to(x0) xt = x0.clone() t_span = torch.linspace(1, args.alpha_max, self.args.num_integration_steps, device=self.device) for i, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])): xt_expanded, prior_weights = expand_simplex(xt, s[None].expand(B), args.prior_pseudocount) logits = model(xt_expanded, t=s[None].expand(B)) flow_probs = torch.nn.functional.softmax(logits / args.flow_temp, -1) # [B, L, K] if not torch.allclose(flow_probs.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (flow_probs >= 0).all(): print(f'WARNING: flow_probs.min(): {flow_probs.min()}. Some values of flow_probs do not lie on the simplex. There are we are {(flow_probs<0).sum()} negative values in flow_probs of shape {flow_probs.shape} that are negative. We are projecting them onto the simplex.') flow_probs = simplex_proj(flow_probs) c_factor = self.condflow.c_factor(xt.cpu().numpy(), s.item()) c_factor = torch.from_numpy(c_factor).to(xt) self.inf_counter += 1 if not (flow_probs >= 0).all(): print(f'flow_probs.min(): {flow_probs.min()}') cond_flows = (eye - xt.unsqueeze(-1)) * c_factor.unsqueeze(-2) flow = (flow_probs.unsqueeze(-2) * cond_flows).sum(-1) xt = xt + flow * (t - s) if not torch.allclose(xt.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (xt >= 0).all(): print(f'WARNING: xt.min(): {xt.min()}. Some values of xt do not lie on the simplex. There are we are {(xt<0).sum()} negative values in xt of shape {xt.shape} that are negative. We are projecting them onto the simplex.') xt = simplex_proj(xt) return logits, x0 def on_validation_epoch_start(self): self.inf_counter = 1 self.nan_inf_counter = 0 def on_validation_epoch_end(self): self.generator = np.random.default_rng() log = self._log log = {key: log[key] for key in log if "val_" in key} log = self.gather_log(log, self.trainer.world_size) mean_log = self.get_log_mean(log) mean_log.update({'val_nan_inf_step_fraction': self.nan_inf_counter / self.inf_counter}) mean_log.update({'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) self.mean_log_ema = update_ema(current_dict=mean_log, prev_ema=self.mean_log_ema, gamma=0.9) mean_log.update(self.mean_log_ema) if self.trainer.is_global_zero: logger.info(str(mean_log)) self.log_dict(mean_log, batch_size=1) if self.args.wandb: wandb.log(mean_log) path = os.path.join(os.environ["MODEL_DIR"], f"val_{self.trainer.global_step}.csv") pd.DataFrame(log).to_csv(path) for key in list(log.keys()): if "val_" in key: del self._log[key] self.val_outputs = defaultdict(list) def on_train_epoch_start(self) -> None: self.inf_counter = 1 self.nan_inf_counter = 0 # if not self.loaded_distill_model and self.args.distill_ckpt is not None: # self.load_distill_model() # self.loaded_distill_model = True # if not self.loaded_classifiers: # self.load_classifiers(load_cls=self.args.cls_ckpt is not None, load_clean_cls=self.args.clean_cls_ckpt is not None) # self.loaded_classifiers = True def on_train_epoch_end(self): self.train_out_initialized = True log = self._log log = {key: log[key] for key in log if "train_" in key} log = self.gather_log(log, self.trainer.world_size) mean_log = self.get_log_mean(log) mean_log.update( {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) if self.trainer.is_global_zero: logger.info(str(mean_log)) self.log_dict(mean_log, batch_size=1) if self.args.wandb: wandb.log(mean_log) for key in list(log.keys()): if "train_" in key: del self._log[key] def lg(self, key, data): if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() log = self._log if self.args.validate or self.stage == 'train': log["iter_" + key].append(data) log[self.stage + "_" + key].append(data) def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=self.args.lr) return optimizer def plot_empirical_and_true(self, empirical_dist, true_dist): num_datasets_to_plot = min(4, empirical_dist.shape[0]) width = 1 # Creating a figure and axes fig, axes = plt.subplots(math.ceil(num_datasets_to_plot/2), 2, figsize=(10, 8)) for i in range(num_datasets_to_plot): row, col = i // 2, i % 2 x = np.arange(len(empirical_dist[i])) axes[row, col].bar(x, empirical_dist[i], width, label=f'empirical') axes[row, col].plot(x, true_dist[i], label=f'true density', color='orange') axes[row, col].legend() axes[row, col].set_title(f'Sequence position {i + 1}') axes[row, col].set_xlabel('Category') axes[row, col].set_ylabel('Density') plt.tight_layout() fig.canvas.draw() pil_img = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) plt.close() return pil_img def load_model(self, alphabet_size, num_cls): if self.args.model == 'cnn': self.model = CNNModel(self.args, alphabet_size=alphabet_size) elif self.args.model == 'mlp': self.model = MLPModel(input_dim=alphabet_size, time_dim=1, hidden_dim=self.args.hidden_dim, length=self.args.length) elif self.args.model == 'transformer': self.model = TransformerModel(alphabet_size=alphabet_size, seq_length=self.args.length, embed_dim=self.args.hidden_dim, \ num_layers=self.args.num_layers, num_heads=self.args.num_heads, dropout=self.args.dropout) elif self.args.model == 'deepflybrain': self.model = DeepFlyBrainModel(self.args, alphabet_size=alphabet_size,num_cls=num_cls) else: raise NotImplementedError() def plot_score_and_probs(self): clss = torch.cat(self.val_outputs['clss_noisycls']) probs = torch.softmax(torch.cat(self.val_outputs['logits_noisycls']), dim=-1) scores = torch.cat(self.val_outputs['scores_noisycls']).cpu().numpy() score_norms = np.linalg.norm(scores, axis=-1) alphas = torch.cat(self.val_outputs['alphas_noisycls']).cpu().numpy() true_probs = probs[torch.arange(len(probs)), clss].cpu().numpy() bins = np.linspace(min(alphas), 12, 20) indices = np.digitize(alphas, bins) bin_means = [np.mean(true_probs[indices == i]) for i in range(1, len(bins))] bin_std = [np.std(true_probs[indices == i]) for i in range(1, len(bins))] bin_centers = 0.5 * (bins[:-1] + bins[1:]) bin_pos_std = [np.std(true_probs[indices == i][true_probs[indices == i] > np.mean(true_probs[indices == i])]) for i in range(1, len(bins))] bin_neg_std = [np.std(true_probs[indices == i][true_probs[indices == i] < np.mean(true_probs[indices == i])]) for i in range(1, len(bins))] plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std}) plt.figure(figsize=(10, 6)) sns.lineplot(x='Alphas', y='Means', data=plot_data) plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'], plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3) plt.xlabel('Binned alphas values') plt.ylabel('Mean of predicted probs for true class') fig = plt.gcf() fig.canvas.draw() pil_probs = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) plt.close() bin_means = [np.mean(score_norms[indices == i]) for i in range(1, len(bins))] bin_std = [np.std(score_norms[indices == i]) for i in range(1, len(bins))] bin_pos_std = [np.std(score_norms[indices == i][score_norms[indices == i] > np.mean(score_norms[indices == i])]) for i in range(1, len(bins))] bin_neg_std = [np.std(score_norms[indices == i][score_norms[indices == i] < np.mean(score_norms[indices == i])]) for i in range(1, len(bins))] plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std}) plt.figure(figsize=(10, 6)) sns.lineplot(x='Alphas', y='Means', data=plot_data) plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'], plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3) plt.xlabel('Binned alphas values') plt.ylabel('Mean of norm of the scores') fig = plt.gcf() fig.canvas.draw() pil_score_norms = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) return pil_probs, pil_score_norms def log_data_similarities(self, seq_pred): similarities1 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class1[None, :, :]) # batchsize, dataset_size, seq_len similarities2 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class2[None, :, :]) # batchsize, dataset_size, seq_len similarities = seq_pred.cpu()[:, None, :].eq(torch.cat([self.toy_data.data_class2[None, :, :], self.toy_data.data_class1[None, :, :]],dim=1)) # batchsize, dataset_size, seq_len self.lg('data1_sim', similarities1.float().mean(-1).max(-1)[0]) self.lg('data2_sim', similarities2.float().mean(-1).max(-1)[0]) self.lg('data_sim', similarities.float().mean(-1).max(-1)[0]) self.lg('mean_data1_sim', similarities1.float().mean(-1).mean(-1)) self.lg('mean_data2_sim', similarities2.float().mean(-1).mean(-1)) self.lg('mean_data_sim', similarities.float().mean(-1).mean(-1))