import sys sys.path.append("./BranchSBM") import torch import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from torch.utils.data import Dataset, DataLoader from state_costs.land import land_metric_tensor from state_costs.rbf import RBFNetwork class DataManifoldMetric: def __init__( self, args, skipped_time_points=None, datamodule=None, ): self.skipped_time_points = skipped_time_points self.datamodule = datamodule self.gamma = args.gamma_current self.rho = args.rho self.metric = args.velocity_metric self.n_centers = args.n_centers self.kappa = args.kappa self.metric_epochs = args.metric_epochs self.metric_patience = args.metric_patience self.lr = args.metric_lr self.alpha_metric = args.alpha_metric self.image_data = args.data_type == "image" self.accelerator = args.accelerator self.called_first_time = True self.args = args def calculate_metric(self, x_t, samples, current_timestep): if self.metric == "land": M_dd_x_t = ( land_metric_tensor(x_t, samples, self.gamma, self.rho) ** self.alpha_metric ) elif self.metric == "rbf": if self.called_first_time: self.rbf_networks = [] for timestep in range(self.datamodule.num_timesteps - 1): if timestep in self.skipped_time_points: continue print("Learning RBF networks, timestep: ", timestep) rbf_network = RBFNetwork( current_timestep=timestep, next_timestep=timestep + 1 + (1 if timestep + 1 in self.skipped_time_points else 0), n_centers=self.n_centers, kappa=self.kappa, lr=self.lr, datamodule=self.datamodule, args=self.args ) early_stop_callback = pl.callbacks.EarlyStopping( monitor="MetricModel/val_loss_learn_metric", patience=self.metric_patience, mode="min", ) trainer = pl.Trainer( max_epochs=self.metric_epochs, accelerator=self.accelerator, logger=WandbLogger(), num_sanity_val_steps=0, callbacks=( [early_stop_callback] if not self.image_data else None ), ) if self.image_data: self.dataloader = DataLoader( self.datamodule.all_data, batch_size=128, shuffle=True, ) trainer.fit(rbf_network, self.dataloader) else: trainer.fit(rbf_network, self.datamodule) self.rbf_networks.append(rbf_network) self.called_first_time = False print("Learning RBF networksss... Done") M_dd_x_t = self.rbf_networks[current_timestep].compute_metric( x_t, epsilon=self.rho, alpha=self.alpha_metric, image_hx=self.image_data, ) return M_dd_x_t def calculate_velocity(self, x_t, u_t, samples, timestep): if len(u_t.shape) > 2: u_t = u_t.reshape(u_t.shape[0], -1) x_t = x_t.reshape(x_t.shape[0], -1) M_dd_x_t = self.calculate_metric(x_t, samples, timestep).to(u_t.device) velocity = torch.sqrt(((u_t**2) * M_dd_x_t).sum(dim=-1)) ut_sum = (u_t**2).sum(dim=-1) metric_sum = M_dd_x_t.sum(dim=-1) return velocity, ut_sum, metric_sum