|
|
import sys |
|
|
sys.path.append("./BranchSBM") |
|
|
import torch |
|
|
import pytorch_lightning as pl |
|
|
from branchsbm.ema import EMA |
|
|
import itertools |
|
|
from utils import wasserstein_distance, plot_lidar |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
class BranchInterpolantTrain(pl.LightningModule): |
|
|
def __init__( |
|
|
self, |
|
|
flow_matcher, |
|
|
args, |
|
|
skipped_time_points: list = None, |
|
|
ot_sampler=None, |
|
|
|
|
|
state_cost=None, |
|
|
data_manifold_metric=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.save_hyperparameters() |
|
|
self.args = args |
|
|
|
|
|
self.flow_matcher = flow_matcher |
|
|
|
|
|
|
|
|
self.geopath_nets = flow_matcher.geopath_nets |
|
|
self.branches = len(self.geopath_nets) |
|
|
self.metric_clusters = args.metric_clusters |
|
|
|
|
|
self.ot_sampler = ot_sampler |
|
|
self.skipped_time_points = skipped_time_points if skipped_time_points else [] |
|
|
self.optimizer_name = args.geopath_optimizer |
|
|
self.lr = args.geopath_lr |
|
|
self.weight_decay = args.geopath_weight_decay |
|
|
self.args = args |
|
|
self.multiply_validation = 4 |
|
|
|
|
|
self.first_loss = None |
|
|
self.timesteps = None |
|
|
self.computing_reference_loss = False |
|
|
|
|
|
|
|
|
self.state_cost = state_cost |
|
|
self.data_manifold_metric = data_manifold_metric |
|
|
self.whiten = args.whiten |
|
|
|
|
|
def forward(self, x0, x1, t, branch_idx): |
|
|
|
|
|
return self.geopath_nets[branch_idx](x0, x1, t) |
|
|
|
|
|
def on_train_start(self): |
|
|
self.first_loss = self.compute_initial_loss() |
|
|
print("first loss") |
|
|
print(self.first_loss) |
|
|
|
|
|
|
|
|
def compute_initial_loss(self): |
|
|
|
|
|
for net in self.geopath_nets: |
|
|
net.train(mode=False) |
|
|
|
|
|
total_loss = 0 |
|
|
total_count = 0 |
|
|
with torch.enable_grad(): |
|
|
self.t_val = [] |
|
|
for i in range( |
|
|
self.trainer.datamodule.num_timesteps - len(self.skipped_time_points) |
|
|
): |
|
|
self.t_val.append( |
|
|
torch.rand( |
|
|
self.trainer.datamodule.batch_size * self.multiply_validation, |
|
|
requires_grad=True, |
|
|
) |
|
|
) |
|
|
self.computing_reference_loss = True |
|
|
with torch.no_grad(): |
|
|
old_alpha = self.flow_matcher.alpha |
|
|
self.flow_matcher.alpha = 0 |
|
|
for batch in self.trainer.datamodule.train_dataloader(): |
|
|
|
|
|
self.timesteps = torch.linspace( |
|
|
0.0, 1.0, len(batch[0]["train_samples"][0]) |
|
|
) |
|
|
|
|
|
loss = self._compute_loss( |
|
|
batch[0]["train_samples"][0], |
|
|
batch[0]["metric_samples"][0], |
|
|
) |
|
|
print("initial loss") |
|
|
print(loss) |
|
|
total_loss += loss.item() |
|
|
total_count += 1 |
|
|
self.flow_matcher.alpha = old_alpha |
|
|
|
|
|
self.computing_reference_loss = False |
|
|
|
|
|
|
|
|
for net in self.geopath_nets: |
|
|
net.train(mode=True) |
|
|
return total_loss / total_count if total_count > 0 else 1.0 |
|
|
|
|
|
def _compute_loss(self, main_batch, metric_samples_batch=None): |
|
|
|
|
|
x0s = [main_batch["x0"][0]] |
|
|
w0s = [main_batch["x0"][1]] |
|
|
|
|
|
x1s_list = [] |
|
|
w1s_list = [] |
|
|
|
|
|
if self.branches > 1: |
|
|
for i in range(self.branches): |
|
|
x1s_list.append([main_batch[f"x1_{i+1}"][0]]) |
|
|
w1s_list.append([main_batch[f"x1_{i+1}"][1]]) |
|
|
else: |
|
|
x1s_list.append([main_batch["x1"][0]]) |
|
|
w1s_list.append([main_batch["x1"][1]]) |
|
|
|
|
|
if self.args.manifold: |
|
|
|
|
|
if self.metric_clusters == 4: |
|
|
branch_sample_pairs = [ |
|
|
(metric_samples_batch[0], metric_samples_batch[1]), |
|
|
(metric_samples_batch[0], metric_samples_batch[2]), |
|
|
(metric_samples_batch[0], metric_samples_batch[3]), |
|
|
] |
|
|
elif self.metric_clusters == 3: |
|
|
branch_sample_pairs = [ |
|
|
(metric_samples_batch[0], metric_samples_batch[1]), |
|
|
(metric_samples_batch[0], metric_samples_batch[2]), |
|
|
] |
|
|
elif self.metric_clusters == 2 and self.branches == 2: |
|
|
branch_sample_pairs = [ |
|
|
(metric_samples_batch[0], metric_samples_batch[1]), |
|
|
(metric_samples_batch[0], metric_samples_batch[1]), |
|
|
] |
|
|
else: |
|
|
branch_sample_pairs = [ |
|
|
(metric_samples_batch[0], metric_samples_batch[1]), |
|
|
] |
|
|
"""samples0, samples1, samples2 = ( |
|
|
metric_samples_batch[0], |
|
|
metric_samples_batch[1], |
|
|
metric_samples_batch[2] |
|
|
)""" |
|
|
|
|
|
assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches" |
|
|
|
|
|
|
|
|
loss = 0 |
|
|
velocities = [] |
|
|
for branch_idx in range(self.branches): |
|
|
|
|
|
ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx) |
|
|
|
|
|
for i in range(len(ts)): |
|
|
|
|
|
if self.args.manifold: |
|
|
start_samples, end_samples = branch_sample_pairs[branch_idx] |
|
|
|
|
|
samples = torch.cat([start_samples, end_samples], dim=0) |
|
|
|
|
|
|
|
|
vel, _, _ = self.data_manifold_metric.calculate_velocity( |
|
|
xts[i], uts[i], samples, i |
|
|
) |
|
|
else: |
|
|
vel = torch.sqrt((uts[i]**2).sum(dim =-1) + self.state_cost(xts[i])) |
|
|
|
|
|
|
|
|
velocities.append(vel) |
|
|
|
|
|
loss = torch.mean(torch.cat(velocities) ** 2) |
|
|
|
|
|
self.log( |
|
|
"BranchPathNet/mean_velocity_geopath", |
|
|
loss, |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
) |
|
|
|
|
|
return loss |
|
|
|
|
|
def _process_flow(self, x0s, x1s, branch_idx): |
|
|
ts, xts, uts = [], [], [] |
|
|
t_start = self.timesteps[0] |
|
|
i_start = 0 |
|
|
|
|
|
for i, (x0, x1) in enumerate(zip(x0s, x1s)): |
|
|
x0, x1 = torch.squeeze(x0), torch.squeeze(x1) |
|
|
if self.trainer.validating or self.computing_reference_loss: |
|
|
repeat_tuple = (self.multiply_validation, 1) + (1,) * ( |
|
|
len(x0.shape) - 2 |
|
|
) |
|
|
x0 = x0.repeat(repeat_tuple) |
|
|
x1 = x1.repeat(repeat_tuple) |
|
|
|
|
|
if self.ot_sampler is not None: |
|
|
x0, x1 = self.ot_sampler.sample_plan( |
|
|
x0, |
|
|
x1, |
|
|
replace=True, |
|
|
) |
|
|
if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]: |
|
|
t_start_next = self.timesteps[i + 2] |
|
|
else: |
|
|
t_start_next = self.timesteps[i + 1] |
|
|
|
|
|
t = None |
|
|
if self.trainer.validating or self.computing_reference_loss: |
|
|
t = self.t_val[i] |
|
|
|
|
|
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow( |
|
|
x0, x1, t_start, t_start_next, branch_idx, training_geopath_net=True, t=t |
|
|
) |
|
|
ts.append(t) |
|
|
xts.append(xt) |
|
|
uts.append(ut) |
|
|
t_start = t_start_next |
|
|
|
|
|
return ts, xts, uts |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
if self.args.data_type in ["scrna", "tahoe"]: |
|
|
main_batch = batch[0]["train_samples"][0] |
|
|
metric_batch = batch[0]["metric_samples"][0] |
|
|
else: |
|
|
main_batch = batch["train_samples"][0] |
|
|
metric_batch = batch["metric_samples"][0] |
|
|
|
|
|
tangential_velocity_loss = self._compute_loss(main_batch, metric_batch) |
|
|
|
|
|
if self.first_loss: |
|
|
tangential_velocity_loss = tangential_velocity_loss / self.first_loss |
|
|
|
|
|
self.log( |
|
|
"BranchPathNet/mean_geopath_geopath", |
|
|
(self.flow_matcher.geopath_net_output.abs().mean()), |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
) |
|
|
|
|
|
self.log( |
|
|
"BranchPathNet/train_loss_geopath", |
|
|
tangential_velocity_loss, |
|
|
on_step=True, |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
) |
|
|
|
|
|
return tangential_velocity_loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
if self.args.data_type in ["scrna", "tahoe"]: |
|
|
main_batch = batch[0]["val_samples"][0] |
|
|
metric_batch = batch[0]["metric_samples"][0] |
|
|
else: |
|
|
main_batch = batch["val_samples"][0] |
|
|
metric_batch = batch["metric_samples"][0] |
|
|
|
|
|
tangential_velocity_loss = self._compute_loss(main_batch, metric_batch) |
|
|
if self.first_loss: |
|
|
tangential_velocity_loss = tangential_velocity_loss / self.first_loss |
|
|
|
|
|
self.log( |
|
|
"BranchPathNet/val_loss_geopath", |
|
|
tangential_velocity_loss, |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
) |
|
|
return tangential_velocity_loss |
|
|
|
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
main_batch = batch["test_samples"][0] |
|
|
metric_batch = batch["metric_samples"][0] |
|
|
|
|
|
x0 = main_batch["x0"][0] |
|
|
cloud_points = main_batch["dataset"][0] |
|
|
|
|
|
x0 = x0.to(self.device) |
|
|
cloud_points = cloud_points.to(self.device) |
|
|
|
|
|
t_vals = [0.25, 0.5, 0.75] |
|
|
t_labels = ["t=1/4", "t=1/2", "t=3/4"] |
|
|
|
|
|
colors = { |
|
|
"x0": "#4D176C", |
|
|
"t=1/4": "#5C3B9D", |
|
|
"t=1/2": "#6172B9", |
|
|
"t=3/4": "#AC4E51", |
|
|
"x1": "#771F4F", |
|
|
} |
|
|
|
|
|
|
|
|
if self.whiten: |
|
|
cloud_points = torch.tensor( |
|
|
self.trainer.datamodule.scaler.inverse_transform(cloud_points.cpu().numpy()) |
|
|
) |
|
|
|
|
|
for i in range(self.branches): |
|
|
geopath = self.geopath_nets[i] |
|
|
x1_key = f"x1_{i + 1}" |
|
|
if x1_key not in main_batch: |
|
|
print(f"Skipping branch {i + 1}: no final distribution {x1_key}") |
|
|
continue |
|
|
|
|
|
x1 = main_batch[x1_key][0].to(self.device) |
|
|
print(x1.shape) |
|
|
print(x0.shape) |
|
|
interpolated_points = [] |
|
|
with torch.no_grad(): |
|
|
for t_scalar in t_vals: |
|
|
t_tensor = torch.full((x0.shape[0], 1), t_scalar, device=self.device) |
|
|
xt = geopath(x0, x1, t_tensor).cpu() |
|
|
if self.whiten: |
|
|
xt = torch.tensor( |
|
|
self.trainer.datamodule.scaler.inverse_transform(xt.numpy()) |
|
|
) |
|
|
interpolated_points.append(xt) |
|
|
|
|
|
if self.whiten: |
|
|
x0_plot = torch.tensor( |
|
|
self.trainer.datamodule.scaler.inverse_transform(x0.cpu().numpy()) |
|
|
) |
|
|
x1_plot = torch.tensor( |
|
|
self.trainer.datamodule.scaler.inverse_transform(x1.cpu().numpy()) |
|
|
) |
|
|
else: |
|
|
x0_plot = x0.cpu() |
|
|
x1_plot = x1.cpu() |
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(6, 5)) |
|
|
ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
|
|
ax.view_init(elev=30, azim=-115, roll=0) |
|
|
plot_lidar(ax, cloud_points) |
|
|
|
|
|
|
|
|
ax.scatter( |
|
|
x0_plot[:, 0], x0_plot[:, 1], x0_plot[:, 2], |
|
|
s=15, alpha=1.0, color=colors["x0"], label="x₀", depthshade=True, |
|
|
edgecolors="white", |
|
|
linewidths=0.3 |
|
|
) |
|
|
|
|
|
|
|
|
for xt, t_label in zip(interpolated_points, t_labels): |
|
|
ax.scatter( |
|
|
xt[:, 0], xt[:, 1], xt[:, 2], |
|
|
s=15, alpha=1.0, color=colors[t_label], label=t_label, depthshade=True, |
|
|
edgecolors="white", |
|
|
linewidths=0.3 |
|
|
) |
|
|
|
|
|
|
|
|
ax.scatter( |
|
|
x1_plot[:, 0], x1_plot[:, 1], x1_plot[:, 2], |
|
|
s=15, alpha=1.0, color=colors["x1"], label="x₁", depthshade=True, |
|
|
edgecolors="white", |
|
|
linewidths=0.3 |
|
|
) |
|
|
|
|
|
ax.legend() |
|
|
save_path = f"/raid/st512/branchsbm/figures/{self.args.data_type}/lidar_geopath_branch_{i+1}.png" |
|
|
plt.savefig(save_path, dpi=300) |
|
|
plt.close() |
|
|
|
|
|
def optimizer_step(self, *args, **kwargs): |
|
|
super().optimizer_step(*args, **kwargs) |
|
|
if isinstance(self.geopath_nets, EMA): |
|
|
self.geopath_nets.update_ema() |
|
|
|
|
|
def configure_optimizers(self): |
|
|
if self.optimizer_name == "adam": |
|
|
"""optimizer = torch.optim.Adam( |
|
|
self.geopath_nets.parameters(), |
|
|
lr=self.lr, |
|
|
)""" |
|
|
optimizer = torch.optim.Adam( |
|
|
itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr |
|
|
) |
|
|
elif self.optimizer_name == "adamw": |
|
|
"""optimizer = torch.optim.AdamW( |
|
|
self.geopath_nets.parameters(), |
|
|
lr=self.lr, |
|
|
weight_decay=self.weight_decay, |
|
|
)""" |
|
|
optimizer = torch.optim.AdamW( |
|
|
itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr |
|
|
) |
|
|
return optimizer |
|
|
|