File size: 5,180 Bytes
5a87d8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import pytorch_lightning as pl
import torch
from sklearn.cluster import KMeans
import numpy as np
class RBFNetwork(pl.LightningModule):
def __init__(
self,
current_timestep,
next_timestep,
n_centers: int = 100,
kappa: float = 1.0,
lr=1e-2,
datamodule=None,
image_data=False,
args=None
):
super().__init__()
self.K = n_centers
self.current_timestep = current_timestep
self.next_timestep = next_timestep
self.clustering_model = KMeans(n_clusters=self.K)
self.kappa = kappa
self.last_val_loss = 1
self.lr = lr
self.W = torch.nn.Parameter(torch.rand(self.K, 1))
self.datamodule = datamodule
self.image_data = image_data
self.args = args
def on_before_zero_grad(self, *args, **kwargs):
self.W.data = torch.clamp(self.W.data, min=0.0001)
def on_train_start(self):
with torch.no_grad():
batch = next(iter(self.trainer.datamodule.train_dataloader()))
metric_samples = batch[0]["metric_samples"][0]
all_data = torch.cat(metric_samples)
data_to_fit = all_data
print("Fitting Clustering model...")
self.clustering_model.fit(data_to_fit)
clusters = (
self.calculate_centroids(all_data, self.clustering_model.labels_)
if self.image_data
else self.clustering_model.cluster_centers_
)
self.C = torch.tensor(clusters, dtype=torch.float32).to(self.device)
labels = self.clustering_model.labels_
sigmas = np.zeros((self.K, 1))
for k in range(self.K):
points = all_data[labels == k, :]
variance = ((points - clusters[k]) ** 2).mean(axis=0)
sigmas[k, :] = np.sqrt(
variance.sum() if self.image_data else variance.mean()
)
self.lamda = torch.tensor(
0.5 / (self.kappa * sigmas) ** 2, dtype=torch.float32
).to(self.device)
def forward(self, x):
if len(x.shape) > 2:
x = x.reshape(x.shape[0], -1).to(self.C.device)
x = x.to(self.C.device)
dist2 = torch.cdist(x, self.C) ** 2
self.phi_x = torch.exp(-0.5 * self.lamda[None, :, :] * dist2[:, :, None])
h_x = (self.W.to(x.device) * self.phi_x).sum(dim=1)
return h_x
def training_step(self, batch, batch_idx):
if self.args.data_type == "scrna" or self.args.data_type == "tahoe":
main_batch = batch[0]["train_samples"][0]
else:
main_batch = batch["train_samples"][0]
x0 = main_batch["x0"][0]
if self.args.branches == 1:
x1 = main_batch["x1"][0]
inputs = torch.cat([x0, x1], dim=0).to(self.device)
else:
x1_1 = main_batch["x1_1"][0]
x1_2 = main_batch["x1_2"][0]
inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device)
print("inputs shape")
print(inputs.shape)
loss = ((1 - self.forward(inputs)) ** 2).mean()
self.log(
"MetricModel/train_loss_learn_metric",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
)
return loss
def validation_step(self, batch, batch_idx):
if self.args.data_type == "scrna" or self.args.data_type == "tahoe":
main_batch = batch[0]["val_samples"][0]
else:
main_batch = batch["val_samples"][0]
x0 = main_batch["x0"][0]
if self.args.branches == 1:
x1 = main_batch["x1"][0]
inputs = torch.cat([x0, x1], dim=0).to(self.device)
else:
x1_1 = main_batch["x1_1"][0]
x1_2 = main_batch["x1_2"][0]
inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device)
h = self.forward(inputs)
loss = ((1 - h) ** 2).mean()
self.log(
"MetricModel/val_loss_learn_metric",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
)
self.last_val_loss = loss.detach()
return loss
def calculate_centroids(self, all_data, labels):
unique_labels = np.unique(labels)
centroids = np.zeros((len(unique_labels), all_data.shape[1]))
for i, label in enumerate(unique_labels):
centroids[i] = all_data[labels == label].mean(axis=0)
return centroids
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def compute_metric(self, x, alpha=1, epsilon=1e-2, image_hx=False):
if epsilon < 0:
epsilon = (1 - self.last_val_loss.item()) / abs(epsilon)
h_x = self.forward(x)
if image_hx:
h_x = 1 - torch.abs(1 - h_x)
M_x = 1 / (h_x**alpha + epsilon)
else:
M_x = 1 / (h_x + epsilon) ** alpha
return M_x |