BranchSBM / branchsbm /branch_flow_net_train.py
sophiat44
model upload
5a87d8d
import os
import sys
sys.path.append("./BranchSBM")
import torch
import wandb
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.optim import AdamW
from torchmetrics.functional import mean_squared_error
from torchdyn.core import NeuralODE
from networks.utils import flow_model_torch_wrapper
from utils import wasserstein_distance, plot_lidar
from branchsbm.ema import EMA
class BranchFlowNetTrainBase(pl.LightningModule):
def __init__(
self,
flow_matcher,
flow_nets,
skipped_time_points=None,
ot_sampler=None,
args=None,
):
super().__init__()
self.args = args
self.flow_matcher = flow_matcher
self.flow_nets = flow_nets # list of flow networks for each branch
self.ot_sampler = ot_sampler
self.skipped_time_points = skipped_time_points
self.optimizer_name = args.flow_optimizer
self.lr = args.flow_lr
self.weight_decay = args.flow_weight_decay
self.whiten = args.whiten
self.working_dir = args.working_dir
#branching
self.branches = len(flow_nets)
def forward(self, t, xt, branch_idx):
# output velocity given branch_idx
return self.flow_nets[branch_idx](t, xt)
def _compute_loss(self, main_batch):
x0s = [main_batch["x0"][0]]
w0s = [main_batch["x0"][1]]
x1s_list = []
w1s_list = []
if self.branches > 1:
for i in range(self.branches):
x1s_list.append([main_batch[f"x1_{i+1}"][0]])
w1s_list.append([main_batch[f"x1_{i+1}"][1]])
else:
x1s_list.append([main_batch["x1"][0]])
w1s_list.append([main_batch["x1"][1]])
assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
loss = 0
for branch_idx in range(self.branches):
ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx)
t = torch.cat(ts)
xt = torch.cat(xts)
ut = torch.cat(uts)
vt = self(t[:, None], xt, branch_idx)
loss += mean_squared_error(vt, ut)
return loss
def _process_flow(self, x0s, x1s, branch_idx):
ts, xts, uts = [], [], []
t_start = self.timesteps[0]
for i, (x0, x1) in enumerate(zip(x0s, x1s)):
x0, x1 = torch.squeeze(x0), torch.squeeze(x1)
if self.ot_sampler is not None:
x0, x1 = self.ot_sampler.sample_plan(
x0,
x1,
replace=True,
)
if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]:
t_start_next = self.timesteps[i + 2]
else:
t_start_next = self.timesteps[i + 1]
# edit to sample from correct flow matcher
t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(
x0, x1, t_start, t_start_next, branch_idx
)
ts.append(t)
xts.append(xt)
uts.append(ut)
t_start = t_start_next
return ts, xts, uts
def training_step(self, batch, batch_idx):
if self.args.data_type in ["scrna", "tahoe"]:
main_batch = batch[0]["train_samples"][0]
else:
main_batch = batch["train_samples"][0]
print("Main batch length")
print(len(main_batch["x0"]))
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
loss = self._compute_loss(main_batch)
if self.flow_matcher.alpha != 0:
self.log(
"FlowNet/mean_geopath_cfm",
(self.flow_matcher.geopath_net_output.abs().mean()),
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"FlowNet/train_loss_cfm",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def validation_step(self, batch, batch_idx):
if self.args.data_type in ["scrna", "tahoe"]:
main_batch = batch[0]["val_samples"][0]
else:
main_batch = batch["val_samples"][0]
self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
val_loss = self._compute_loss(main_batch)
self.log(
"FlowNet/val_loss_cfm",
val_loss,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return val_loss
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
for net in self.flow_nets:
if isinstance(net, EMA):
net.update_ema()
def configure_optimizers(self):
if self.optimizer_name == "adamw":
optimizer = AdamW(
self.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)
elif self.optimizer_name == "adam":
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.lr,
)
return optimizer
class FlowNetTrainTrajectory(BranchFlowNetTrainBase):
def test_step(self, batch, batch_idx):
data_type = self.args.data_type
node = NeuralODE(
flow_model_torch_wrapper(self.flow_nets),
solver="euler",
sensitivity="adjoint",
atol=1e-5,
rtol=1e-5,
)
t_exclude = self.skipped_time_points[0] if self.skipped_time_points else None
if t_exclude is not None:
traj = node.trajectory(
batch[t_exclude - 1],
t_span=torch.linspace(
self.timesteps[t_exclude - 1], self.timesteps[t_exclude], 101
),
)
X_mid_pred = traj[-1]
traj = node.trajectory(
batch[t_exclude - 1],
t_span=torch.linspace(
self.timesteps[t_exclude - 1],
self.timesteps[t_exclude + 1],
101,
),
)
EMD = wasserstein_distance(X_mid_pred, batch[t_exclude], p=1)
self.final_EMD = EMD
self.log("test_EMD", EMD, on_step=False, on_epoch=True, prog_bar=True)
class FlowNetTrainCell(BranchFlowNetTrainBase):
def test_step(self, batch, batch_idx):
x0 = batch[0]["test_samples"][0]["x0"][0] # [B, D]
dataset_points = batch[0]["test_samples"][0]["dataset"][0] # full dataset, [N, D]
t_span = torch.linspace(0, 1, 101)
all_trajs = []
for i, flow_net in enumerate(self.flow_nets):
node = NeuralODE(
flow_model_torch_wrapper(flow_net),
solver="euler",
sensitivity="adjoint",
)
with torch.no_grad():
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
if self.whiten:
traj_shape = traj.shape
traj = traj.reshape(-1, traj.shape[-1])
traj = self.trainer.datamodule.scaler.inverse_transform(
traj.cpu().detach().numpy()
).reshape(traj_shape)
dataset_points = self.trainer.datamodule.scaler.inverse_transform(
dataset_points.cpu().detach().numpy()
)
traj = torch.tensor(traj)
traj = torch.transpose(traj, 0, 1) # [B, T, D]
all_trajs.append(traj)
dataset_2d = dataset_points[:, :2] if isinstance(dataset_points, torch.Tensor) else dataset_points[:, :2]
# ===== Plot all 2D trajectories together with dataset and start/end points =====
fig, ax = plt.subplots(figsize=(6, 5))
dataset_2d = dataset_2d.cpu().numpy()
ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
for traj in all_trajs:
traj_2d = traj[..., :2] # [B, T, 2]
for i in range(traj_2d.shape[0]):
ax.plot(traj_2d[i, :, 0], traj_2d[i, :, 1], alpha=0.8, zorder=2)
ax.scatter(traj_2d[i, 0, 0], traj_2d[i, 0, 1], c='green', s=10, label="t=0" if i == 0 else "", zorder=3)
ax.scatter(traj_2d[i, -1, 0], traj_2d[i, -1, 1], c='red', s=10, label="t=1" if i == 0 else "", zorder=3)
ax.set_title("All Branch Trajectories (2D) with Dataset")
ax.set_xlabel("x")
ax.set_ylabel("y")
plt.axis("equal")
handles, labels = ax.get_legend_handles_labels()
if labels:
ax.legend()
save_path = f'./figures/{self.args.data_name}'
os.makedirs(save_path, exist_ok=True)
plt.savefig(f'{save_path}/{self.args.data_name}_all_branches.png', dpi=300)
plt.close()
# ===== Plot each 2D trajectory separately with dataset and endpoints =====
for i, traj in enumerate(all_trajs):
traj_2d = traj[..., :2]
fig, ax = plt.subplots(figsize=(6, 5))
ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
for j in range(traj_2d.shape[0]):
ax.plot(traj_2d[j, :, 0], traj_2d[j, :, 1], alpha=0.9, zorder=2)
ax.scatter(traj_2d[j, 0, 0], traj_2d[j, 0, 1], c='green', s=12, label="t=0" if j == 0 else "", zorder=3)
ax.scatter(traj_2d[j, -1, 0], traj_2d[j, -1, 1], c='red', s=12, label="t=1" if j == 0 else "", zorder=3)
ax.set_title(f"Branch {i + 1} Trajectories (2D) with Dataset")
ax.set_xlabel("x")
ax.set_ylabel("y")
plt.axis("equal")
handles, labels = ax.get_legend_handles_labels()
if labels:
ax.legend()
plt.savefig(f'{save_path}/{self.args.data_name}_branch_{i + 1}.png', dpi=300)
plt.close()
class FlowNetTrainLidar(BranchFlowNetTrainBase):
def test_step(self, batch, batch_idx):
main_batch = batch["test_samples"][0]
metric_batch = batch["metric_samples"][0]
x0 = main_batch["x0"][0] # [B, D]
cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
t_span = torch.linspace(0, 1, 101)
all_trajs = []
for i, flow_net in enumerate(self.flow_nets):
node = NeuralODE(
flow_model_torch_wrapper(flow_net),
solver="euler",
sensitivity="adjoint",
)
with torch.no_grad():
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
if self.whiten:
traj_shape = traj.shape
traj = traj.reshape(-1, 3)
traj = self.trainer.datamodule.scaler.inverse_transform(
traj.cpu().detach().numpy()
).reshape(traj_shape)
traj = torch.tensor(traj)
traj = torch.transpose(traj, 0, 1) # [B, T, D]
all_trajs.append(traj)
# Inverse-transform the point cloud once
if self.whiten:
cloud_points = torch.tensor(
self.trainer.datamodule.scaler.inverse_transform(
cloud_points.cpu().detach().numpy()
)
)
# ===== Plot all trajectories together =====
fig = plt.figure(figsize=(6, 5))
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
ax.view_init(elev=30, azim=-115, roll=0)
for i, traj in enumerate(all_trajs):
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
plt.savefig('./figures/lidar/lidar_all_branches.png', dpi=300)
plt.close()
# ===== Plot each trajectory separately =====
for i, traj in enumerate(all_trajs):
fig = plt.figure(figsize=(6, 5))
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
ax.view_init(elev=30, azim=-115, roll=0)
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
plt.savefig(f'./figures/lidar/lidar_branch_{i + 1}.png', dpi=300)
plt.close()