| | |
| |
|
| | import pytorch_lightning as pl |
| | import torch |
| | from sklearn.cluster import KMeans |
| | import numpy as np |
| |
|
| | class RBFNetwork(pl.LightningModule): |
| | def __init__( |
| | self, |
| | current_timestep, |
| | next_timestep, |
| | n_centers: int = 100, |
| | kappa: float = 1.0, |
| | lr=1e-2, |
| | datamodule=None, |
| | image_data=False, |
| | args=None |
| | ): |
| | super().__init__() |
| | self.K = n_centers |
| | self.current_timestep = current_timestep |
| | self.next_timestep = next_timestep |
| | self.clustering_model = KMeans(n_clusters=self.K) |
| | self.kappa = kappa |
| | self.last_val_loss = 1 |
| | self.lr = lr |
| | self.W = torch.nn.Parameter(torch.rand(self.K, 1)) |
| | self.datamodule = datamodule |
| | self.image_data = image_data |
| | self.args = args |
| |
|
| | def on_before_zero_grad(self, *args, **kwargs): |
| | self.W.data = torch.clamp(self.W.data, min=0.0001) |
| |
|
| | def on_train_start(self): |
| | with torch.no_grad(): |
| | |
| | batch = next(iter(self.trainer.datamodule.train_dataloader())) |
| | |
| | metric_samples = batch[0]["metric_samples"][0] |
| | all_data = torch.cat(metric_samples) |
| | data_to_fit = all_data |
| |
|
| | print("Fitting Clustering model...") |
| | self.clustering_model.fit(data_to_fit) |
| |
|
| | clusters = ( |
| | self.calculate_centroids(all_data, self.clustering_model.labels_) |
| | if self.image_data |
| | else self.clustering_model.cluster_centers_ |
| | ) |
| |
|
| | self.C = torch.tensor(clusters, dtype=torch.float32).to(self.device) |
| | labels = self.clustering_model.labels_ |
| | sigmas = np.zeros((self.K, 1)) |
| |
|
| | for k in range(self.K): |
| | points = all_data[labels == k, :] |
| | variance = ((points - clusters[k]) ** 2).mean(axis=0) |
| | sigmas[k, :] = np.sqrt( |
| | variance.sum() if self.image_data else variance.mean() |
| | ) |
| |
|
| | |
| | sigmas = np.maximum(sigmas, 1e-6) |
| | |
| | self.lamda = torch.tensor( |
| | 0.5 / (self.kappa * sigmas) ** 2, dtype=torch.float32 |
| | ).to(self.device) |
| |
|
| | def forward(self, x): |
| | if len(x.shape) > 2: |
| | x = x.reshape(x.shape[0], -1).to(self.C.device) |
| | |
| | x = x.to(self.C.device) |
| | dist2 = torch.cdist(x, self.C) ** 2 |
| | self.phi_x = torch.exp(-0.5 * self.lamda[None, :, :] * dist2[:, :, None]) |
| | |
| | h_x = (self.W.to(x.device) * self.phi_x).sum(dim=1) |
| | |
| | return h_x |
| |
|
| | def training_step(self, batch, batch_idx): |
| | if self.args.data_type == "scrna" or self.args.data_type == "tahoe": |
| | main_batch = batch[0]["train_samples"][0] |
| | else: |
| | main_batch = batch["train_samples"][0] |
| |
|
| | x0 = main_batch["x0"][0] |
| | if self.args.branches == 1: |
| | x1 = main_batch["x1"][0] |
| | inputs = torch.cat([x0, x1], dim=0).to(self.device) |
| | else: |
| | x1_1 = main_batch["x1_1"][0] |
| | x1_2 = main_batch["x1_2"][0] |
| | |
| | inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device) |
| | print("inputs shape") |
| | print(inputs.shape) |
| | |
| | loss = ((1 - self.forward(inputs)) ** 2).mean() |
| | self.log( |
| | "MetricModel/train_loss_learn_metric", |
| | loss, |
| | on_step=True, |
| | on_epoch=True, |
| | prog_bar=True, |
| | ) |
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | if self.args.data_type == "scrna" or self.args.data_type == "tahoe": |
| | main_batch = batch[0]["val_samples"][0] |
| | else: |
| | main_batch = batch["val_samples"][0] |
| |
|
| | x0 = main_batch["x0"][0] |
| | if self.args.branches == 1: |
| | x1 = main_batch["x1"][0] |
| | inputs = torch.cat([x0, x1], dim=0).to(self.device) |
| | else: |
| | x1_1 = main_batch["x1_1"][0] |
| | x1_2 = main_batch["x1_2"][0] |
| | |
| | inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device) |
| |
|
| | h = self.forward(inputs) |
| | |
| | loss = ((1 - h) ** 2).mean() |
| | self.log( |
| | "MetricModel/val_loss_learn_metric", |
| | loss, |
| | on_step=True, |
| | on_epoch=True, |
| | prog_bar=True, |
| | ) |
| | self.last_val_loss = loss.detach() |
| | return loss |
| |
|
| | def calculate_centroids(self, all_data, labels): |
| | unique_labels = np.unique(labels) |
| | centroids = np.zeros((len(unique_labels), all_data.shape[1])) |
| | for i, label in enumerate(unique_labels): |
| | centroids[i] = all_data[labels == label].mean(axis=0) |
| | return centroids |
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) |
| | return optimizer |
| |
|
| | def compute_metric(self, x, alpha=1, epsilon=1e-2, image_hx=False): |
| | if epsilon < 0: |
| | epsilon = (1 - float(self.last_val_loss)) / abs(epsilon) |
| | h_x = self.forward(x) |
| | if image_hx: |
| | h_x = 1 - torch.abs(1 - h_x) |
| | M_x = 1 / (h_x**alpha + epsilon) |
| | else: |
| | M_x = 1 / (h_x + epsilon) ** alpha |
| | return M_x |