| |
|
|
| import sys |
| import torch |
| import pytorch_lightning as pl |
| from pytorch_lightning.loggers import WandbLogger |
| from torch.utils.data import Dataset, DataLoader |
|
|
| from .land import land_metric_tensor |
| from .rbf import RBFNetwork |
|
|
| class DataManifoldMetric: |
| def __init__( |
| self, |
| args, |
| skipped_time_points=None, |
| datamodule=None, |
| ): |
| self.skipped_time_points = skipped_time_points |
| self.datamodule = datamodule |
|
|
| self.gamma = args.gamma_current |
| self.rho = args.rho |
| self.metric = args.velocity_metric |
| self.n_centers = args.n_centers |
| self.kappa = args.kappa |
| self.metric_epochs = args.metric_epochs |
| self.metric_patience = args.metric_patience |
| self.lr = args.metric_lr |
| self.alpha_metric = args.alpha_metric |
| self.image_data = args.data_type == "image" |
| self.accelerator = args.accelerator |
|
|
| self.called_first_time = True |
| self.args = args |
|
|
| def calculate_metric(self, x_t, samples, current_timestep): |
| if self.metric == "land": |
| M_dd_x_t = ( |
| land_metric_tensor(x_t, samples, self.gamma, self.rho) |
| ** self.alpha_metric |
| ) |
| |
| elif self.metric == "rbf": |
| if self.called_first_time: |
| |
| print("Learning single RBF network for all timesteps") |
| self.rbf_network = RBFNetwork( |
| current_timestep=0, |
| next_timestep=self.datamodule.num_timesteps - 1, |
| n_centers=self.n_centers, |
| kappa=self.kappa, |
| lr=self.lr, |
| datamodule=self.datamodule, |
| args=self.args |
| ) |
| early_stop_callback = pl.callbacks.EarlyStopping( |
| monitor="MetricModel/train_loss_learn_metric_epoch", |
| patience=self.metric_patience, |
| mode="min", |
| ) |
| trainer = pl.Trainer( |
| max_epochs=self.metric_epochs, |
| accelerator=self.accelerator, |
| logger=WandbLogger(), |
| num_sanity_val_steps=0, |
| callbacks=( |
| [early_stop_callback] if not self.image_data else None |
| ), |
| ) |
| if self.image_data: |
| self.dataloader = DataLoader( |
| self.datamodule.all_data, |
| batch_size=128, |
| shuffle=True, |
| ) |
| trainer.fit(self.rbf_network, self.dataloader) |
| else: |
| trainer.fit(self.rbf_network, self.datamodule) |
| self.called_first_time = False |
| print("Learning RBF network... Done") |
| M_dd_x_t = self.rbf_network.compute_metric( |
| x_t, |
| epsilon=self.rho, |
| alpha=self.alpha_metric, |
| image_hx=self.image_data, |
| ) |
| return M_dd_x_t |
|
|
| def calculate_velocity(self, x_t, u_t, samples, timestep): |
|
|
| if len(u_t.shape) > 2: |
| u_t = u_t.reshape(u_t.shape[0], -1) |
| x_t = x_t.reshape(x_t.shape[0], -1) |
| M_dd_x_t = self.calculate_metric(x_t, samples, timestep).to(u_t.device) |
|
|
| |
| |
| velocity = torch.sqrt(torch.clamp(((u_t**2) * M_dd_x_t).sum(dim=-1), min=0)) |
| ut_sum = (u_t**2).sum(dim=-1) |
| metric_sum = M_dd_x_t.sum(dim=-1) |
| return velocity, ut_sum, metric_sum |
|
|