import sys sys.path.append("./BranchSBM") import torch import pytorch_lightning as pl from branchsbm.ema import EMA import itertools from utils import wasserstein_distance, plot_lidar import matplotlib.pyplot as plt class BranchInterpolantTrain(pl.LightningModule): def __init__( self, flow_matcher, args, skipped_time_points: list = None, ot_sampler=None, state_cost=None, data_manifold_metric=None, ): super().__init__() self.save_hyperparameters() self.args = args self.flow_matcher = flow_matcher # list of geopath nets self.geopath_nets = flow_matcher.geopath_nets self.branches = len(self.geopath_nets) self.metric_clusters = args.metric_clusters self.ot_sampler = ot_sampler self.skipped_time_points = skipped_time_points if skipped_time_points else [] self.optimizer_name = args.geopath_optimizer self.lr = args.geopath_lr self.weight_decay = args.geopath_weight_decay self.args = args self.multiply_validation = 4 self.first_loss = None self.timesteps = None self.computing_reference_loss = False # updates self.state_cost = state_cost self.data_manifold_metric = data_manifold_metric self.whiten = args.whiten def forward(self, x0, x1, t, branch_idx): # return specific branch interpolant return self.geopath_nets[branch_idx](x0, x1, t) def on_train_start(self): self.first_loss = self.compute_initial_loss() print("first loss") print(self.first_loss) # to edit def compute_initial_loss(self): # Set all GeoPath networks to eval mode for net in self.geopath_nets: net.train(mode=False) total_loss = 0 total_count = 0 with torch.enable_grad(): self.t_val = [] for i in range( self.trainer.datamodule.num_timesteps - len(self.skipped_time_points) ): self.t_val.append( torch.rand( self.trainer.datamodule.batch_size * self.multiply_validation, requires_grad=True, ) ) self.computing_reference_loss = True with torch.no_grad(): old_alpha = self.flow_matcher.alpha self.flow_matcher.alpha = 0 for batch in self.trainer.datamodule.train_dataloader(): self.timesteps = torch.linspace( 0.0, 1.0, len(batch[0]["train_samples"][0]) ) loss = self._compute_loss( batch[0]["train_samples"][0], batch[0]["metric_samples"][0], ) print("initial loss") print(loss) total_loss += loss.item() total_count += 1 self.flow_matcher.alpha = old_alpha self.computing_reference_loss = False # Set all GeoPath networks back to training mode for net in self.geopath_nets: net.train(mode=True) return total_loss / total_count if total_count > 0 else 1.0 def _compute_loss(self, main_batch, metric_samples_batch=None): x0s = [main_batch["x0"][0]] w0s = [main_batch["x0"][1]] x1s_list = [] w1s_list = [] if self.branches > 1: for i in range(self.branches): x1s_list.append([main_batch[f"x1_{i+1}"][0]]) w1s_list.append([main_batch[f"x1_{i+1}"][1]]) else: x1s_list.append([main_batch["x1"][0]]) w1s_list.append([main_batch["x1"][1]]) if self.args.manifold: #changed if self.metric_clusters == 4: branch_sample_pairs = [ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) (metric_samples_batch[0], metric_samples_batch[3]), ] elif self.metric_clusters == 3: branch_sample_pairs = [ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) ] elif self.metric_clusters == 2 and self.branches == 2: branch_sample_pairs = [ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2) ] else: branch_sample_pairs = [ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) ] """samples0, samples1, samples2 = ( metric_samples_batch[0], metric_samples_batch[1], metric_samples_batch[2] )""" assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches" # compute sum of velocities for each branch loss = 0 velocities = [] for branch_idx in range(self.branches): ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx) for i in range(len(ts)): # calculate kinetic and potential energy of the predicted interpolant if self.args.manifold: start_samples, end_samples = branch_sample_pairs[branch_idx] samples = torch.cat([start_samples, end_samples], dim=0) #print("metric sample shape") #print(samples.shape) vel, _, _ = self.data_manifold_metric.calculate_velocity( xts[i], uts[i], samples, i ) else: vel = torch.sqrt((uts[i]**2).sum(dim =-1) + self.state_cost(xts[i])) #vel = (uts[i]**2).sum(dim =-1) velocities.append(vel) loss = torch.mean(torch.cat(velocities) ** 2) self.log( "BranchPathNet/mean_velocity_geopath", loss, on_step=False, on_epoch=True, prog_bar=True, ) return loss def _process_flow(self, x0s, x1s, branch_idx): ts, xts, uts = [], [], [] t_start = self.timesteps[0] i_start = 0 for i, (x0, x1) in enumerate(zip(x0s, x1s)): x0, x1 = torch.squeeze(x0), torch.squeeze(x1) if self.trainer.validating or self.computing_reference_loss: repeat_tuple = (self.multiply_validation, 1) + (1,) * ( len(x0.shape) - 2 ) x0 = x0.repeat(repeat_tuple) x1 = x1.repeat(repeat_tuple) if self.ot_sampler is not None: x0, x1 = self.ot_sampler.sample_plan( x0, x1, replace=True, ) if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]: t_start_next = self.timesteps[i + 2] else: t_start_next = self.timesteps[i + 1] t = None if self.trainer.validating or self.computing_reference_loss: t = self.t_val[i] t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow( x0, x1, t_start, t_start_next, branch_idx, training_geopath_net=True, t=t ) ts.append(t) xts.append(xt) uts.append(ut) t_start = t_start_next return ts, xts, uts def training_step(self, batch, batch_idx): if self.args.data_type in ["scrna", "tahoe"]: main_batch = batch[0]["train_samples"][0] metric_batch = batch[0]["metric_samples"][0] else: main_batch = batch["train_samples"][0] metric_batch = batch["metric_samples"][0] tangential_velocity_loss = self._compute_loss(main_batch, metric_batch) if self.first_loss: tangential_velocity_loss = tangential_velocity_loss / self.first_loss self.log( "BranchPathNet/mean_geopath_geopath", (self.flow_matcher.geopath_net_output.abs().mean()), on_step=False, on_epoch=True, prog_bar=True, ) self.log( "BranchPathNet/train_loss_geopath", tangential_velocity_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, ) return tangential_velocity_loss def validation_step(self, batch, batch_idx): if self.args.data_type in ["scrna", "tahoe"]: main_batch = batch[0]["val_samples"][0] metric_batch = batch[0]["metric_samples"][0] else: main_batch = batch["val_samples"][0] metric_batch = batch["metric_samples"][0] tangential_velocity_loss = self._compute_loss(main_batch, metric_batch) if self.first_loss: tangential_velocity_loss = tangential_velocity_loss / self.first_loss self.log( "BranchPathNet/val_loss_geopath", tangential_velocity_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, ) return tangential_velocity_loss def test_step(self, batch, batch_idx): main_batch = batch["test_samples"][0] metric_batch = batch["metric_samples"][0] x0 = main_batch["x0"][0] # [B, D] cloud_points = main_batch["dataset"][0] # full dataset, [N, D] x0 = x0.to(self.device) cloud_points = cloud_points.to(self.device) t_vals = [0.25, 0.5, 0.75] t_labels = ["t=1/4", "t=1/2", "t=3/4"] colors = { "x0": "#4D176C", "t=1/4": "#5C3B9D", "t=1/2": "#6172B9", "t=3/4": "#AC4E51", "x1": "#771F4F", } # Unwhiten cloud points if needed if self.whiten: cloud_points = torch.tensor( self.trainer.datamodule.scaler.inverse_transform(cloud_points.cpu().numpy()) ) for i in range(self.branches): geopath = self.geopath_nets[i] x1_key = f"x1_{i + 1}" if x1_key not in main_batch: print(f"Skipping branch {i + 1}: no final distribution {x1_key}") continue x1 = main_batch[x1_key][0].to(self.device) print(x1.shape) print(x0.shape) interpolated_points = [] with torch.no_grad(): for t_scalar in t_vals: t_tensor = torch.full((x0.shape[0], 1), t_scalar, device=self.device) # [B, 1] xt = geopath(x0, x1, t_tensor).cpu() # [B, D] if self.whiten: xt = torch.tensor( self.trainer.datamodule.scaler.inverse_transform(xt.numpy()) ) interpolated_points.append(xt) if self.whiten: x0_plot = torch.tensor( self.trainer.datamodule.scaler.inverse_transform(x0.cpu().numpy()) ) x1_plot = torch.tensor( self.trainer.datamodule.scaler.inverse_transform(x1.cpu().numpy()) ) else: x0_plot = x0.cpu() x1_plot = x1.cpu() # Plot fig = plt.figure(figsize=(6, 5)) ax = fig.add_subplot(111, projection="3d", computed_zorder=False) ax.view_init(elev=30, azim=-115, roll=0) plot_lidar(ax, cloud_points) # Initial x₀ ax.scatter( x0_plot[:, 0], x0_plot[:, 1], x0_plot[:, 2], s=15, alpha=1.0, color=colors["x0"], label="x₀", depthshade=True, edgecolors="white", linewidths=0.3 ) # Interpolated points for xt, t_label in zip(interpolated_points, t_labels): ax.scatter( xt[:, 0], xt[:, 1], xt[:, 2], s=15, alpha=1.0, color=colors[t_label], label=t_label, depthshade=True, edgecolors="white", linewidths=0.3 ) # Final x₁ ax.scatter( x1_plot[:, 0], x1_plot[:, 1], x1_plot[:, 2], s=15, alpha=1.0, color=colors["x1"], label="x₁", depthshade=True, edgecolors="white", linewidths=0.3 ) ax.legend() save_path = f"/raid/st512/branchsbm/figures/{self.args.data_type}/lidar_geopath_branch_{i+1}.png" plt.savefig(save_path, dpi=300) plt.close() def optimizer_step(self, *args, **kwargs): super().optimizer_step(*args, **kwargs) if isinstance(self.geopath_nets, EMA): self.geopath_nets.update_ema() def configure_optimizers(self): if self.optimizer_name == "adam": """optimizer = torch.optim.Adam( self.geopath_nets.parameters(), lr=self.lr, )""" optimizer = torch.optim.Adam( itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr ) elif self.optimizer_name == "adamw": """optimizer = torch.optim.AdamW( self.geopath_nets.parameters(), lr=self.lr, weight_decay=self.weight_decay, )""" optimizer = torch.optim.AdamW( itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr ) return optimizer