| | import os |
| | import sys |
| | import torch |
| | import wandb |
| | import matplotlib.pyplot as plt |
| | import pytorch_lightning as pl |
| | from torch.optim import AdamW |
| | from torchmetrics.functional import mean_squared_error |
| | from torchdyn.core import NeuralODE |
| | import numpy as np |
| | import lpips |
| | from .networks.utils import flow_model_torch_wrapper |
| | from .utils import plot_lidar |
| | from .ema import EMA |
| | from torchdiffeq import odeint as odeint2 |
| | from .losses.energy_loss import EnergySolver, ReconsLoss |
| |
|
| | class GrowthNetTrain(pl.LightningModule): |
| | def __init__( |
| | self, |
| | flow_nets, |
| | growth_nets, |
| | skipped_time_points=None, |
| | ot_sampler=None, |
| | args=None, |
| | |
| | state_cost=None, |
| | data_manifold_metric=None, |
| | |
| | joint = False |
| | ): |
| | super().__init__() |
| | |
| | self.flow_nets = flow_nets |
| | |
| | if not joint: |
| | for param in self.flow_nets.parameters(): |
| | param.requires_grad = False |
| | |
| | self.growth_nets = growth_nets |
| | |
| | self.ot_sampler = ot_sampler |
| | self.skipped_time_points = skipped_time_points |
| |
|
| | self.optimizer_name = args.growth_optimizer |
| | self.lr = args.growth_lr |
| | self.weight_decay = args.growth_weight_decay |
| | self.whiten = args.whiten |
| | self.working_dir = args.working_dir |
| | |
| | self.args = args |
| | |
| | |
| | self.state_cost = state_cost |
| | self.data_manifold_metric = data_manifold_metric |
| | self.branches = len(growth_nets) |
| | self.metric_clusters = args.metric_clusters |
| | |
| | self.recons_loss = ReconsLoss() |
| | |
| | |
| | self.lambda_energy = args.lambda_energy |
| | self.lambda_mass = args.lambda_mass |
| | self.lambda_match = args.lambda_match |
| | self.lambda_recons = args.lambda_recons |
| | |
| | self.joint = joint |
| |
|
| | def forward(self, t, xt, branch_idx): |
| | |
| | return self.growth_nets[branch_idx](t, xt) |
| |
|
| | def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False): |
| | x0s = main_batch["x0"][0] |
| | w0s = main_batch["x0"][1] |
| | x1s_list = [] |
| | w1s_list = [] |
| | |
| | if self.branches > 1: |
| | for i in range(self.branches): |
| | x1s_list.append([main_batch[f"x1_{i+1}"][0]]) |
| | w1s_list.append([main_batch[f"x1_{i+1}"][1]]) |
| | else: |
| | x1s_list.append([main_batch["x1"][0]]) |
| | w1s_list.append([main_batch["x1"][1]]) |
| | |
| | if self.args.manifold: |
| | |
| | if self.metric_clusters == 7 and self.branches == 6: |
| | |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[2]), |
| | (metric_samples_batch[0], metric_samples_batch[3]), |
| | (metric_samples_batch[0], metric_samples_batch[4]), |
| | (metric_samples_batch[0], metric_samples_batch[5]), |
| | (metric_samples_batch[0], metric_samples_batch[6]), |
| | ] |
| | elif self.metric_clusters == 4: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[2]), |
| | (metric_samples_batch[0], metric_samples_batch[3]), |
| | ] |
| | elif self.metric_clusters == 3: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[2]), |
| | ] |
| | elif self.metric_clusters == 2 and self.branches == 2: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | ] |
| | elif self.metric_clusters == 2: |
| | |
| | |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]) |
| | ] * self.branches |
| | else: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | ] |
| | |
| | batch_size = x0s.shape[0] |
| | |
| | assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches" |
| | |
| | energy_loss = [0.] * self.branches |
| | mass_loss = 0. |
| | neg_weight_penalty = 0. |
| | match_loss = [0.] * self.branches |
| | recons_loss = [0.] * self.branches |
| | |
| | dtype = x0s[0].dtype |
| | |
| | m0s = torch.zeros_like(w0s, dtype=dtype) |
| | start_state = (x0s, w0s, m0s) |
| | |
| | xt = [x0s.clone() for _ in range(self.branches)] |
| | w0_branch = torch.zeros_like(w0s, dtype=dtype) |
| | w0_branches = [] |
| | w0_branches.append(w0s) |
| | for _ in range(self.branches - 1): |
| | w0_branches.append(w0_branch) |
| | |
| | wt = w0_branches |
| | |
| | mt = [m0s.clone() for _ in range(self.branches)] |
| | |
| | |
| | for step_idx, (s, t) in enumerate(zip(self.timesteps[:-1], self.timesteps[1:])): |
| | time = torch.Tensor([s, t]) |
| | |
| | total_w_t = 0 |
| | |
| | for i in range(self.branches): |
| | |
| | if self.args.manifold: |
| | start_samples, end_samples = branch_sample_pairs[i] |
| | samples = torch.cat([start_samples, end_samples], dim=0) |
| | else: |
| | samples = None |
| | |
| | |
| | start_state = (xt[i], wt[i], mt[i]) |
| | |
| | |
| | xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx) |
| | |
| | |
| | xt_last = xt_next[-1] |
| | wt_last = wt_next[-1] |
| | mt_last = mt_next[-1] |
| | |
| | total_w_t += wt_last |
| |
|
| | energy_loss[i] += (mt_last - mt[i]) |
| | neg_weight_penalty += torch.relu(-wt_last).sum() |
| | |
| | |
| | xt[i] = xt_last.clone().detach() |
| | wt[i] = wt_last.clone().detach() |
| | mt[i] = mt_last.clone().detach() |
| |
|
| | |
| | target = torch.ones_like(total_w_t) |
| | mass_loss += mean_squared_error(total_w_t, target) |
| | |
| | |
| | for i in range(self.branches): |
| | match_loss[i] = mean_squared_error(wt[i], w1s_list[i][0]) |
| | |
| | recons_loss[i] = self.recons_loss(xt[i], x1s_list[i][0]) |
| | |
| | |
| | mass_loss = mass_loss / max(len(self.timesteps) - 1, 1) |
| | |
| | |
| | |
| | |
| | if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'cluster_sizes'): |
| | cluster_sizes = self.trainer.datamodule.cluster_sizes |
| | max_size = max(cluster_sizes) |
| | |
| | branch_weights = torch.tensor([max_size / size for size in cluster_sizes], |
| | dtype=energy_loss[0].dtype, device=energy_loss[0].device) |
| | |
| | branch_weights = branch_weights * self.branches / branch_weights.sum() |
| | |
| | energy_loss = torch.mean(torch.stack([e.mean() for e in energy_loss]) * branch_weights) |
| | match_loss = torch.mean(torch.stack(match_loss) * branch_weights) |
| | recons_loss = torch.mean(torch.stack(recons_loss) * branch_weights) |
| | else: |
| | |
| | energy_loss = torch.mean(torch.stack([e.mean() for e in energy_loss])) |
| | match_loss = torch.mean(torch.stack(match_loss)) |
| | recons_loss = torch.mean(torch.stack(recons_loss)) |
| | |
| | loss = (self.lambda_energy * energy_loss) + (self.lambda_mass * (mass_loss + neg_weight_penalty)) + (self.lambda_match * match_loss) \ |
| | + (self.lambda_recons * recons_loss) |
| | |
| | if self.joint: |
| | if validation: |
| | self.log("JointTrain/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | else: |
| | self.log("JointTrain/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | else: |
| | if validation: |
| | self.log("GrowthNet/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | else: |
| | self.log("GrowthNet/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | |
| | return loss |
| | |
| | def take_step(self, t, start_state, branch_idx, samples=None, timestep_idx=0): |
| | |
| | flow_net = self.flow_nets[branch_idx] |
| | growth_net = self.growth_nets[branch_idx] |
| | |
| | |
| | x_t, w_t, m_t = odeint2(EnergySolver(flow_net, growth_net, self.state_cost, self.data_manifold_metric, samples, timestep_idx), start_state, t, options=dict(step_size=0.1),method='euler') |
| | |
| | return x_t, w_t, m_t |
| | |
| | def training_step(self, batch, batch_idx): |
| | if isinstance(batch, (list, tuple)): |
| | batch = batch[0] |
| | if isinstance(batch, dict) and "train_samples" in batch: |
| | main_batch = batch["train_samples"] |
| | metric_batch = batch["metric_samples"] |
| | if isinstance(main_batch, tuple): |
| | main_batch = main_batch[0] |
| | if isinstance(metric_batch, tuple): |
| | metric_batch = metric_batch[0] |
| | else: |
| | |
| | main_batch = batch.get("train_samples", batch) |
| | metric_batch = batch.get("metric_samples", []) |
| | |
| | self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() |
| | loss = self._compute_loss(main_batch, metric_batch, validation=False) |
| | |
| | if self.joint: |
| | self.log( |
| | "JointTrain/train_loss", |
| | loss, |
| | on_step=False, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | ) |
| | else: |
| | self.log( |
| | "GrowthNet/train_loss", |
| | loss, |
| | on_step=False, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | ) |
| | |
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | if isinstance(batch, (list, tuple)): |
| | batch = batch[0] |
| | if isinstance(batch, dict) and "val_samples" in batch: |
| | main_batch = batch["val_samples"] |
| | metric_batch = batch["metric_samples"] |
| | if isinstance(main_batch, tuple): |
| | main_batch = main_batch[0] |
| | if isinstance(metric_batch, tuple): |
| | metric_batch = metric_batch[0] |
| | else: |
| | |
| | main_batch = batch.get("val_samples", batch) |
| | metric_batch = batch.get("metric_samples", []) |
| |
|
| | self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() |
| | val_loss = self._compute_loss(main_batch, metric_batch, validation=True) |
| | |
| | if self.joint: |
| | self.log( |
| | "JointTrain/val_loss", |
| | val_loss, |
| | on_step=False, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | ) |
| | else: |
| | self.log( |
| | "GrowthNet/val_loss", |
| | val_loss, |
| | on_step=False, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | ) |
| | return val_loss |
| |
|
| | def optimizer_step(self, *args, **kwargs): |
| | super().optimizer_step(*args, **kwargs) |
| | for net in self.growth_nets: |
| | if isinstance(net, EMA): |
| | net.update_ema() |
| | if self.joint: |
| | for net in self.flow_nets: |
| | if isinstance(net, EMA): |
| | net.update_ema() |
| |
|
| | def configure_optimizers(self): |
| | params = [] |
| | for net in self.growth_nets: |
| | params += list(net.parameters()) |
| | |
| | if self.joint: |
| | for net in self.flow_nets: |
| | params += list(net.parameters()) |
| | |
| | if self.optimizer_name == "adamw": |
| | optimizer = AdamW( |
| | params, |
| | lr=self.lr, |
| | weight_decay=self.weight_decay, |
| | ) |
| | elif self.optimizer_name == "adam": |
| | optimizer = torch.optim.Adam( |
| | params, |
| | lr=self.lr, |
| | ) |
| |
|
| | return optimizer |
| | |
| | @torch.no_grad() |
| | def get_mass_and_position(self, main_batch, metric_samples_batch=None): |
| | if isinstance(main_batch, dict): |
| | main_batch = main_batch |
| | else: |
| | main_batch = main_batch[0] |
| | |
| | x0s = main_batch["x0"][0] |
| | w0s = main_batch["x0"][1] |
| |
|
| | if self.args.manifold: |
| | if self.metric_clusters == 4: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[2]), |
| | (metric_samples_batch[0], metric_samples_batch[3]), |
| | ] |
| | elif self.metric_clusters == 3: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[2]), |
| | ] |
| | elif self.metric_clusters == 2 and self.branches == 2: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | ] |
| | elif self.metric_clusters == 2: |
| | |
| | |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]) |
| | ] * self.branches |
| | else: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | ] |
| |
|
| | batch_size = x0s.shape[0] |
| | dtype = x0s[0].dtype |
| |
|
| | m0s = torch.zeros_like(w0s, dtype=dtype) |
| | xt = [x0s.clone() for _ in range(self.branches)] |
| |
|
| | w0_branch = torch.zeros_like(w0s, dtype=dtype) |
| | w0_branches = [] |
| | w0_branches.append(w0s) |
| | for _ in range(self.branches - 1): |
| | w0_branches.append(w0_branch) |
| |
|
| | wt = w0_branches |
| | mt = [m0s.clone() for _ in range(self.branches)] |
| |
|
| | time_points = [] |
| | mass_over_time = [[] for _ in range(self.branches)] |
| | energy_over_time = [[] for _ in range(self.branches)] |
| | |
| | weights_over_time = [[] for _ in range(self.branches)] |
| | all_trajs = [[] for _ in range(self.branches)] |
| |
|
| | t_span = torch.linspace(0, 1, 101) |
| | for step_idx, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])): |
| | time_points.append(t.item()) |
| | time = torch.Tensor([s, t]) |
| |
|
| | for i in range(self.branches): |
| | if self.args.manifold: |
| | start_samples, end_samples = branch_sample_pairs[i] |
| | samples = torch.cat([start_samples, end_samples], dim=0) |
| | else: |
| | samples = None |
| |
|
| | start_state = (xt[i], wt[i], mt[i]) |
| | xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx) |
| |
|
| | xt[i] = xt_next[-1].clone().detach() |
| | wt[i] = wt_next[-1].clone().detach() |
| | mt[i] = mt_next[-1].clone().detach() |
| |
|
| | all_trajs[i].append(xt[i].clone().detach()) |
| | mass_over_time[i].append(wt[i].mean().item()) |
| | energy_over_time[i].append(mt[i].mean().item()) |
| | |
| | try: |
| | weights_over_time[i].append(wt[i].clone().detach()) |
| | except Exception: |
| | |
| | weights_over_time[i].append(torch.tensor(wt[i].mean().item()).unsqueeze(0)) |
| | |
| | return time_points, xt, all_trajs, mass_over_time, energy_over_time, weights_over_time |
| |
|
| | @torch.no_grad() |
| | def _plot_mass_and_energy(self, main_batch, metric_samples_batch=None, save_dir=None): |
| | x0s = main_batch["x0"][0] |
| | w0s = main_batch["x0"][1] |
| |
|
| | if self.args.manifold: |
| | if self.metric_clusters == 7 and self.branches == 6: |
| | |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[2]), |
| | (metric_samples_batch[0], metric_samples_batch[3]), |
| | (metric_samples_batch[0], metric_samples_batch[4]), |
| | (metric_samples_batch[0], metric_samples_batch[5]), |
| | (metric_samples_batch[0], metric_samples_batch[6]), |
| | ] |
| | elif self.metric_clusters == 4: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[2]), |
| | (metric_samples_batch[0], metric_samples_batch[3]), |
| | ] |
| | elif self.metric_clusters == 3: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[2]), |
| | ] |
| | elif self.metric_clusters == 2 and self.branches == 2: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | ] |
| | else: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]), |
| | ] |
| |
|
| | batch_size = x0s.shape[0] |
| | dtype = x0s[0].dtype |
| |
|
| | m0s = torch.zeros_like(w0s, dtype=dtype) |
| | xt = [x0s.clone() for _ in range(self.branches)] |
| |
|
| | w0_branch = torch.zeros_like(w0s, dtype=dtype) |
| | w0_branches = [] |
| | w0_branches.append(w0s) |
| | for _ in range(self.branches - 1): |
| | w0_branches.append(w0_branch) |
| |
|
| | wt = w0_branches |
| | mt = [m0s.clone() for _ in range(self.branches)] |
| |
|
| | time_points = [] |
| | mass_over_time = [[] for _ in range(self.branches)] |
| | energy_over_time = [[] for _ in range(self.branches)] |
| |
|
| | t_span = torch.linspace(0, 1, 101) |
| | for step_idx, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])): |
| | time_points.append(t.item()) |
| | time = torch.Tensor([s, t]) |
| |
|
| | for i in range(self.branches): |
| | if self.args.manifold: |
| | start_samples, end_samples = branch_sample_pairs[i] |
| | samples = torch.cat([start_samples, end_samples], dim=0) |
| | else: |
| | samples = None |
| |
|
| | start_state = (xt[i], wt[i], mt[i]) |
| | xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx) |
| |
|
| | xt[i] = xt_next[-1].clone().detach() |
| | wt[i] = wt_next[-1].clone().detach() |
| | mt[i] = mt_next[-1].clone().detach() |
| |
|
| | mass_over_time[i].append(wt[i].mean().item()) |
| | energy_over_time[i].append(mt[i].mean().item()) |
| |
|
| | if save_dir is None: |
| | run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| | save_dir = os.path.join(self.args.working_dir, 'results', run_name, 'figures') |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | |
| | if self.args.branches == 3: |
| | branch_colors = ['#9793F8', '#50B2D7', '#D577FF'] |
| | else: |
| | branch_colors = ['#50B2D7', '#D577FF'] |
| |
|
| | |
| | plt.figure(figsize=(8, 5)) |
| | for i in range(self.branches): |
| | color = branch_colors[i] |
| | plt.plot(time_points, mass_over_time[i], color=color, linewidth=2.5, label=f"Mass Branch {i}") |
| | plt.xlabel("Time") |
| | plt.ylabel("Mass") |
| | plt.title("Mass Evolution per Branch") |
| | plt.legend() |
| | plt.grid(True) |
| | if self.joint: |
| | mass_path = os.path.join(save_dir, f"{self.args.data_name}_joint_mass.png") |
| | else: |
| | mass_path = os.path.join(save_dir, f"{self.args.data_name}_growth_mass.png") |
| | plt.savefig(mass_path, dpi=300, bbox_inches="tight") |
| | plt.close() |
| |
|
| | |
| | plt.figure(figsize=(8, 5)) |
| | for i in range(self.branches): |
| | color = branch_colors[i] |
| | plt.plot(time_points, energy_over_time[i], color=color, linewidth=2.5, label=f"Energy Branch {i}") |
| | plt.xlabel("Time") |
| | plt.ylabel("Energy") |
| | plt.title("Energy Evolution per Branch") |
| | plt.legend() |
| | plt.grid(True) |
| | if self.joint: |
| | energy_path = os.path.join(save_dir, f"{self.args.data_name}_joint_energy.png") |
| | else: |
| | energy_path = os.path.join(save_dir, f"{self.args.data_name}_growth_energy.png") |
| | plt.savefig(energy_path, dpi=300, bbox_inches="tight") |
| | plt.close() |
| | |
| | |
| | class GrowthNetTrainLidar(GrowthNetTrain): |
| | def test_step(self, batch, batch_idx): |
| | |
| | if isinstance(batch, dict): |
| | main_batch = batch["test_samples"][0] |
| | metric_batch = batch["metric_samples"][0] |
| | else: |
| | |
| | main_batch = batch[0][0] |
| | metric_batch = batch[1][0] |
| | |
| | self._plot_mass_and_energy(main_batch, metric_batch) |
| | |
| | x0 = main_batch["x0"][0] |
| | cloud_points = main_batch["dataset"][0] |
| | t_span = torch.linspace(0, 1, 101) |
| | |
| |
|
| | all_trajs = [] |
| |
|
| | for i, flow_net in enumerate(self.flow_nets): |
| | node = NeuralODE( |
| | flow_model_torch_wrapper(flow_net), |
| | solver="euler", |
| | sensitivity="adjoint", |
| | ) |
| |
|
| | with torch.no_grad(): |
| | traj = node.trajectory(x0, t_span).cpu() |
| |
|
| | if self.whiten: |
| | traj_shape = traj.shape |
| | traj = traj.reshape(-1, 3) |
| | traj = self.trainer.datamodule.scaler.inverse_transform( |
| | traj.cpu().detach().numpy() |
| | ).reshape(traj_shape) |
| |
|
| | traj = torch.tensor(traj) |
| | traj = torch.transpose(traj, 0, 1) |
| | all_trajs.append(traj) |
| |
|
| | |
| | if self.whiten: |
| | cloud_points = torch.tensor( |
| | self.trainer.datamodule.scaler.inverse_transform( |
| | cloud_points.cpu().detach().numpy() |
| | ) |
| | ) |
| |
|
| | |
| | fig = plt.figure(figsize=(6, 5)) |
| | ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
| | ax.view_init(elev=30, azim=-115, roll=0) |
| | for i, traj in enumerate(all_trajs): |
| | plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) |
| | run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name |
| | results_dir = os.path.join(self.args.working_dir, 'results', run_name) |
| | lidar_fig_dir = os.path.join(results_dir, 'figures') |
| | os.makedirs(lidar_fig_dir, exist_ok=True) |
| | if self.joint: |
| | plt.savefig(os.path.join(lidar_fig_dir, 'joint_lidar_all_branches.png'), dpi=300) |
| | else: |
| | plt.savefig(os.path.join(lidar_fig_dir, 'growth_lidar_all_branches.png'), dpi=300) |
| | plt.close() |
| |
|
| | |
| | for i, traj in enumerate(all_trajs): |
| | fig = plt.figure(figsize=(6, 5)) |
| | ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
| | ax.view_init(elev=30, azim=-115, roll=0) |
| | plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) |
| | if self.joint: |
| | plt.savefig(os.path.join(lidar_fig_dir, f'joint_lidar_branch_{i + 1}.png'), dpi=300) |
| | else: |
| | plt.savefig(os.path.join(lidar_fig_dir, f'growth_lidar_branch_{i + 1}.png'), dpi=300) |
| | plt.close() |
| | |
| | class GrowthNetTrainCell(GrowthNetTrain): |
| | def test_step(self, batch, batch_idx): |
| | if self.args.data_type in ["scrna", "tahoe"]: |
| | main_batch = batch[0]["test_samples"][0] |
| | metric_batch = batch[0]["metric_samples"][0] |
| | else: |
| | main_batch = batch["test_samples"][0] |
| | metric_batch = batch["metric_samples"][0] |
| | |
| | self._plot_mass_and_energy(main_batch, metric_batch) |
| |
|
| |
|
| | class SequentialGrowthNetTrain(pl.LightningModule): |
| | """ |
| | Sequential growth network training for multi-timepoint data. |
| | Learns growth rates for transitions between consecutive timepoints. |
| | """ |
| | def __init__( |
| | self, |
| | flow_nets, |
| | growth_nets, |
| | skipped_time_points=None, |
| | ot_sampler=None, |
| | args=None, |
| | data_manifold_metric=None, |
| | joint=False |
| | ): |
| | super().__init__() |
| | self.flow_nets = flow_nets |
| | |
| | if not joint: |
| | for param in self.flow_nets.parameters(): |
| | param.requires_grad = False |
| | |
| | self.growth_nets = growth_nets |
| | self.ot_sampler = ot_sampler |
| | self.skipped_time_points = skipped_time_points |
| |
|
| | self.optimizer_name = args.growth_optimizer |
| | self.lr = args.growth_lr |
| | self.weight_decay = args.growth_weight_decay |
| | self.whiten = args.whiten |
| | self.working_dir = args.working_dir |
| | |
| | self.args = args |
| | self.data_manifold_metric = data_manifold_metric |
| | self.branches = len(growth_nets) |
| | self.metric_clusters = args.metric_clusters |
| | |
| | self.recons_loss = ReconsLoss() |
| | |
| | |
| | self.lambda_energy = args.lambda_energy |
| | self.lambda_mass = args.lambda_mass |
| | self.lambda_match = args.lambda_match |
| | self.lambda_recons = args.lambda_recons |
| | |
| | self.joint = joint |
| | self.num_timepoints = None |
| | self.timepoint_keys = None |
| |
|
| | def forward(self, t, xt, branch_idx): |
| | return self.growth_nets[branch_idx](t, xt) |
| |
|
| | def setup(self, stage=None): |
| | """Initialize timepoint keys before training/validation starts.""" |
| | if self.timepoint_keys is None: |
| | timepoint_data = self.trainer.datamodule.get_timepoint_data() |
| | self.timepoint_keys = [k for k in sorted(timepoint_data.keys()) |
| | if not any(x in k for x in ['_', 'time_labels'])] |
| | self.num_timepoints = len(self.timepoint_keys) |
| | print(f"Training sequential growth for {self.num_timepoints} timepoints: {self.timepoint_keys}") |
| |
|
| | def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False): |
| | """Compute loss for sequential growth between timepoints.""" |
| | x0s = main_batch["x0"][0] |
| | w0s = main_batch["x0"][1] |
| | |
| | |
| | if self.args.manifold: |
| | if self.metric_clusters == 2: |
| | branch_sample_pairs = [ |
| | (metric_samples_batch[0], metric_samples_batch[1]) |
| | ] * self.branches |
| | else: |
| | branch_sample_pairs = [] |
| | for b in range(self.branches): |
| | if b + 1 < len(metric_samples_batch): |
| | branch_sample_pairs.append( |
| | (metric_samples_batch[0], metric_samples_batch[b + 1]) |
| | ) |
| | else: |
| | branch_sample_pairs.append( |
| | (metric_samples_batch[0], metric_samples_batch[1]) |
| | ) |
| | |
| | total_loss = 0 |
| | total_energy_loss = 0 |
| | total_mass_loss = 0 |
| | total_match_loss = 0 |
| | total_recons_loss = 0 |
| | num_transitions = 0 |
| | |
| | |
| | for i in range(len(self.timepoint_keys) - 1): |
| | t_curr_key = self.timepoint_keys[i] |
| | t_next_key = self.timepoint_keys[i + 1] |
| | |
| | batch_curr_key = f"x{t_curr_key.replace('t', '').replace('final', '1')}" |
| | x_curr = main_batch[batch_curr_key][0] |
| | w_curr = main_batch[batch_curr_key][1] |
| | |
| | if i == len(self.timepoint_keys) - 2: |
| | |
| | |
| | if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'cluster_sizes'): |
| | cluster_sizes = self.trainer.datamodule.cluster_sizes |
| | max_size = max(cluster_sizes) |
| | |
| | branch_weights = [max_size / size for size in cluster_sizes] |
| | else: |
| | branch_weights = [1.0] * self.branches |
| | |
| | for b in range(self.branches): |
| | x_next = main_batch[f"x1_{b+1}"][0] |
| | w_next = main_batch[f"x1_{b+1}"][1] |
| | |
| | |
| | loss, energy_l, mass_l, match_l, recons_l = self._compute_transition_loss( |
| | x_curr, w_curr, x_next, w_next, b, i, |
| | branch_sample_pairs[b] if self.args.manifold else None |
| | ) |
| | |
| | total_loss += loss * branch_weights[b] |
| | total_energy_loss += energy_l * branch_weights[b] |
| | total_mass_loss += mass_l * branch_weights[b] |
| | total_match_loss += match_l * branch_weights[b] |
| | total_recons_loss += recons_l * branch_weights[b] |
| | num_transitions += 1 |
| | else: |
| | |
| | batch_next_key = f"x{t_next_key.replace('t', '').replace('final', '1')}" |
| | x_next = main_batch[batch_next_key][0] |
| | w_next = main_batch[batch_next_key][1] |
| | |
| | for b in range(self.branches): |
| | loss, energy_l, mass_l, match_l, recons_l = self._compute_transition_loss( |
| | x_curr, w_curr, x_next, w_next, b, i, |
| | branch_sample_pairs[b] if self.args.manifold else None |
| | ) |
| | total_loss += loss |
| | total_energy_loss += energy_l |
| | total_mass_loss += mass_l |
| | total_match_loss += match_l |
| | total_recons_loss += recons_l |
| | num_transitions += 1 |
| | |
| | |
| | avg_energy_loss = total_energy_loss / num_transitions if num_transitions > 0 else total_energy_loss |
| | avg_mass_loss = total_mass_loss / num_transitions if num_transitions > 0 else total_mass_loss |
| | avg_match_loss = total_match_loss / num_transitions if num_transitions > 0 else total_match_loss |
| | avg_recons_loss = total_recons_loss / num_transitions if num_transitions > 0 else total_recons_loss |
| | |
| | |
| | if self.joint: |
| | if validation: |
| | self.log("JointTrain/val_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/val_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/val_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/val_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | else: |
| | self.log("JointTrain/train_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/train_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/train_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("JointTrain/train_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | else: |
| | if validation: |
| | self.log("GrowthNet/val_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/val_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/val_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/val_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | else: |
| | self.log("GrowthNet/train_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/train_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/train_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | self.log("GrowthNet/train_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True) |
| | |
| | return total_loss |
| |
|
| | def _compute_transition_loss(self, x0, w0, x1, w1, branch_idx, transition_idx, metric_pair): |
| | """Compute loss for a single timepoint transition.""" |
| | if self.ot_sampler is not None: |
| | x0, x1 = self.ot_sampler.sample_plan(x0, x1, replace=True) |
| | |
| | |
| | t_span = torch.linspace(0, 1, 10, device=x0.device) |
| | |
| | flow_model = flow_model_torch_wrapper(self.flow_nets[branch_idx]) |
| | node = NeuralODE(flow_model, solver="euler", sensitivity="adjoint") |
| | |
| | with torch.no_grad(): |
| | traj = node.trajectory(x0, t_span) |
| | |
| | |
| | energy_loss = 0 |
| | mass_loss = 0 |
| | neg_weight_penalty = 0 |
| | |
| | for t_idx in range(len(t_span)): |
| | t = t_span[t_idx] |
| | xt = traj[t_idx] |
| | |
| | |
| | growth = self.growth_nets[branch_idx](t.unsqueeze(0).expand(xt.shape[0]), xt) |
| | |
| | |
| | if self.args.manifold and metric_pair is not None: |
| | start_samples, end_samples = metric_pair |
| | samples = torch.cat([start_samples, end_samples], dim=0) |
| | _, kinetic, potential = self.data_manifold_metric.calculate_velocity( |
| | xt, torch.zeros_like(xt), samples, transition_idx |
| | ) |
| | energy = kinetic + potential |
| | else: |
| | energy = (growth ** 2).sum(dim=-1) |
| | |
| | energy_loss += energy.mean() |
| | |
| | |
| | growth_sum = growth.sum(dim=-1, keepdim=True) |
| | wt = w0 * torch.exp(growth_sum) |
| | mass = wt.sum() |
| | mass_loss += (mass - w1.sum()).abs() |
| | neg_weight_penalty += torch.relu(-wt).sum() |
| | |
| | |
| | xt_final = traj[-1] |
| | match_loss = mean_squared_error(wt, w1) |
| | recons_loss = self.recons_loss(xt_final, x1) |
| | |
| | total_loss = ( |
| | self.lambda_energy * energy_loss + |
| | self.lambda_mass * (mass_loss + neg_weight_penalty) + |
| | self.lambda_match * match_loss + |
| | self.lambda_recons * recons_loss |
| | ) |
| | |
| | return total_loss, energy_loss, mass_loss + neg_weight_penalty, match_loss, recons_loss |
| |
|
| | def training_step(self, batch, batch_idx): |
| | if isinstance(batch, (list, tuple)): |
| | batch = batch[0] |
| | main_batch = batch["train_samples"] |
| | metric_batch = batch["metric_samples"] |
| | if isinstance(main_batch, tuple): |
| | main_batch = main_batch[0] |
| | if isinstance(metric_batch, tuple): |
| | metric_batch = metric_batch[0] |
| | |
| | loss = self._compute_loss(main_batch, metric_batch) |
| | |
| | if self.joint: |
| | self.log( |
| | "JointTrain/train_loss", |
| | loss, |
| | on_step=False, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | ) |
| | else: |
| | self.log( |
| | "GrowthNet/train_loss", |
| | loss, |
| | on_step=False, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | ) |
| | |
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | if isinstance(batch, (list, tuple)): |
| | batch = batch[0] |
| | main_batch = batch["val_samples"] |
| | metric_batch = batch["metric_samples"] |
| | if isinstance(main_batch, tuple): |
| | main_batch = main_batch[0] |
| | if isinstance(metric_batch, tuple): |
| | metric_batch = metric_batch[0] |
| | |
| | loss = self._compute_loss(main_batch, metric_batch, validation=True) |
| | |
| | if self.joint: |
| | self.log( |
| | "JointTrain/val_loss", |
| | loss, |
| | on_step=False, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | ) |
| | else: |
| | self.log( |
| | "GrowthNet/val_loss", |
| | loss, |
| | on_step=False, |
| | on_epoch=True, |
| | prog_bar=True, |
| | logger=True, |
| | ) |
| | |
| | return loss |
| |
|
| | def configure_optimizers(self): |
| | import itertools |
| | params = list(itertools.chain(*[net.parameters() for net in self.growth_nets])) |
| | if self.joint: |
| | params += list(itertools.chain(*[net.parameters() for net in self.flow_nets])) |
| | |
| | if self.optimizer_name == "adam": |
| | optimizer = torch.optim.Adam(params, lr=self.lr) |
| | elif self.optimizer_name == "adamw": |
| | optimizer = torch.optim.AdamW( |
| | params, |
| | lr=self.lr, |
| | weight_decay=self.weight_decay, |
| | ) |
| | return optimizer |