BranchSBM / branchsbm /branch_interpolant_train.py
sophiat44
model upload
5a87d8d
import sys
sys.path.append("./BranchSBM")
import torch
import pytorch_lightning as pl
from branchsbm.ema import EMA
import itertools
from utils import wasserstein_distance, plot_lidar
import matplotlib.pyplot as plt
class BranchInterpolantTrain(pl.LightningModule):
def __init__(
self,
flow_matcher,
args,
skipped_time_points: list = None,
ot_sampler=None,
state_cost=None,
data_manifold_metric=None,
):
super().__init__()
self.save_hyperparameters()
self.args = args
self.flow_matcher = flow_matcher
# list of geopath nets
self.geopath_nets = flow_matcher.geopath_nets
self.branches = len(self.geopath_nets)
self.metric_clusters = args.metric_clusters
self.ot_sampler = ot_sampler
self.skipped_time_points = skipped_time_points if skipped_time_points else []
self.optimizer_name = args.geopath_optimizer
self.lr = args.geopath_lr
self.weight_decay = args.geopath_weight_decay
self.args = args
self.multiply_validation = 4
self.first_loss = None
self.timesteps = None
self.computing_reference_loss = False
# updates
self.state_cost = state_cost
self.data_manifold_metric = data_manifold_metric
self.whiten = args.whiten
def forward(self, x0, x1, t, branch_idx):
# return specific branch interpolant
return self.geopath_nets[branch_idx](x0, x1, t)
def on_train_start(self):
self.first_loss = self.compute_initial_loss()
print("first loss")
print(self.first_loss)
# to edit
def compute_initial_loss(self):
# Set all GeoPath networks to eval mode
for net in self.geopath_nets:
net.train(mode=False)
total_loss = 0
total_count = 0
with torch.enable_grad():
self.t_val = []
for i in range(
self.trainer.datamodule.num_timesteps - len(self.skipped_time_points)
):
self.t_val.append(
torch.rand(
self.trainer.datamodule.batch_size * self.multiply_validation,
requires_grad=True,
)
)
self.computing_reference_loss = True
with torch.no_grad():
old_alpha = self.flow_matcher.alpha
self.flow_matcher.alpha = 0
for batch in self.trainer.datamodule.train_dataloader():
self.timesteps = torch.linspace(
0.0, 1.0, len(batch[0]["train_samples"][0])
)
loss = self._compute_loss(
batch[0]["train_samples"][0],
batch[0]["metric_samples"][0],
)
print("initial loss")
print(loss)
total_loss += loss.item()
total_count += 1
self.flow_matcher.alpha = old_alpha
self.computing_reference_loss = False
# Set all GeoPath networks back to training mode
for net in self.geopath_nets:
net.train(mode=True)
return total_loss / total_count if total_count > 0 else 1.0
def _compute_loss(self, main_batch, metric_samples_batch=None):
x0s = [main_batch["x0"][0]]
w0s = [main_batch["x0"][1]]
x1s_list = []
w1s_list = []
if self.branches > 1:
for i in range(self.branches):
x1s_list.append([main_batch[f"x1_{i+1}"][0]])
w1s_list.append([main_batch[f"x1_{i+1}"][1]])
else:
x1s_list.append([main_batch["x1"][0]])
w1s_list.append([main_batch["x1"][1]])
if self.args.manifold:
#changed
if self.metric_clusters == 4:
branch_sample_pairs = [
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
(metric_samples_batch[0], metric_samples_batch[3]),
]
elif self.metric_clusters == 3:
branch_sample_pairs = [
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
(metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
]
elif self.metric_clusters == 2 and self.branches == 2:
branch_sample_pairs = [
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
]
else:
branch_sample_pairs = [
(metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
]
"""samples0, samples1, samples2 = (
metric_samples_batch[0],
metric_samples_batch[1],
metric_samples_batch[2]
)"""
assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
# compute sum of velocities for each branch
loss = 0
velocities = []
for branch_idx in range(self.branches):
ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx)
for i in range(len(ts)):
# calculate kinetic and potential energy of the predicted interpolant
if self.args.manifold:
start_samples, end_samples = branch_sample_pairs[branch_idx]
samples = torch.cat([start_samples, end_samples], dim=0)
#print("metric sample shape")
#print(samples.shape)
vel, _, _ = self.data_manifold_metric.calculate_velocity(
xts[i], uts[i], samples, i
)
else:
vel = torch.sqrt((uts[i]**2).sum(dim =-1) + self.state_cost(xts[i]))
#vel = (uts[i]**2).sum(dim =-1)
velocities.append(vel)
loss = torch.mean(torch.cat(velocities) ** 2)
self.log(
"BranchPathNet/mean_velocity_geopath",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
)
return loss
def _process_flow(self, x0s, x1s, branch_idx):
ts, xts, uts = [], [], []
t_start = self.timesteps[0]
i_start = 0
for i, (x0, x1) in enumerate(zip(x0s, x1s)):
x0, x1 = torch.squeeze(x0), torch.squeeze(x1)
if self.trainer.validating or self.computing_reference_loss:
repeat_tuple = (self.multiply_validation, 1) + (1,) * (
len(x0.shape) - 2
)
x0 = x0.repeat(repeat_tuple)
x1 = x1.repeat(repeat_tuple)
if self.ot_sampler is not None:
x0, x1 = self.ot_sampler.sample_plan(
x0,
x1,
replace=True,
)
if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]:
t_start_next = self.timesteps[i + 2]
else:
t_start_next = self.timesteps[i + 1]
t = None
if self.trainer.validating or self.computing_reference_loss:
t = self.t_val[i]
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(
x0, x1, t_start, t_start_next, branch_idx, training_geopath_net=True, t=t
)
ts.append(t)
xts.append(xt)
uts.append(ut)
t_start = t_start_next
return ts, xts, uts
def training_step(self, batch, batch_idx):
if self.args.data_type in ["scrna", "tahoe"]:
main_batch = batch[0]["train_samples"][0]
metric_batch = batch[0]["metric_samples"][0]
else:
main_batch = batch["train_samples"][0]
metric_batch = batch["metric_samples"][0]
tangential_velocity_loss = self._compute_loss(main_batch, metric_batch)
if self.first_loss:
tangential_velocity_loss = tangential_velocity_loss / self.first_loss
self.log(
"BranchPathNet/mean_geopath_geopath",
(self.flow_matcher.geopath_net_output.abs().mean()),
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"BranchPathNet/train_loss_geopath",
tangential_velocity_loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
)
return tangential_velocity_loss
def validation_step(self, batch, batch_idx):
if self.args.data_type in ["scrna", "tahoe"]:
main_batch = batch[0]["val_samples"][0]
metric_batch = batch[0]["metric_samples"][0]
else:
main_batch = batch["val_samples"][0]
metric_batch = batch["metric_samples"][0]
tangential_velocity_loss = self._compute_loss(main_batch, metric_batch)
if self.first_loss:
tangential_velocity_loss = tangential_velocity_loss / self.first_loss
self.log(
"BranchPathNet/val_loss_geopath",
tangential_velocity_loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return tangential_velocity_loss
def test_step(self, batch, batch_idx):
main_batch = batch["test_samples"][0]
metric_batch = batch["metric_samples"][0]
x0 = main_batch["x0"][0] # [B, D]
cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
x0 = x0.to(self.device)
cloud_points = cloud_points.to(self.device)
t_vals = [0.25, 0.5, 0.75]
t_labels = ["t=1/4", "t=1/2", "t=3/4"]
colors = {
"x0": "#4D176C",
"t=1/4": "#5C3B9D",
"t=1/2": "#6172B9",
"t=3/4": "#AC4E51",
"x1": "#771F4F",
}
# Unwhiten cloud points if needed
if self.whiten:
cloud_points = torch.tensor(
self.trainer.datamodule.scaler.inverse_transform(cloud_points.cpu().numpy())
)
for i in range(self.branches):
geopath = self.geopath_nets[i]
x1_key = f"x1_{i + 1}"
if x1_key not in main_batch:
print(f"Skipping branch {i + 1}: no final distribution {x1_key}")
continue
x1 = main_batch[x1_key][0].to(self.device)
print(x1.shape)
print(x0.shape)
interpolated_points = []
with torch.no_grad():
for t_scalar in t_vals:
t_tensor = torch.full((x0.shape[0], 1), t_scalar, device=self.device) # [B, 1]
xt = geopath(x0, x1, t_tensor).cpu() # [B, D]
if self.whiten:
xt = torch.tensor(
self.trainer.datamodule.scaler.inverse_transform(xt.numpy())
)
interpolated_points.append(xt)
if self.whiten:
x0_plot = torch.tensor(
self.trainer.datamodule.scaler.inverse_transform(x0.cpu().numpy())
)
x1_plot = torch.tensor(
self.trainer.datamodule.scaler.inverse_transform(x1.cpu().numpy())
)
else:
x0_plot = x0.cpu()
x1_plot = x1.cpu()
# Plot
fig = plt.figure(figsize=(6, 5))
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
ax.view_init(elev=30, azim=-115, roll=0)
plot_lidar(ax, cloud_points)
# Initial x₀
ax.scatter(
x0_plot[:, 0], x0_plot[:, 1], x0_plot[:, 2],
s=15, alpha=1.0, color=colors["x0"], label="x₀", depthshade=True,
edgecolors="white",
linewidths=0.3
)
# Interpolated points
for xt, t_label in zip(interpolated_points, t_labels):
ax.scatter(
xt[:, 0], xt[:, 1], xt[:, 2],
s=15, alpha=1.0, color=colors[t_label], label=t_label, depthshade=True,
edgecolors="white",
linewidths=0.3
)
# Final x₁
ax.scatter(
x1_plot[:, 0], x1_plot[:, 1], x1_plot[:, 2],
s=15, alpha=1.0, color=colors["x1"], label="x₁", depthshade=True,
edgecolors="white",
linewidths=0.3
)
ax.legend()
save_path = f"/raid/st512/branchsbm/figures/{self.args.data_type}/lidar_geopath_branch_{i+1}.png"
plt.savefig(save_path, dpi=300)
plt.close()
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
if isinstance(self.geopath_nets, EMA):
self.geopath_nets.update_ema()
def configure_optimizers(self):
if self.optimizer_name == "adam":
"""optimizer = torch.optim.Adam(
self.geopath_nets.parameters(),
lr=self.lr,
)"""
optimizer = torch.optim.Adam(
itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr
)
elif self.optimizer_name == "adamw":
"""optimizer = torch.optim.AdamW(
self.geopath_nets.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)"""
optimizer = torch.optim.AdamW(
itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr
)
return optimizer