|
|
import copy |
|
|
import math |
|
|
from collections import defaultdict |
|
|
|
|
|
import PIL |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch, time, os |
|
|
import wandb |
|
|
import seaborn as sns |
|
|
import yaml |
|
|
|
|
|
sns.set_style('whitegrid') |
|
|
from matplotlib import pyplot as plt |
|
|
from torch import optim |
|
|
|
|
|
from models.dna_models import MLPModel, CNNModel, TransformerModel, DeepFlyBrainModel |
|
|
from utils.flow_utils import DirichletConditionalFlow, expand_simplex, sample_cond_prob_path, simplex_proj, \ |
|
|
get_wasserstein_dist, update_ema, load_flybrain_designed_seqs |
|
|
from modules.general_module import GeneralModule |
|
|
from utils.log import get_logger |
|
|
|
|
|
from flow_matching.path import MixtureDiscreteProbPath |
|
|
from flow_matching.path.scheduler import PolynomialConvexScheduler |
|
|
from flow_matching.solver import MixtureDiscreteEulerSolver |
|
|
from flow_matching.utils import ModelWrapper |
|
|
from flow_matching.loss import MixturePathGeneralizedKL |
|
|
|
|
|
import pdb |
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class DNAModule(GeneralModule): |
|
|
def __init__(self, args, alphabet_size, num_cls, source_distribution="uniform"): |
|
|
super().__init__(args) |
|
|
self.alphabet_size = alphabet_size |
|
|
self.source_distribution = source_distribution |
|
|
self.epsilon = 1e-3 |
|
|
|
|
|
if source_distribution == "uniform": |
|
|
added_token = 0 |
|
|
elif source_distribution == "mask": |
|
|
self.mask_token = alphabet_size |
|
|
added_token = 1 |
|
|
else: |
|
|
raise NotImplementedError |
|
|
self.alphabet_size += added_token |
|
|
|
|
|
self.load_model(self.alphabet_size, num_cls) |
|
|
|
|
|
self.scheduler = PolynomialConvexScheduler(n=args.scheduler_n) |
|
|
self.path = MixtureDiscreteProbPath(scheduler=self.scheduler) |
|
|
self.loss_fn = MixturePathGeneralizedKL(path=self.path) |
|
|
|
|
|
self.val_outputs = defaultdict(list) |
|
|
self.train_outputs = defaultdict(list) |
|
|
self.train_out_initialized = False |
|
|
self.mean_log_ema = {} |
|
|
if self.args.taskiran_seq_path is not None: |
|
|
self.taskiran_fly_seqs = load_flybrain_designed_seqs(self.args.taskiran_seq_path).to(self.device) |
|
|
|
|
|
def on_load_checkpoint(self, checkpoint): |
|
|
checkpoint['state_dict'] = {k: v for k,v in checkpoint['state_dict'].items() if 'cls_model' not in k and 'distill_model' not in k} |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
self.stage = 'train' |
|
|
loss = self.general_step(batch, batch_idx) |
|
|
if self.args.ckpt_iterations is not None and self.trainer.global_step in self.args.ckpt_iterations: |
|
|
self.trainer.save_checkpoint(os.path.join(os.environ["MODEL_DIR"],f"epoch={self.trainer.current_epoch}-step={self.trainer.global_step}.ckpt")) |
|
|
|
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
self.stage = 'val' |
|
|
loss = self.general_step(batch, batch_idx) |
|
|
|
|
|
|
|
|
|
|
|
def general_step(self, batch, batch_idx=None): |
|
|
self.iter_step += 1 |
|
|
x_1, cls = batch |
|
|
B, L = x_1.shape |
|
|
x_1 = x_1.to(self.device) |
|
|
|
|
|
if self.source_distribution == "uniform": |
|
|
x_0 = torch.randint_like(x_1, high=self.alphabet_size) |
|
|
elif self.source_distribution == "mask": |
|
|
x_0 = torch.zeros_like(x_1) + self.mask_token |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
t = torch.rand(x_1.shape[0]) * (1 - self.epsilon) |
|
|
t = t.to(x_1.device) |
|
|
path_sample = self.path.sample(t=t, x_0=x_0, x_1=x_1) |
|
|
|
|
|
logits = self.model(x_t=path_sample.x_t, t=path_sample.t) |
|
|
loss = self.loss_fn(logits=logits, x_1=x_1, x_t=path_sample.x_t, t=path_sample.t) |
|
|
|
|
|
|
|
|
self.lg('loss', loss) |
|
|
if self.stage == "val": |
|
|
predicted = logits.argmax(dim=-1) |
|
|
accuracy = (predicted == x_1).float().mean() |
|
|
self.lg('acc', accuracy) |
|
|
self.last_log_time = time.time() |
|
|
return loss |
|
|
|
|
|
@torch.no_grad() |
|
|
def dirichlet_flow_inference(self, seq, cls, model, args): |
|
|
B, L = seq.shape |
|
|
K = model.alphabet_size |
|
|
x0 = torch.distributions.Dirichlet(torch.ones(B, L, model.alphabet_size, device=seq.device)).sample() |
|
|
eye = torch.eye(K).to(x0) |
|
|
xt = x0.clone() |
|
|
|
|
|
t_span = torch.linspace(1, args.alpha_max, self.args.num_integration_steps, device=self.device) |
|
|
for i, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])): |
|
|
xt_expanded, prior_weights = expand_simplex(xt, s[None].expand(B), args.prior_pseudocount) |
|
|
|
|
|
logits = model(xt_expanded, t=s[None].expand(B)) |
|
|
flow_probs = torch.nn.functional.softmax(logits / args.flow_temp, -1) |
|
|
|
|
|
if not torch.allclose(flow_probs.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (flow_probs >= 0).all(): |
|
|
print(f'WARNING: flow_probs.min(): {flow_probs.min()}. Some values of flow_probs do not lie on the simplex. There are we are {(flow_probs<0).sum()} negative values in flow_probs of shape {flow_probs.shape} that are negative. We are projecting them onto the simplex.') |
|
|
flow_probs = simplex_proj(flow_probs) |
|
|
|
|
|
c_factor = self.condflow.c_factor(xt.cpu().numpy(), s.item()) |
|
|
c_factor = torch.from_numpy(c_factor).to(xt) |
|
|
|
|
|
self.inf_counter += 1 |
|
|
|
|
|
if not (flow_probs >= 0).all(): print(f'flow_probs.min(): {flow_probs.min()}') |
|
|
cond_flows = (eye - xt.unsqueeze(-1)) * c_factor.unsqueeze(-2) |
|
|
flow = (flow_probs.unsqueeze(-2) * cond_flows).sum(-1) |
|
|
|
|
|
xt = xt + flow * (t - s) |
|
|
|
|
|
if not torch.allclose(xt.sum(2), torch.ones((B, L), device=self.device), atol=1e-4) or not (xt >= 0).all(): |
|
|
print(f'WARNING: xt.min(): {xt.min()}. Some values of xt do not lie on the simplex. There are we are {(xt<0).sum()} negative values in xt of shape {xt.shape} that are negative. We are projecting them onto the simplex.') |
|
|
xt = simplex_proj(xt) |
|
|
return logits, x0 |
|
|
|
|
|
def on_validation_epoch_start(self): |
|
|
self.inf_counter = 1 |
|
|
self.nan_inf_counter = 0 |
|
|
|
|
|
def on_validation_epoch_end(self): |
|
|
self.generator = np.random.default_rng() |
|
|
log = self._log |
|
|
log = {key: log[key] for key in log if "val_" in key} |
|
|
log = self.gather_log(log, self.trainer.world_size) |
|
|
mean_log = self.get_log_mean(log) |
|
|
mean_log.update({'val_nan_inf_step_fraction': self.nan_inf_counter / self.inf_counter}) |
|
|
|
|
|
mean_log.update({'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) |
|
|
|
|
|
self.mean_log_ema = update_ema(current_dict=mean_log, prev_ema=self.mean_log_ema, gamma=0.9) |
|
|
mean_log.update(self.mean_log_ema) |
|
|
if self.trainer.is_global_zero: |
|
|
logger.info(str(mean_log)) |
|
|
self.log_dict(mean_log, batch_size=1) |
|
|
if self.args.wandb: |
|
|
wandb.log(mean_log) |
|
|
|
|
|
path = os.path.join(os.environ["MODEL_DIR"], f"val_{self.trainer.global_step}.csv") |
|
|
pd.DataFrame(log).to_csv(path) |
|
|
|
|
|
for key in list(log.keys()): |
|
|
if "val_" in key: |
|
|
del self._log[key] |
|
|
self.val_outputs = defaultdict(list) |
|
|
|
|
|
|
|
|
def on_train_epoch_start(self) -> None: |
|
|
self.inf_counter = 1 |
|
|
self.nan_inf_counter = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_train_epoch_end(self): |
|
|
self.train_out_initialized = True |
|
|
log = self._log |
|
|
log = {key: log[key] for key in log if "train_" in key} |
|
|
log = self.gather_log(log, self.trainer.world_size) |
|
|
mean_log = self.get_log_mean(log) |
|
|
mean_log.update( |
|
|
{'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) |
|
|
|
|
|
if self.trainer.is_global_zero: |
|
|
logger.info(str(mean_log)) |
|
|
self.log_dict(mean_log, batch_size=1) |
|
|
if self.args.wandb: |
|
|
wandb.log(mean_log) |
|
|
|
|
|
for key in list(log.keys()): |
|
|
if "train_" in key: |
|
|
del self._log[key] |
|
|
|
|
|
def lg(self, key, data): |
|
|
if isinstance(data, torch.Tensor): |
|
|
data = data.detach().cpu().numpy() |
|
|
log = self._log |
|
|
if self.args.validate or self.stage == 'train': |
|
|
log["iter_" + key].append(data) |
|
|
log[self.stage + "_" + key].append(data) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
optimizer = optim.Adam(self.parameters(), lr=self.args.lr) |
|
|
return optimizer |
|
|
|
|
|
def plot_empirical_and_true(self, empirical_dist, true_dist): |
|
|
num_datasets_to_plot = min(4, empirical_dist.shape[0]) |
|
|
width = 1 |
|
|
|
|
|
fig, axes = plt.subplots(math.ceil(num_datasets_to_plot/2), 2, figsize=(10, 8)) |
|
|
for i in range(num_datasets_to_plot): |
|
|
row, col = i // 2, i % 2 |
|
|
x = np.arange(len(empirical_dist[i])) |
|
|
axes[row, col].bar(x, empirical_dist[i], width, label=f'empirical') |
|
|
axes[row, col].plot(x, true_dist[i], label=f'true density', color='orange') |
|
|
axes[row, col].legend() |
|
|
axes[row, col].set_title(f'Sequence position {i + 1}') |
|
|
axes[row, col].set_xlabel('Category') |
|
|
axes[row, col].set_ylabel('Density') |
|
|
plt.tight_layout() |
|
|
fig.canvas.draw() |
|
|
pil_img = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) |
|
|
plt.close() |
|
|
return pil_img |
|
|
|
|
|
def load_model(self, alphabet_size, num_cls): |
|
|
if self.args.model == 'cnn': |
|
|
self.model = CNNModel(self.args, alphabet_size=alphabet_size) |
|
|
elif self.args.model == 'mlp': |
|
|
self.model = MLPModel(input_dim=alphabet_size, time_dim=1, hidden_dim=self.args.hidden_dim, length=self.args.length) |
|
|
elif self.args.model == 'transformer': |
|
|
self.model = TransformerModel(alphabet_size=alphabet_size, seq_length=self.args.length, embed_dim=self.args.hidden_dim, \ |
|
|
num_layers=self.args.num_layers, num_heads=self.args.num_heads, dropout=self.args.dropout) |
|
|
elif self.args.model == 'deepflybrain': |
|
|
self.model = DeepFlyBrainModel(self.args, alphabet_size=alphabet_size,num_cls=num_cls) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
def plot_score_and_probs(self): |
|
|
clss = torch.cat(self.val_outputs['clss_noisycls']) |
|
|
probs = torch.softmax(torch.cat(self.val_outputs['logits_noisycls']), dim=-1) |
|
|
scores = torch.cat(self.val_outputs['scores_noisycls']).cpu().numpy() |
|
|
score_norms = np.linalg.norm(scores, axis=-1) |
|
|
alphas = torch.cat(self.val_outputs['alphas_noisycls']).cpu().numpy() |
|
|
true_probs = probs[torch.arange(len(probs)), clss].cpu().numpy() |
|
|
bins = np.linspace(min(alphas), 12, 20) |
|
|
indices = np.digitize(alphas, bins) |
|
|
bin_means = [np.mean(true_probs[indices == i]) for i in range(1, len(bins))] |
|
|
bin_std = [np.std(true_probs[indices == i]) for i in range(1, len(bins))] |
|
|
bin_centers = 0.5 * (bins[:-1] + bins[1:]) |
|
|
|
|
|
bin_pos_std = [np.std(true_probs[indices == i][true_probs[indices == i] > np.mean(true_probs[indices == i])]) for i in range(1, len(bins))] |
|
|
bin_neg_std = [np.std(true_probs[indices == i][true_probs[indices == i] < np.mean(true_probs[indices == i])]) for i in range(1, len(bins))] |
|
|
plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std}) |
|
|
plt.figure(figsize=(10, 6)) |
|
|
sns.lineplot(x='Alphas', y='Means', data=plot_data) |
|
|
plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'], plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3) |
|
|
plt.xlabel('Binned alphas values') |
|
|
plt.ylabel('Mean of predicted probs for true class') |
|
|
fig = plt.gcf() |
|
|
fig.canvas.draw() |
|
|
pil_probs = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) |
|
|
|
|
|
plt.close() |
|
|
bin_means = [np.mean(score_norms[indices == i]) for i in range(1, len(bins))] |
|
|
bin_std = [np.std(score_norms[indices == i]) for i in range(1, len(bins))] |
|
|
bin_pos_std = [np.std(score_norms[indices == i][score_norms[indices == i] > np.mean(score_norms[indices == i])]) for i in range(1, len(bins))] |
|
|
bin_neg_std = [np.std(score_norms[indices == i][score_norms[indices == i] < np.mean(score_norms[indices == i])]) for i in range(1, len(bins))] |
|
|
plot_data = pd.DataFrame({'Alphas': bin_centers, 'Means': bin_means, 'Std': bin_std, 'Pos_Std': bin_pos_std, 'Neg_Std': bin_neg_std}) |
|
|
plt.figure(figsize=(10, 6)) |
|
|
sns.lineplot(x='Alphas', y='Means', data=plot_data) |
|
|
plt.fill_between(plot_data['Alphas'], plot_data['Means'] - plot_data['Neg_Std'], |
|
|
plot_data['Means'] + plot_data['Pos_Std'], alpha=0.3) |
|
|
plt.xlabel('Binned alphas values') |
|
|
plt.ylabel('Mean of norm of the scores') |
|
|
fig = plt.gcf() |
|
|
fig.canvas.draw() |
|
|
pil_score_norms = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) |
|
|
return pil_probs, pil_score_norms |
|
|
|
|
|
def log_data_similarities(self, seq_pred): |
|
|
similarities1 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class1[None, :, :]) |
|
|
similarities2 = seq_pred.cpu()[:, None, :].eq(self.toy_data.data_class2[None, :, :]) |
|
|
similarities = seq_pred.cpu()[:, None, :].eq(torch.cat([self.toy_data.data_class2[None, :, :], self.toy_data.data_class1[None, :, :]],dim=1)) |
|
|
self.lg('data1_sim', similarities1.float().mean(-1).max(-1)[0]) |
|
|
self.lg('data2_sim', similarities2.float().mean(-1).max(-1)[0]) |
|
|
self.lg('data_sim', similarities.float().mean(-1).max(-1)[0]) |
|
|
self.lg('mean_data1_sim', similarities1.float().mean(-1).mean(-1)) |
|
|
self.lg('mean_data2_sim', similarities2.float().mean(-1).mean(-1)) |
|
|
self.lg('mean_data_sim', similarities.float().mean(-1).mean(-1)) |
|
|
|