|
|
import os |
|
|
import sys |
|
|
sys.path.append("./BranchSBM") |
|
|
import torch |
|
|
import wandb |
|
|
import matplotlib.pyplot as plt |
|
|
import pytorch_lightning as pl |
|
|
from torch.optim import AdamW |
|
|
from torchmetrics.functional import mean_squared_error |
|
|
from torchdyn.core import NeuralODE |
|
|
from networks.utils import flow_model_torch_wrapper |
|
|
from utils import wasserstein_distance, plot_lidar |
|
|
from branchsbm.ema import EMA |
|
|
|
|
|
class BranchFlowNetTrainBase(pl.LightningModule): |
|
|
def __init__( |
|
|
self, |
|
|
flow_matcher, |
|
|
flow_nets, |
|
|
skipped_time_points=None, |
|
|
ot_sampler=None, |
|
|
args=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
|
|
|
self.flow_matcher = flow_matcher |
|
|
self.flow_nets = flow_nets |
|
|
self.ot_sampler = ot_sampler |
|
|
self.skipped_time_points = skipped_time_points |
|
|
|
|
|
self.optimizer_name = args.flow_optimizer |
|
|
self.lr = args.flow_lr |
|
|
self.weight_decay = args.flow_weight_decay |
|
|
self.whiten = args.whiten |
|
|
self.working_dir = args.working_dir |
|
|
|
|
|
|
|
|
self.branches = len(flow_nets) |
|
|
|
|
|
def forward(self, t, xt, branch_idx): |
|
|
|
|
|
return self.flow_nets[branch_idx](t, xt) |
|
|
|
|
|
def _compute_loss(self, main_batch): |
|
|
|
|
|
x0s = [main_batch["x0"][0]] |
|
|
w0s = [main_batch["x0"][1]] |
|
|
|
|
|
x1s_list = [] |
|
|
w1s_list = [] |
|
|
|
|
|
if self.branches > 1: |
|
|
for i in range(self.branches): |
|
|
x1s_list.append([main_batch[f"x1_{i+1}"][0]]) |
|
|
w1s_list.append([main_batch[f"x1_{i+1}"][1]]) |
|
|
else: |
|
|
x1s_list.append([main_batch["x1"][0]]) |
|
|
w1s_list.append([main_batch["x1"][1]]) |
|
|
|
|
|
assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches" |
|
|
|
|
|
loss = 0 |
|
|
for branch_idx in range(self.branches): |
|
|
ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx) |
|
|
|
|
|
t = torch.cat(ts) |
|
|
xt = torch.cat(xts) |
|
|
ut = torch.cat(uts) |
|
|
vt = self(t[:, None], xt, branch_idx) |
|
|
|
|
|
loss += mean_squared_error(vt, ut) |
|
|
|
|
|
return loss |
|
|
|
|
|
def _process_flow(self, x0s, x1s, branch_idx): |
|
|
ts, xts, uts = [], [], [] |
|
|
t_start = self.timesteps[0] |
|
|
|
|
|
for i, (x0, x1) in enumerate(zip(x0s, x1s)): |
|
|
|
|
|
x0, x1 = torch.squeeze(x0), torch.squeeze(x1) |
|
|
|
|
|
if self.ot_sampler is not None: |
|
|
x0, x1 = self.ot_sampler.sample_plan( |
|
|
x0, |
|
|
x1, |
|
|
replace=True, |
|
|
) |
|
|
if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]: |
|
|
t_start_next = self.timesteps[i + 2] |
|
|
else: |
|
|
t_start_next = self.timesteps[i + 1] |
|
|
|
|
|
|
|
|
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow( |
|
|
x0, x1, t_start, t_start_next, branch_idx |
|
|
) |
|
|
|
|
|
ts.append(t) |
|
|
|
|
|
xts.append(xt) |
|
|
uts.append(ut) |
|
|
t_start = t_start_next |
|
|
return ts, xts, uts |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
if self.args.data_type in ["scrna", "tahoe"]: |
|
|
main_batch = batch[0]["train_samples"][0] |
|
|
else: |
|
|
main_batch = batch["train_samples"][0] |
|
|
|
|
|
print("Main batch length") |
|
|
print(len(main_batch["x0"])) |
|
|
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() |
|
|
loss = self._compute_loss(main_batch) |
|
|
if self.flow_matcher.alpha != 0: |
|
|
self.log( |
|
|
"FlowNet/mean_geopath_cfm", |
|
|
(self.flow_matcher.geopath_net_output.abs().mean()), |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
) |
|
|
|
|
|
self.log( |
|
|
"FlowNet/train_loss_cfm", |
|
|
loss, |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
) |
|
|
|
|
|
|
|
|
return loss |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
if self.args.data_type in ["scrna", "tahoe"]: |
|
|
main_batch = batch[0]["val_samples"][0] |
|
|
else: |
|
|
main_batch = batch["val_samples"][0] |
|
|
|
|
|
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() |
|
|
val_loss = self._compute_loss(main_batch) |
|
|
self.log( |
|
|
"FlowNet/val_loss_cfm", |
|
|
val_loss, |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
logger=True, |
|
|
) |
|
|
return val_loss |
|
|
|
|
|
def optimizer_step(self, *args, **kwargs): |
|
|
super().optimizer_step(*args, **kwargs) |
|
|
|
|
|
for net in self.flow_nets: |
|
|
if isinstance(net, EMA): |
|
|
net.update_ema() |
|
|
|
|
|
def configure_optimizers(self): |
|
|
if self.optimizer_name == "adamw": |
|
|
optimizer = AdamW( |
|
|
self.parameters(), |
|
|
lr=self.lr, |
|
|
weight_decay=self.weight_decay, |
|
|
) |
|
|
elif self.optimizer_name == "adam": |
|
|
optimizer = torch.optim.Adam( |
|
|
self.parameters(), |
|
|
lr=self.lr, |
|
|
) |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
class FlowNetTrainTrajectory(BranchFlowNetTrainBase): |
|
|
def test_step(self, batch, batch_idx): |
|
|
data_type = self.args.data_type |
|
|
node = NeuralODE( |
|
|
flow_model_torch_wrapper(self.flow_nets), |
|
|
solver="euler", |
|
|
sensitivity="adjoint", |
|
|
atol=1e-5, |
|
|
rtol=1e-5, |
|
|
) |
|
|
|
|
|
t_exclude = self.skipped_time_points[0] if self.skipped_time_points else None |
|
|
if t_exclude is not None: |
|
|
traj = node.trajectory( |
|
|
batch[t_exclude - 1], |
|
|
t_span=torch.linspace( |
|
|
self.timesteps[t_exclude - 1], self.timesteps[t_exclude], 101 |
|
|
), |
|
|
) |
|
|
X_mid_pred = traj[-1] |
|
|
traj = node.trajectory( |
|
|
batch[t_exclude - 1], |
|
|
t_span=torch.linspace( |
|
|
self.timesteps[t_exclude - 1], |
|
|
self.timesteps[t_exclude + 1], |
|
|
101, |
|
|
), |
|
|
) |
|
|
|
|
|
EMD = wasserstein_distance(X_mid_pred, batch[t_exclude], p=1) |
|
|
self.final_EMD = EMD |
|
|
|
|
|
self.log("test_EMD", EMD, on_step=False, on_epoch=True, prog_bar=True) |
|
|
|
|
|
class FlowNetTrainCell(BranchFlowNetTrainBase): |
|
|
def test_step(self, batch, batch_idx): |
|
|
x0 = batch[0]["test_samples"][0]["x0"][0] |
|
|
dataset_points = batch[0]["test_samples"][0]["dataset"][0] |
|
|
t_span = torch.linspace(0, 1, 101) |
|
|
|
|
|
all_trajs = [] |
|
|
|
|
|
for i, flow_net in enumerate(self.flow_nets): |
|
|
node = NeuralODE( |
|
|
flow_model_torch_wrapper(flow_net), |
|
|
solver="euler", |
|
|
sensitivity="adjoint", |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
traj = node.trajectory(x0, t_span).cpu() |
|
|
|
|
|
if self.whiten: |
|
|
traj_shape = traj.shape |
|
|
traj = traj.reshape(-1, traj.shape[-1]) |
|
|
traj = self.trainer.datamodule.scaler.inverse_transform( |
|
|
traj.cpu().detach().numpy() |
|
|
).reshape(traj_shape) |
|
|
dataset_points = self.trainer.datamodule.scaler.inverse_transform( |
|
|
dataset_points.cpu().detach().numpy() |
|
|
) |
|
|
|
|
|
traj = torch.tensor(traj) |
|
|
traj = torch.transpose(traj, 0, 1) |
|
|
all_trajs.append(traj) |
|
|
|
|
|
dataset_2d = dataset_points[:, :2] if isinstance(dataset_points, torch.Tensor) else dataset_points[:, :2] |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(6, 5)) |
|
|
dataset_2d = dataset_2d.cpu().numpy() |
|
|
ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1) |
|
|
for traj in all_trajs: |
|
|
traj_2d = traj[..., :2] |
|
|
for i in range(traj_2d.shape[0]): |
|
|
ax.plot(traj_2d[i, :, 0], traj_2d[i, :, 1], alpha=0.8, zorder=2) |
|
|
ax.scatter(traj_2d[i, 0, 0], traj_2d[i, 0, 1], c='green', s=10, label="t=0" if i == 0 else "", zorder=3) |
|
|
ax.scatter(traj_2d[i, -1, 0], traj_2d[i, -1, 1], c='red', s=10, label="t=1" if i == 0 else "", zorder=3) |
|
|
|
|
|
ax.set_title("All Branch Trajectories (2D) with Dataset") |
|
|
ax.set_xlabel("x") |
|
|
ax.set_ylabel("y") |
|
|
plt.axis("equal") |
|
|
handles, labels = ax.get_legend_handles_labels() |
|
|
if labels: |
|
|
ax.legend() |
|
|
|
|
|
save_path = f'./figures/{self.args.data_name}' |
|
|
|
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
plt.savefig(f'{save_path}/{self.args.data_name}_all_branches.png', dpi=300) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
for i, traj in enumerate(all_trajs): |
|
|
traj_2d = traj[..., :2] |
|
|
fig, ax = plt.subplots(figsize=(6, 5)) |
|
|
ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1) |
|
|
for j in range(traj_2d.shape[0]): |
|
|
ax.plot(traj_2d[j, :, 0], traj_2d[j, :, 1], alpha=0.9, zorder=2) |
|
|
ax.scatter(traj_2d[j, 0, 0], traj_2d[j, 0, 1], c='green', s=12, label="t=0" if j == 0 else "", zorder=3) |
|
|
ax.scatter(traj_2d[j, -1, 0], traj_2d[j, -1, 1], c='red', s=12, label="t=1" if j == 0 else "", zorder=3) |
|
|
|
|
|
ax.set_title(f"Branch {i + 1} Trajectories (2D) with Dataset") |
|
|
ax.set_xlabel("x") |
|
|
ax.set_ylabel("y") |
|
|
plt.axis("equal") |
|
|
handles, labels = ax.get_legend_handles_labels() |
|
|
if labels: |
|
|
ax.legend() |
|
|
plt.savefig(f'{save_path}/{self.args.data_name}_branch_{i + 1}.png', dpi=300) |
|
|
plt.close() |
|
|
|
|
|
class FlowNetTrainLidar(BranchFlowNetTrainBase): |
|
|
def test_step(self, batch, batch_idx): |
|
|
main_batch = batch["test_samples"][0] |
|
|
metric_batch = batch["metric_samples"][0] |
|
|
|
|
|
x0 = main_batch["x0"][0] |
|
|
cloud_points = main_batch["dataset"][0] |
|
|
t_span = torch.linspace(0, 1, 101) |
|
|
|
|
|
all_trajs = [] |
|
|
|
|
|
for i, flow_net in enumerate(self.flow_nets): |
|
|
node = NeuralODE( |
|
|
flow_model_torch_wrapper(flow_net), |
|
|
solver="euler", |
|
|
sensitivity="adjoint", |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
traj = node.trajectory(x0, t_span).cpu() |
|
|
|
|
|
if self.whiten: |
|
|
traj_shape = traj.shape |
|
|
traj = traj.reshape(-1, 3) |
|
|
traj = self.trainer.datamodule.scaler.inverse_transform( |
|
|
traj.cpu().detach().numpy() |
|
|
).reshape(traj_shape) |
|
|
|
|
|
traj = torch.tensor(traj) |
|
|
traj = torch.transpose(traj, 0, 1) |
|
|
all_trajs.append(traj) |
|
|
|
|
|
|
|
|
if self.whiten: |
|
|
cloud_points = torch.tensor( |
|
|
self.trainer.datamodule.scaler.inverse_transform( |
|
|
cloud_points.cpu().detach().numpy() |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(6, 5)) |
|
|
ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
|
|
ax.view_init(elev=30, azim=-115, roll=0) |
|
|
for i, traj in enumerate(all_trajs): |
|
|
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) |
|
|
plt.savefig('./figures/lidar/lidar_all_branches.png', dpi=300) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
for i, traj in enumerate(all_trajs): |
|
|
fig = plt.figure(figsize=(6, 5)) |
|
|
ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
|
|
ax.view_init(elev=30, azim=-115, roll=0) |
|
|
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) |
|
|
plt.savefig(f'./figures/lidar/lidar_branch_{i + 1}.png', dpi=300) |
|
|
plt.close() |