BranchSBM / state_costs /metric_factory.py
sophiat44
model upload
5a87d8d
import sys
sys.path.append("./BranchSBM")
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import Dataset, DataLoader
from state_costs.land import land_metric_tensor
from state_costs.rbf import RBFNetwork
class DataManifoldMetric:
def __init__(
self,
args,
skipped_time_points=None,
datamodule=None,
):
self.skipped_time_points = skipped_time_points
self.datamodule = datamodule
self.gamma = args.gamma_current
self.rho = args.rho
self.metric = args.velocity_metric
self.n_centers = args.n_centers
self.kappa = args.kappa
self.metric_epochs = args.metric_epochs
self.metric_patience = args.metric_patience
self.lr = args.metric_lr
self.alpha_metric = args.alpha_metric
self.image_data = args.data_type == "image"
self.accelerator = args.accelerator
self.called_first_time = True
self.args = args
def calculate_metric(self, x_t, samples, current_timestep):
if self.metric == "land":
M_dd_x_t = (
land_metric_tensor(x_t, samples, self.gamma, self.rho)
** self.alpha_metric
)
elif self.metric == "rbf":
if self.called_first_time:
self.rbf_networks = []
for timestep in range(self.datamodule.num_timesteps - 1):
if timestep in self.skipped_time_points:
continue
print("Learning RBF networks, timestep: ", timestep)
rbf_network = RBFNetwork(
current_timestep=timestep,
next_timestep=timestep
+ 1
+ (1 if timestep + 1 in self.skipped_time_points else 0),
n_centers=self.n_centers,
kappa=self.kappa,
lr=self.lr,
datamodule=self.datamodule,
args=self.args
)
early_stop_callback = pl.callbacks.EarlyStopping(
monitor="MetricModel/val_loss_learn_metric",
patience=self.metric_patience,
mode="min",
)
trainer = pl.Trainer(
max_epochs=self.metric_epochs,
accelerator=self.accelerator,
logger=WandbLogger(),
num_sanity_val_steps=0,
callbacks=(
[early_stop_callback] if not self.image_data else None
),
)
if self.image_data:
self.dataloader = DataLoader(
self.datamodule.all_data,
batch_size=128,
shuffle=True,
)
trainer.fit(rbf_network, self.dataloader)
else:
trainer.fit(rbf_network, self.datamodule)
self.rbf_networks.append(rbf_network)
self.called_first_time = False
print("Learning RBF networksss... Done")
M_dd_x_t = self.rbf_networks[current_timestep].compute_metric(
x_t,
epsilon=self.rho,
alpha=self.alpha_metric,
image_hx=self.image_data,
)
return M_dd_x_t
def calculate_velocity(self, x_t, u_t, samples, timestep):
if len(u_t.shape) > 2:
u_t = u_t.reshape(u_t.shape[0], -1)
x_t = x_t.reshape(x_t.shape[0], -1)
M_dd_x_t = self.calculate_metric(x_t, samples, timestep).to(u_t.device)
velocity = torch.sqrt(((u_t**2) * M_dd_x_t).sum(dim=-1))
ut_sum = (u_t**2).sum(dim=-1)
metric_sum = M_dd_x_t.sum(dim=-1)
return velocity, ut_sum, metric_sum