BranchSBM / src /branch_flow_net_test.py
Sophia Tang
Initial commit
b55bace
"""
Separate test classes for each BranchSBM experiment with specific plotting styles.
Each class handles testing and visualization for: LiDAR, Mouse, Clonidine, Trametinib, Veres.
"""
import os
import json
import csv
import torch
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import random
import ot
from torchdyn.core import NeuralODE
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.collections import LineCollection
from .networks.utils import flow_model_torch_wrapper
from .branch_flow_net_train import BranchFlowNetTrainBase
from .branch_growth_net_train import GrowthNetTrain
from .utils import wasserstein, mix_rbf_mmd2, plot_lidar
import json
def evaluate_model(gt_data, model_data, a, b):
# ensure inputs are tensors
if not isinstance(gt_data, torch.Tensor):
gt_data = torch.tensor(gt_data, dtype=torch.float32)
if not isinstance(model_data, torch.Tensor):
model_data = torch.tensor(model_data, dtype=torch.float32)
# choose device: prefer model_data's device if it's not CPU, otherwise use gt_data's device
try:
model_dev = model_data.device
except Exception:
model_dev = torch.device('cpu')
try:
gt_dev = gt_data.device
except Exception:
gt_dev = torch.device('cpu')
device = model_dev if model_dev.type != 'cpu' else gt_dev
gt = gt_data.to(device=device, dtype=torch.float32)
md = model_data.to(device=device, dtype=torch.float32)
M = torch.cdist(gt, md, p=2).cpu().numpy()
if np.isnan(M).any() or np.isinf(M).any():
return np.nan
return ot.emd2(a, b, M, numItermax=1e7)
def compute_distribution_distances(pred, true, pred_full=None, true_full=None):
w1 = wasserstein(pred, true, power=1)
w2 = wasserstein(pred, true, power=2)
# Use full dimensions for MMD if provided, otherwise use same as W1/W2
mmd_pred = pred_full if pred_full is not None else pred
mmd_true = true_full if true_full is not None else true
# MMD requires same number of samples — randomly subsample the larger set
n_pred, n_true = mmd_pred.shape[0], mmd_true.shape[0]
if n_pred > n_true:
perm = torch.randperm(n_pred)[:n_true]
mmd_pred = mmd_pred[perm]
elif n_true > n_pred:
perm = torch.randperm(n_true)[:n_pred]
mmd_true = mmd_true[perm]
mmd = mix_rbf_mmd2(mmd_pred, mmd_true, sigma_list=[0.01, 0.1, 1, 10, 100]).item()
return {"W1": w1, "W2": w2, "MMD": mmd}
def compute_tmv_from_mass_over_time(mass_over_time, all_endpoints, time_points=None, timepoint_data=None, time_index=None, target_time=None, gt_key_template='t1_{}', weights_over_time=None):
if weights_over_time is not None or mass_over_time is not None:
if time_index is None:
if target_time is not None and time_points is not None:
arr = np.array(time_points)
time_index = int(np.argmin(np.abs(arr - float(target_time))))
else:
# default to last index
ref_list = weights_over_time if weights_over_time is not None else mass_over_time
time_index = len(ref_list[0]) - 1
else:
# neither available; time_index not used
if time_index is None:
time_index = -1
n_branches = len(all_endpoints)
# initial total cells for normalization
n_initial = None
if timepoint_data is not None and 't0' in timepoint_data:
try:
n_initial = int(timepoint_data['t0'].shape[0])
except Exception:
n_initial = None
pred_masses = []
for i in range(n_branches):
# Use sum of actual particle weights if available, otherwise mean_weight * num_particles
if weights_over_time is not None:
try:
weights_tensor = weights_over_time[i][time_index]
# Sum all particle weights to get total mass for this branch
total_mass = float(weights_tensor.sum().item())
pred_masses.append(total_mass)
continue
except Exception:
pass # Fall through to mean weight calculation
# Fallback: mean weight from mass_over_time if available, otherwise assume weight=1
mean_w = 1.0
if mass_over_time is not None:
try:
mean_w = float(mass_over_time[i][time_index])
except Exception:
mean_w = 1.0
# determine number of particles for this branch
num_particles = 0
try:
if hasattr(all_endpoints[i], 'shape'):
num_particles = int(all_endpoints[i].shape[0])
else:
num_particles = int(len(all_endpoints[i]))
except Exception:
num_particles = 0
pred_masses.append(mean_w * float(num_particles))
# ground-truth masses per branch
gt_masses = []
if timepoint_data is not None:
for i in range(n_branches):
key1 = gt_key_template.format(i)
if key1 in timepoint_data:
gt_masses.append(float(timepoint_data[key1].shape[0]))
else:
base_key = gt_key_template.split("_")[0] if '_' in gt_key_template else gt_key_template
if base_key in timepoint_data:
gt_masses.append(float(timepoint_data[base_key].shape[0]))
else:
gt_masses.append(0.0)
else:
gt_masses = [0.0 for _ in range(n_branches)]
# determine normalization denominator
if n_initial is None:
s = float(sum(gt_masses))
if s > 0:
n_initial = s
else:
n_initial = float(sum(pred_masses)) if sum(pred_masses) > 0 else 1.0
pred_fracs = [m / float(n_initial) for m in pred_masses]
gt_fracs = [m / float(n_initial) for m in gt_masses]
tmv = 0.5 * float(np.sum(np.abs(np.array(pred_fracs) - np.array(gt_fracs))))
return {
'time_index': time_index,
'pred_masses': pred_masses,
'gt_masses': gt_masses,
'pred_fracs': pred_fracs,
'gt_fracs': gt_fracs,
'tmv': tmv,
}
class FlowNetTestLidar(GrowthNetTrain):
def test_step(self, batch, batch_idx):
# Unwrap CombinedLoader outer tuple if needed
if isinstance(batch, (list, tuple)) and len(batch) == 1:
batch = batch[0]
if isinstance(batch, dict) and "test_samples" in batch:
test_samples = batch["test_samples"]
metric_samples = batch["metric_samples"]
if isinstance(test_samples, (list, tuple)) and len(test_samples) >= 2 and isinstance(test_samples[-1], int):
test_samples = test_samples[0]
if isinstance(metric_samples, (list, tuple)) and len(metric_samples) >= 2 and isinstance(metric_samples[-1], int):
metric_samples = metric_samples[0]
if isinstance(test_samples, (list, tuple)) and len(test_samples) == 1:
test_samples = test_samples[0]
main_batch = test_samples
if isinstance(metric_samples, dict):
metric_batch = list(metric_samples.values())
elif isinstance(metric_samples, (list, tuple)):
metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples]
else:
metric_batch = [metric_samples]
elif isinstance(batch, (list, tuple)) and len(batch) == 2:
# Old tuple format: (test_samples, metric_samples)
# Each could be dict or list
test_samples = batch[0]
metric_samples = batch[1]
if isinstance(test_samples, dict):
main_batch = test_samples
elif isinstance(test_samples, (list, tuple)):
main_batch = test_samples[0]
else:
main_batch = test_samples
if isinstance(metric_samples, dict):
metric_batch = list(metric_samples.values())
elif isinstance(metric_samples, (list, tuple)):
metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples]
else:
metric_batch = [metric_samples]
else:
# Fallback
main_batch = batch
metric_batch = []
timepoint_data = self.trainer.datamodule.get_timepoint_data()
# main_batch is a dict like {"x0": (tensor, weights), ...}
if isinstance(main_batch, dict):
device = main_batch["x0"][0].device
else:
device = main_batch[0]["x0"][0].device
x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device)
full_batch = {"x0": (x0_all, w0_all)}
time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch)
cloud_points = main_batch["dataset"][0] # [N, 3]
# Run 5 trials with random subsampling for robust metrics
n_trials = 5
# Compute per-branch metrics
metrics_dict = {}
for i, endpoints in enumerate(all_endpoints):
true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
true_data = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(endpoints.device)
w1_br, w2_br, mmd_br = [], [], []
for trial in range(n_trials):
n_min = min(endpoints.shape[0], true_data.shape[0])
perm_pred = torch.randperm(endpoints.shape[0])[:n_min]
perm_gt = torch.randperm(true_data.shape[0])[:n_min]
m = compute_distribution_distances(
endpoints[perm_pred, :2], true_data[perm_gt, :2],
pred_full=endpoints[perm_pred], true_full=true_data[perm_gt]
)
w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
metrics_dict[f"branch_{i+1}"] = {
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
}
self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True)
print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
# Compute combined metrics across all branches (5 trials)
all_pred_combined = torch.cat(list(all_endpoints), dim=0)
all_true_list = []
for i in range(len(all_endpoints)):
true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
all_true_list.append(torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(all_pred_combined.device))
all_true_combined = torch.cat(all_true_list, dim=0)
w1_trials, w2_trials, mmd_trials = [], [], []
for trial in range(n_trials):
n_min = min(all_pred_combined.shape[0], all_true_combined.shape[0])
perm_pred = torch.randperm(all_pred_combined.shape[0])[:n_min]
perm_gt = torch.randperm(all_true_combined.shape[0])[:n_min]
m = compute_distribution_distances(
all_pred_combined[perm_pred, :2], all_true_combined[perm_gt, :2],
pred_full=all_pred_combined[perm_pred], true_full=all_true_combined[perm_gt]
)
w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
self.log("test/W1_combined", w1_mean, on_epoch=True)
self.log("test/W2_combined", w2_mean, on_epoch=True)
self.log("test/MMD_combined", mmd_mean, on_epoch=True)
metrics_dict["combined"] = {
"W1_mean": float(w1_mean), "W1_std": float(w1_std),
"W2_mean": float(w2_mean), "W2_std": float(w2_std),
"MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
"n_trials": n_trials,
}
print(f"\n=== Combined ===")
print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
# Inverse-transform cloud points for visualization
if self.whiten:
cloud_points = torch.tensor(
self.trainer.datamodule.scaler.inverse_transform(
cloud_points.cpu().detach().numpy()
)
)
# Create results directory structure
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
figures_dir = f'{results_dir}/figures'
os.makedirs(figures_dir, exist_ok=True)
# Save metrics to JSON
metrics_path = f'{results_dir}/metrics.json'
with open(metrics_path, 'w') as f:
json.dump(metrics_dict, f, indent=2)
print(f"Metrics saved to {metrics_path}")
# Save detailed per-branch metrics to CSV
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
with open(detailed_csv_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
for key in sorted(metrics_dict.keys()):
m = metrics_dict[key]
writer.writerow([key,
f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}',
f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}',
f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}'])
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
# Convert all_trajs from list of lists to stacked tensors for plotting
# all_trajs[i] is a list of T tensors of shape [B, D]
# Stack to get shape [B, T, D]
stacked_trajs = []
for traj_list in all_trajs:
# Stack along time dimension (dim=1) to get [B, T, D]
stacked_traj = torch.stack(traj_list, dim=1)
stacked_trajs.append(stacked_traj)
# Inverse-transform trajectories to match cloud_points coordinates
if self.whiten:
stacked_trajs_original = []
for traj in stacked_trajs:
B, T, D = traj.shape
# Reshape to [B*T, D] for inverse transform
traj_flat = traj.reshape(-1, D).cpu().detach().numpy()
traj_inv = self.trainer.datamodule.scaler.inverse_transform(traj_flat)
# Reshape back to [B, T, D]
traj_inv = torch.tensor(traj_inv).reshape(B, T, D)
stacked_trajs_original.append(traj_inv)
stacked_trajs = stacked_trajs_original
# ===== Plot all branches together =====
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
ax.view_init(elev=30, azim=-115, roll=0)
for i, traj in enumerate(stacked_trajs):
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
plt.savefig(f'{figures_dir}/{self.args.data_name}_all_branches.png', dpi=300)
plt.close()
# ===== Plot each branch separately =====
for i, traj in enumerate(stacked_trajs):
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
ax.view_init(elev=30, azim=-115, roll=0)
plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
plt.savefig(f'{figures_dir}/{self.args.data_name}_branch_{i + 1}.png', dpi=300)
plt.close()
print(f"LiDAR figures saved to {figures_dir}")
class FlowNetTestMouse(GrowthNetTrain):
def test_step(self, batch, batch_idx):
# Handle both tuple and dict batch formats from CombinedLoader
if isinstance(batch, dict):
main_batch = batch.get("test_samples", batch)
if isinstance(main_batch, tuple):
main_batch = main_batch[0]
elif isinstance(batch, (list, tuple)) and len(batch) >= 1:
if isinstance(batch[0], dict):
main_batch = batch[0].get("test_samples", batch[0])
if isinstance(main_batch, tuple):
main_batch = main_batch[0]
else:
main_batch = batch[0][0]
else:
main_batch = batch
device = main_batch["x0"][0].device
# Use val x0 as initial conditions
x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
# Get timepoint data for ground truth
timepoint_data = self.trainer.datamodule.get_timepoint_data()
# Ground truth at t1 (intermediate timepoint)
data_t1 = torch.tensor(timepoint_data['t1'], dtype=torch.float32)
# Define color schemes for mouse (2 branches)
custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"]
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1)
custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2)
t_span_full = torch.linspace(0, 1.0, 100).to(device)
all_trajs = []
for i, flow_net in enumerate(self.flow_nets):
node = NeuralODE(
flow_model_torch_wrapper(flow_net),
solver="euler",
sensitivity="adjoint",
).to(device)
with torch.no_grad():
traj = node.trajectory(x0, t_span_full).cpu() # [T, B, D]
traj = torch.transpose(traj, 0, 1) # [B, T, D]
all_trajs.append(traj)
t_span_metric_t1 = torch.linspace(0, 0.5, 50).to(device)
t_span_metric_t2 = torch.linspace(0, 1.0, 100).to(device)
n_trials = 5
# Gather t2 branch ground truth
data_t2_branches = []
for i in range(len(self.flow_nets)):
key = f't2_{i+1}'
if key in timepoint_data:
data_t2_branches.append(torch.tensor(timepoint_data[key], dtype=torch.float32))
elif i == 0 and 't2' in timepoint_data:
data_t2_branches.append(torch.tensor(timepoint_data['t2'], dtype=torch.float32))
else:
data_t2_branches.append(None)
# Combined t2 ground truth (all branches merged)
data_t2_all_list = [d for d in data_t2_branches if d is not None]
data_t2_combined = torch.cat(data_t2_all_list, dim=0) if data_t2_all_list else None
# ---- t1 combined metrics (all branches pooled, compared to t1) ----
w1_t1_trials, w2_t1_trials, mmd_t1_trials = [], [], []
for trial in range(n_trials):
all_preds = []
for i, flow_net in enumerate(self.flow_nets):
node = NeuralODE(
flow_model_torch_wrapper(flow_net),
solver="euler",
sensitivity="adjoint",
).to(device)
with torch.no_grad():
traj = node.trajectory(x0, t_span_metric_t1) # [T, B, D]
x_final = traj[-1].cpu() # [B, D]
all_preds.append(x_final)
preds = torch.cat(all_preds, dim=0)
target_size = preds.shape[0]
perm = torch.randperm(data_t1.shape[0])[:target_size]
data_t1_reduced = data_t1[perm]
metrics = compute_distribution_distances(
preds[:, :2], data_t1_reduced[:, :2]
)
w1_t1_trials.append(metrics["W1"])
w2_t1_trials.append(metrics["W2"])
mmd_t1_trials.append(metrics["MMD"])
# ---- t2 per-branch metrics (each branch endpoint vs its own t2 cluster) ----
branch_t2_metrics = {}
for i, flow_net in enumerate(self.flow_nets):
if data_t2_branches[i] is None:
continue
w1_br, w2_br, mmd_br = [], [], []
for trial in range(n_trials):
node = NeuralODE(
flow_model_torch_wrapper(flow_net),
solver="euler",
sensitivity="adjoint",
).to(device)
with torch.no_grad():
traj = node.trajectory(x0, t_span_metric_t2)
x_final = traj[-1].cpu()
gt = data_t2_branches[i]
n_min = min(x_final.shape[0], gt.shape[0])
perm_pred = torch.randperm(x_final.shape[0])[:n_min]
perm_gt = torch.randperm(gt.shape[0])[:n_min]
m = compute_distribution_distances(
x_final[perm_pred, :2], gt[perm_gt, :2]
)
w1_br.append(m["W1"])
w2_br.append(m["W2"])
mmd_br.append(m["MMD"])
branch_t2_metrics[f"branch_{i+1}_t2"] = {
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
}
print(f"Branch {i+1} @ t2 — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
# ---- t2 combined metrics (all branches pooled, compared to all t2) ----
w1_t2_trials, w2_t2_trials, mmd_t2_trials = [], [], []
if data_t2_combined is not None:
for trial in range(n_trials):
all_preds = []
for i, flow_net in enumerate(self.flow_nets):
node = NeuralODE(
flow_model_torch_wrapper(flow_net),
solver="euler",
sensitivity="adjoint",
).to(device)
with torch.no_grad():
traj = node.trajectory(x0, t_span_metric_t2)
all_preds.append(traj[-1].cpu())
preds = torch.cat(all_preds, dim=0)
n_min = min(preds.shape[0], data_t2_combined.shape[0])
perm_pred = torch.randperm(preds.shape[0])[:n_min]
perm_gt = torch.randperm(data_t2_combined.shape[0])[:n_min]
m = compute_distribution_distances(
preds[perm_pred, :2], data_t2_combined[perm_gt, :2]
)
w1_t2_trials.append(m["W1"])
w2_t2_trials.append(m["W2"])
mmd_t2_trials.append(m["MMD"])
# Compute mean and std
w1_t1_mean, w1_t1_std = np.mean(w1_t1_trials), np.std(w1_t1_trials, ddof=1)
w2_t1_mean, w2_t1_std = np.mean(w2_t1_trials), np.std(w2_t1_trials, ddof=1)
mmd_t1_mean, mmd_t1_std = np.mean(mmd_t1_trials), np.std(mmd_t1_trials, ddof=1)
# Log metrics
self.log("test/W1_combined_t1", w1_t1_mean, on_epoch=True)
self.log("test/W2_combined_t1", w2_t1_mean, on_epoch=True)
self.log("test/MMD_combined_t1", mmd_t1_mean, on_epoch=True)
metrics_dict = {
"combined_t1": {
"W1_mean": float(w1_t1_mean), "W1_std": float(w1_t1_std),
"W2_mean": float(w2_t1_mean), "W2_std": float(w2_t1_std),
"MMD_mean": float(mmd_t1_mean), "MMD_std": float(mmd_t1_std),
"n_trials": n_trials,
}
}
metrics_dict.update(branch_t2_metrics)
if w1_t2_trials:
w1_t2_mean, w1_t2_std = np.mean(w1_t2_trials), np.std(w1_t2_trials, ddof=1)
w2_t2_mean, w2_t2_std = np.mean(w2_t2_trials), np.std(w2_t2_trials, ddof=1)
mmd_t2_mean, mmd_t2_std = np.mean(mmd_t2_trials), np.std(mmd_t2_trials, ddof=1)
self.log("test/W1_combined_t2", w1_t2_mean, on_epoch=True)
self.log("test/W2_combined_t2", w2_t2_mean, on_epoch=True)
self.log("test/MMD_combined_t2", mmd_t2_mean, on_epoch=True)
metrics_dict["combined_t2"] = {
"W1_mean": float(w1_t2_mean), "W1_std": float(w1_t2_std),
"W2_mean": float(w2_t2_mean), "W2_std": float(w2_t2_std),
"MMD_mean": float(mmd_t2_mean), "MMD_std": float(mmd_t2_std),
"n_trials": n_trials,
}
print(f"\n=== Combined @ t1 ===")
print(f"W1: {w1_t1_mean:.6f} ± {w1_t1_std:.6f}")
print(f"W2: {w2_t1_mean:.6f} ± {w2_t1_std:.6f}")
print(f"MMD: {mmd_t1_mean:.6f} ± {mmd_t1_std:.6f}")
if w1_t2_trials:
print(f"\n=== Combined @ t2 ===")
print(f"W1: {w1_t2_mean:.6f} ± {w1_t2_std:.6f}")
print(f"W2: {w2_t2_mean:.6f} ± {w2_t2_std:.6f}")
print(f"MMD: {mmd_t2_mean:.6f} ± {mmd_t2_std:.6f}")
# Create results directory structure
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
figures_dir = f'{results_dir}/figures'
os.makedirs(figures_dir, exist_ok=True)
# Save metrics to JSON
metrics_path = f'{results_dir}/metrics.json'
with open(metrics_path, 'w') as f:
json.dump(metrics_dict, f, indent=2)
print(f"Metrics saved to {metrics_path}")
# Save detailed metrics to CSV
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
with open(detailed_csv_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
for key in sorted(metrics_dict.keys()):
m = metrics_dict[key]
writer.writerow([key,
f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}',
f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}',
f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}'])
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
# ===== Plot individual branches (using full t_span trajectories) =====
self._plot_mouse_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2)
# ===== Plot all branches together =====
self._plot_mouse_combined(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2)
print(f"Mouse figures saved to {figures_dir}")
def _plot_mouse_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2):
"""Plot each branch separately with timepoint background."""
n_branches = len(all_trajs)
branch_names = [f'Branch {i+1}' for i in range(n_branches)]
branch_colors = ['#B83CFF', '#50B2D7'][:n_branches]
cmaps = [cmap1, cmap2][:n_branches]
# Stack list-of-tensors into [B, T, D] numpy arrays
all_trajs_np = []
for traj in all_trajs:
if isinstance(traj, list):
traj = torch.stack(traj, dim=1) # list of [B,D] -> [B,T,D]
all_trajs_np.append(traj.cpu().detach().numpy())
all_trajs = all_trajs_np
# Move timepoint data to numpy
for key in list(timepoint_data.keys()):
if torch.is_tensor(timepoint_data[key]):
timepoint_data[key] = timepoint_data[key].cpu().numpy()
# Compute global axis limits
all_coords = []
for key in ['t0', 't1', 't2', 't2_1', 't2_2']:
if key in timepoint_data:
all_coords.append(timepoint_data[key][:, :2])
for traj_np in all_trajs:
all_coords.append(traj_np.reshape(-1, traj_np.shape[-1])[:, :2])
all_coords = np.concatenate(all_coords, axis=0)
x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
# Add margin
x_margin = 0.05 * (x_max - x_min)
y_margin = 0.05 * (y_max - y_min)
x_min -= x_margin
x_max += x_margin
y_min -= y_margin
y_max += y_margin
for i, traj in enumerate(all_trajs):
fig, ax = plt.subplots(figsize=(10, 8))
cmap = cmaps[i]
c_end = branch_colors[i]
# Plot timepoint background
t2_key = f't2_{i+1}' if f't2_{i+1}' in timepoint_data else 't2'
coords_list = [timepoint_data['t0'], timepoint_data['t1'], timepoint_data[t2_key]]
tp_colors = ['#05009E', '#A19EFF', c_end]
tp_labels = ["t=0", "t=1", f"t=2 (branch {i+1})"]
for coords, color, label in zip(coords_list, tp_colors, tp_labels):
alpha = 0.8 if color == '#05009E' else 0.6
ax.scatter(coords[:, 0], coords[:, 1],
c=color, s=80, alpha=alpha, marker='x',
label=f'{label} cells', linewidth=1.5)
# Plot continuous trajectories with LineCollection for speed
traj_2d = traj[:, :, :2]
n_time = traj_2d.shape[1]
color_vals = cmap(np.linspace(0, 1, n_time))
segments = []
seg_colors = []
for j in range(traj_2d.shape[0]):
pts = traj_2d[j] # [T, 2]
segs = np.stack([pts[:-1], pts[1:]], axis=1)
segments.append(segs)
seg_colors.append(color_vals[:-1])
segments = np.concatenate(segments, axis=0)
seg_colors = np.concatenate(seg_colors, axis=0)
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
ax.add_collection(lc)
# Start and end points
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
c='#05009E', s=30, marker='o', label='Trajectory Start',
zorder=5, edgecolors='white', linewidth=1)
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
c=c_end, s=30, marker='o', label='Trajectory End',
zorder=5, edgecolors='white', linewidth=1)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel("PC1", fontsize=12)
ax.set_ylabel("PC2", fontsize=12)
ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=12, frameon=False)
plt.tight_layout()
plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
plt.close()
def _plot_mouse_combined(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2):
"""Plot all branches together."""
n_branches = len(all_trajs)
branch_names = [f'Branch {i+1}' for i in range(n_branches)]
branch_colors = ['#B83CFF', '#50B2D7'][:n_branches]
# Build timepoint key/color/label lists depending on branching
if 't2_1' in timepoint_data:
tp_keys = ['t0', 't1', 't2_1', 't2_2']
tp_colors = ['#05009E', '#A19EFF', '#B83CFF', '#50B2D7']
tp_labels = ['t=0', 't=1', 't=2 (branch 1)', 't=2 (branch 2)']
else:
tp_keys = ['t0', 't1', 't2']
tp_colors = ['#05009E', '#A19EFF', '#B83CFF']
tp_labels = ['t=0', 't=1', 't=2']
# Stack list-of-tensors into [B, T, D] numpy arrays
all_trajs_np = []
for traj in all_trajs:
if isinstance(traj, list):
traj = torch.stack(traj, dim=1)
if torch.is_tensor(traj):
traj = traj.cpu().detach().numpy()
all_trajs_np.append(traj)
all_trajs = all_trajs_np
# Move timepoint data to numpy
for key in list(timepoint_data.keys()):
if torch.is_tensor(timepoint_data[key]):
timepoint_data[key] = timepoint_data[key].cpu().numpy()
fig, ax = plt.subplots(figsize=(12, 10))
# Plot timepoint background
for idx, (t_key, color, label) in enumerate(zip(
tp_keys,
tp_colors,
tp_labels
)):
if t_key in timepoint_data:
coords = timepoint_data[t_key]
ax.scatter(coords[:, 0], coords[:, 1],
c=color, s=80, alpha=0.4, marker='x',
label=f'{label} cells', linewidth=1.5)
# Plot trajectories with color gradients
cmaps = [cmap1, cmap2]
for i, traj in enumerate(all_trajs):
traj_2d = traj[:, :, :2]
c_end = branch_colors[i]
cmap = cmaps[i]
n_time = traj_2d.shape[1]
color_vals = cmap(np.linspace(0, 1, n_time))
segments = []
seg_colors = []
for j in range(traj_2d.shape[0]):
pts = traj_2d[j]
segs = np.stack([pts[:-1], pts[1:]], axis=1)
segments.append(segs)
seg_colors.append(color_vals[:-1])
segments = np.concatenate(segments, axis=0)
seg_colors = np.concatenate(seg_colors, axis=0)
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
ax.add_collection(lc)
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
c='#05009E', s=30, marker='o',
label=f'{branch_names[i]} Start',
zorder=5, edgecolors='white', linewidth=1)
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
c=c_end, s=30, marker='o',
label=f'{branch_names[i]} End',
zorder=5, edgecolors='white', linewidth=1)
ax.set_xlabel("PC1", fontsize=14)
ax.set_ylabel("PC2", fontsize=14)
ax.set_title("All Branch Trajectories with Timepoint Background",
fontsize=16, weight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=12, frameon=False)
plt.tight_layout()
plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
plt.close()
class FlowNetTestClonidine(BranchFlowNetTrainBase):
"""Test class for Clonidine perturbation experiment (1 or 2 branches)."""
def test_step(self, batch, batch_idx):
# Handle both dict and tuple batch formats from CombinedLoader
if isinstance(batch, dict) and "test_samples" in batch:
# New format: {"test_samples": {...}, "metric_samples": {...}}
main_batch = batch["test_samples"]
elif isinstance(batch, (list, tuple)) and len(batch) >= 1:
# Old format with nested structure
test_samples = batch[0]
if isinstance(test_samples, dict) and "test_samples" in test_samples:
main_batch = test_samples["test_samples"][0]
else:
main_batch = test_samples
else:
# Fallback
main_batch = batch
# Get timepoint data
timepoint_data = self.trainer.datamodule.get_timepoint_data()
device = main_batch["x0"][0].device
# Use val x0 as initial conditions
x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
t_span = torch.linspace(0, 1, 100).to(device)
# Define color schemes for clonidine (2 branches)
custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"]
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1)
custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2)
all_trajs = []
all_endpoints = []
for i, flow_net in enumerate(self.flow_nets):
node = NeuralODE(
flow_model_torch_wrapper(flow_net),
solver="euler",
sensitivity="adjoint",
)
with torch.no_grad():
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
traj = torch.transpose(traj, 0, 1) # [B, T, D]
all_trajs.append(traj)
all_endpoints.append(traj[:, -1, :])
# Run 5 trials with random subsampling for robust metrics
n_trials = 5
n_branches = len(self.flow_nets)
# Gather per-branch ground truth
gt_data_per_branch = []
for i in range(n_branches):
if n_branches == 1:
key = 't1'
else:
key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32))
gt_all = torch.cat(gt_data_per_branch, dim=0)
# Per-branch metrics (5 trials)
metrics_dict = {}
for i in range(n_branches):
w1_br, w2_br, mmd_br = [], [], []
pred = all_endpoints[i]
gt = gt_data_per_branch[i]
for trial in range(n_trials):
n_min = min(pred.shape[0], gt.shape[0])
perm_pred = torch.randperm(pred.shape[0])[:n_min]
perm_gt = torch.randperm(gt.shape[0])[:n_min]
m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2])
w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
metrics_dict[f"branch_{i+1}"] = {
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
}
self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True)
print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
# Combined metrics (5 trials)
pred_all = torch.cat(all_endpoints, dim=0)
w1_trials, w2_trials, mmd_trials = [], [], []
for trial in range(n_trials):
n_min = min(pred_all.shape[0], gt_all.shape[0])
perm_pred = torch.randperm(pred_all.shape[0])[:n_min]
perm_gt = torch.randperm(gt_all.shape[0])[:n_min]
m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2])
w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
self.log("test/W1_t1_combined", w1_mean, on_epoch=True)
self.log("test/W2_t1_combined", w2_mean, on_epoch=True)
self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True)
metrics_dict['t1_combined'] = {
"W1_mean": float(w1_mean), "W1_std": float(w1_std),
"W2_mean": float(w2_mean), "W2_std": float(w2_std),
"MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
"n_trials": n_trials,
}
print(f"\n=== Combined @ t1 ===")
print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
# Create results directory structure
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
figures_dir = f'{results_dir}/figures'
os.makedirs(figures_dir, exist_ok=True)
# Save metrics to JSON
metrics_path = f'{results_dir}/metrics.json'
with open(metrics_path, 'w') as f:
json.dump(metrics_dict, f, indent=2)
print(f"Metrics saved to {metrics_path}")
# Save detailed metrics to CSV
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
with open(detailed_csv_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
for key in sorted(metrics_dict.keys()):
m = metrics_dict[key]
writer.writerow([key,
f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}',
f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}',
f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}'])
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
# ===== Plot branches =====
self._plot_clonidine_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2)
self._plot_clonidine_combined(all_trajs, timepoint_data, figures_dir)
print(f"Clonidine figures saved to {figures_dir}")
def _plot_clonidine_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2):
"""Plot each branch separately."""
branch_names = ['Branch 1', 'Branch 2']
branch_colors = ['#B83CFF', '#50B2D7']
cmaps = [cmap1, cmap2]
# Compute global axis limits – handle single vs multi branch keys
all_coords = []
if 't1_1' in timepoint_data:
tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))]
else:
tp_keys = ['t0', 't1']
for key in tp_keys:
all_coords.append(timepoint_data[key][:, :2])
for traj in all_trajs:
all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2])
all_coords = np.concatenate(all_coords, axis=0)
x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
x_margin = 0.05 * (x_max - x_min)
y_margin = 0.05 * (y_max - y_min)
x_min -= x_margin
x_max += x_margin
y_min -= y_margin
y_max += y_margin
for i, traj in enumerate(all_trajs):
fig, ax = plt.subplots(figsize=(10, 8))
c_end = branch_colors[i]
# Plot timepoint background
t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
coords_list = [timepoint_data['t0'], timepoint_data[t1_key]]
tp_colors = ['#05009E', c_end]
t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1"
tp_labels = ["t=0", t1_label]
for coords, color, label in zip(coords_list, tp_colors, tp_labels):
ax.scatter(coords[:, 0], coords[:, 1],
c=color, s=80, alpha=0.4, marker='x',
label=f'{label} cells', linewidth=1.5)
# Plot continuous trajectories with LineCollection for speed
traj_2d = traj[:, :, :2]
n_time = traj_2d.shape[1]
color_vals = cmaps[i](np.linspace(0, 1, n_time))
segments = []
seg_colors = []
for j in range(traj_2d.shape[0]):
pts = traj_2d[j]
segs = np.stack([pts[:-1], pts[1:]], axis=1)
segments.append(segs)
seg_colors.append(color_vals[:-1])
segments = np.concatenate(segments, axis=0)
seg_colors = np.concatenate(seg_colors, axis=0)
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
ax.add_collection(lc)
# Start and end points
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
c='#05009E', s=30, marker='o', label='Trajectory Start',
zorder=5, edgecolors='white', linewidth=1)
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
c=c_end, s=30, marker='o', label='Trajectory End',
zorder=5, edgecolors='white', linewidth=1)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel("PC1", fontsize=12)
ax.set_ylabel("PC2", fontsize=12)
ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=16, frameon=False)
plt.tight_layout()
plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
plt.close()
def _plot_clonidine_combined(self, all_trajs, timepoint_data, save_dir):
"""Plot all branches together."""
branch_names = ['Branch 1', 'Branch 2']
branch_colors = ['#B83CFF', '#50B2D7']
fig, ax = plt.subplots(figsize=(12, 10))
# Build timepoint keys/colors/labels depending on single vs multi branch
if 't1_1' in timepoint_data:
tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))]
tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))]
else:
tp_keys = ['t0', 't1']
tp_labels_list = ['t=0', 't=1']
tp_colors = ['#05009E', '#B83CFF', '#50B2D7'][:len(tp_keys)]
# Plot timepoint background
for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list):
coords = timepoint_data[t_key]
ax.scatter(coords[:, 0], coords[:, 1],
c=color, s=80, alpha=0.4, marker='x',
label=f'{label} cells', linewidth=1.5)
# Plot trajectories with color gradients
custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"]
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
cmaps = [
LinearSegmentedColormap.from_list("clon_cmap1", custom_colors_1),
LinearSegmentedColormap.from_list("clon_cmap2", custom_colors_2),
]
for i, traj in enumerate(all_trajs):
traj_2d = traj[:, :, :2]
c_end = branch_colors[i]
cmap = cmaps[i]
n_time = traj_2d.shape[1]
color_vals = cmap(np.linspace(0, 1, n_time))
segments = []
seg_colors = []
for j in range(traj_2d.shape[0]):
pts = traj_2d[j]
segs = np.stack([pts[:-1], pts[1:]], axis=1)
segments.append(segs)
seg_colors.append(color_vals[:-1])
segments = np.concatenate(segments, axis=0)
seg_colors = np.concatenate(seg_colors, axis=0)
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
ax.add_collection(lc)
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
c='#05009E', s=30, marker='o',
label=f'{branch_names[i]} Start',
zorder=5, edgecolors='white', linewidth=1)
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
c=c_end, s=30, marker='o',
label=f'{branch_names[i]} End',
zorder=5, edgecolors='white', linewidth=1)
ax.set_xlabel("PC1", fontsize=14)
ax.set_ylabel("PC2", fontsize=14)
ax.set_title("All Branch Trajectories with Timepoint Background",
fontsize=16, weight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=12, frameon=False)
plt.tight_layout()
plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
plt.close()
class FlowNetTestTrametinib(BranchFlowNetTrainBase):
"""Test class for Trametinib perturbation experiment (1 or 3 branches)."""
def test_step(self, batch, batch_idx):
# Handle both dict and tuple batch formats from CombinedLoader
if isinstance(batch, dict) and "test_samples" in batch:
# New format: {"test_samples": {...}, "metric_samples": {...}}
main_batch = batch["test_samples"]
elif isinstance(batch, (list, tuple)) and len(batch) >= 1:
# Old format with nested structure
test_samples = batch[0]
if isinstance(test_samples, dict) and "test_samples" in test_samples:
main_batch = test_samples["test_samples"][0]
else:
main_batch = test_samples
else:
# Fallback
main_batch = batch
# Get timepoint data
timepoint_data = self.trainer.datamodule.get_timepoint_data()
device = main_batch["x0"][0].device
# Use val x0 as initial conditions
x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
t_span = torch.linspace(0, 1, 100).to(device)
# Define color schemes for trametinib (3 branches)
custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"]
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
custom_colors_3 = ["#05009E", "#A19EFF", "#B83CFF"]
custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1)
custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2)
custom_cmap_3 = LinearSegmentedColormap.from_list("cmap3", custom_colors_3)
all_trajs = []
all_endpoints = []
for i, flow_net in enumerate(self.flow_nets):
node = NeuralODE(
flow_model_torch_wrapper(flow_net),
solver="euler",
sensitivity="adjoint",
)
with torch.no_grad():
traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
traj = torch.transpose(traj, 0, 1) # [B, T, D]
all_trajs.append(traj)
all_endpoints.append(traj[:, -1, :])
# Run 5 trials with random subsampling for robust metrics
n_trials = 5
n_branches = len(self.flow_nets)
# Gather per-branch ground truth
gt_data_per_branch = []
for i in range(n_branches):
if n_branches == 1:
key = 't1'
else:
key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32))
gt_all = torch.cat(gt_data_per_branch, dim=0)
# Per-branch metrics (5 trials)
metrics_dict = {}
for i in range(n_branches):
w1_br, w2_br, mmd_br = [], [], []
pred = all_endpoints[i]
gt = gt_data_per_branch[i]
for trial in range(n_trials):
n_min = min(pred.shape[0], gt.shape[0])
perm_pred = torch.randperm(pred.shape[0])[:n_min]
perm_gt = torch.randperm(gt.shape[0])[:n_min]
m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2])
w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
metrics_dict[f"branch_{i+1}"] = {
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
}
self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True)
print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
# Combined metrics (5 trials)
pred_all = torch.cat(all_endpoints, dim=0)
w1_trials, w2_trials, mmd_trials = [], [], []
for trial in range(n_trials):
n_min = min(pred_all.shape[0], gt_all.shape[0])
perm_pred = torch.randperm(pred_all.shape[0])[:n_min]
perm_gt = torch.randperm(gt_all.shape[0])[:n_min]
m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2])
w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
self.log("test/W1_t1_combined", w1_mean, on_epoch=True)
self.log("test/W2_t1_combined", w2_mean, on_epoch=True)
self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True)
metrics_dict['t1_combined'] = {
"W1_mean": float(w1_mean), "W1_std": float(w1_std),
"W2_mean": float(w2_mean), "W2_std": float(w2_std),
"MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
"n_trials": n_trials,
}
print(f"\n=== Combined @ t1 ===")
print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
# Create results directory structure
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
figures_dir = f'{results_dir}/figures'
os.makedirs(figures_dir, exist_ok=True)
# Save metrics to JSON
metrics_path = f'{results_dir}/metrics.json'
with open(metrics_path, 'w') as f:
json.dump(metrics_dict, f, indent=2)
print(f"Metrics saved to {metrics_path}")
# Save detailed metrics to CSV
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
with open(detailed_csv_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
for key in sorted(metrics_dict.keys()):
m = metrics_dict[key]
writer.writerow([key,
f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}',
f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}',
f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}'])
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
# ===== Plot branches =====
self._plot_trametinib_branches(all_trajs, timepoint_data, figures_dir,
custom_cmap_1, custom_cmap_2, custom_cmap_3)
self._plot_trametinib_combined(all_trajs, timepoint_data, figures_dir)
print(f"Trametinib figures saved to {figures_dir}")
def _plot_trametinib_branches(self, all_trajs, timepoint_data, save_dir,
cmap1, cmap2, cmap3):
"""Plot each branch separately."""
branch_names = ['Branch 1', 'Branch 2', 'Branch 3']
branch_colors = ['#9793F8', '#50B2D7', '#B83CFF']
cmaps = [cmap1, cmap2, cmap3]
# Compute global axis limits – handle single vs multi branch keys
all_coords = []
if 't1_1' in timepoint_data:
tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))]
else:
tp_keys = ['t0', 't1']
for key in tp_keys:
all_coords.append(timepoint_data[key][:, :2])
for traj in all_trajs:
all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2])
all_coords = np.concatenate(all_coords, axis=0)
x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
x_margin = 0.05 * (x_max - x_min)
y_margin = 0.05 * (y_max - y_min)
x_min -= x_margin
x_max += x_margin
y_min -= y_margin
y_max += y_margin
for i, traj in enumerate(all_trajs):
fig, ax = plt.subplots(figsize=(10, 8))
c_end = branch_colors[i]
# Plot timepoint background
t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
coords_list = [timepoint_data['t0'], timepoint_data[t1_key]]
tp_colors = ['#05009E', c_end]
t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1"
tp_labels = ["t=0", t1_label]
for coords, color, label in zip(coords_list, tp_colors, tp_labels):
ax.scatter(coords[:, 0], coords[:, 1],
c=color, s=80, alpha=0.4, marker='x',
label=f'{label} cells', linewidth=1.5)
# Plot continuous trajectories with LineCollection for speed
traj_2d = traj[:, :, :2]
n_time = traj_2d.shape[1]
color_vals = cmaps[i](np.linspace(0, 1, n_time))
segments = []
seg_colors = []
for j in range(traj_2d.shape[0]):
pts = traj_2d[j]
segs = np.stack([pts[:-1], pts[1:]], axis=1)
segments.append(segs)
seg_colors.append(color_vals[:-1])
segments = np.concatenate(segments, axis=0)
seg_colors = np.concatenate(seg_colors, axis=0)
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
ax.add_collection(lc)
# Start and end points
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
c='#05009E', s=30, marker='o', label='Trajectory Start',
zorder=5, edgecolors='white', linewidth=1)
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
c=c_end, s=30, marker='o', label='Trajectory End',
zorder=5, edgecolors='white', linewidth=1)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel("PC1", fontsize=12)
ax.set_ylabel("PC2", fontsize=12)
ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=16, frameon=False)
plt.tight_layout()
plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
plt.close()
def _plot_trametinib_combined(self, all_trajs, timepoint_data, save_dir):
"""Plot all 3 branches together."""
branch_names = ['Branch 1', 'Branch 2', 'Branch 3']
branch_colors = ['#9793F8', '#50B2D7', '#B83CFF']
fig, ax = plt.subplots(figsize=(12, 10))
# Build timepoint keys/colors/labels depending on single vs multi branch
if 't1_1' in timepoint_data:
tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))]
tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))]
else:
tp_keys = ['t0', 't1']
tp_labels_list = ['t=0', 't=1']
tp_colors = ['#05009E', '#9793F8', '#50B2D7', '#B83CFF'][:len(tp_keys)]
# Plot timepoint background
for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list):
coords = timepoint_data[t_key]
ax.scatter(coords[:, 0], coords[:, 1],
c=color, s=80, alpha=0.4, marker='x',
label=f'{label} cells', linewidth=1.5)
# Plot trajectories with color gradients
custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"]
custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
custom_colors_3 = ["#05009E", "#A19EFF", "#D577FF"]
cmaps = [
LinearSegmentedColormap.from_list("tram_cmap1", custom_colors_1),
LinearSegmentedColormap.from_list("tram_cmap2", custom_colors_2),
LinearSegmentedColormap.from_list("tram_cmap3", custom_colors_3),
]
for i, traj in enumerate(all_trajs):
traj_2d = traj[:, :, :2]
c_end = branch_colors[i]
cmap = cmaps[i]
n_time = traj_2d.shape[1]
color_vals = cmap(np.linspace(0, 1, n_time))
segments = []
seg_colors = []
for j in range(traj_2d.shape[0]):
pts = traj_2d[j]
segs = np.stack([pts[:-1], pts[1:]], axis=1)
segments.append(segs)
seg_colors.append(color_vals[:-1])
segments = np.concatenate(segments, axis=0)
seg_colors = np.concatenate(seg_colors, axis=0)
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
ax.add_collection(lc)
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
c='#05009E', s=30, marker='o',
label=f'{branch_names[i]} Start',
zorder=5, edgecolors='white', linewidth=1)
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
c=c_end, s=30, marker='o',
label=f'{branch_names[i]} End',
zorder=5, edgecolors='white', linewidth=1)
ax.set_xlabel("PC1", fontsize=14)
ax.set_ylabel("PC2", fontsize=14)
ax.set_title("All Branch Trajectories with Timepoint Background",
fontsize=16, weight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=12, frameon=False)
plt.tight_layout()
plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
plt.close()
class FlowNetTestVeres(GrowthNetTrain):
"""Test class for Veres pancreatic endocrinogenesis experiment (3 or 5 branches)."""
def test_step(self, batch, batch_idx):
# Handle both tuple and dict batch formats from CombinedLoader
if isinstance(batch, dict):
main_batch = batch["test_samples"][0]
metric_batch = batch["metric_samples"][0]
else:
# batch is a list/tuple
if isinstance(batch[0], dict):
# batch[0] contains the dict with test_samples and metric_samples
main_batch = batch[0]["test_samples"][0]
metric_batch = batch[0]["metric_samples"][0]
else:
# batch is a tuple: (test_samples, metric_samples)
main_batch = batch[0][0]
metric_batch = batch[1][0]
# Get timepoint data (full datasets, not just val split)
timepoint_data = self.trainer.datamodule.get_timepoint_data()
device = main_batch["x0"][0].device
# Use val x0 as initial conditions
x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device)
full_batch = {"x0": (x0_all, w0_all)}
time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch)
n_branches = len(self.flow_nets)
# trajectory time grid
t_span = torch.linspace(0, 1, 101).to(device)
# `all_trajs` returned from `get_mass_and_position` is expected to be a list where each
# element is a sequence of per-timepoint tensors for that branch (shape [B, D] each).
# Convert each branch to [T, B, D] then to [B, T, D] for downstream processing.
trajs_TBD = [torch.stack(branch_list, dim=0) for branch_list in all_trajs] # each is [T, B, D]
trajs_BTD = [t.permute(1, 0, 2) for t in trajs_TBD] # each -> [B, T, D]
all_trajs = []
all_endpoints = []
# will store per-branch intermediate frames: each entry -> tensor [B, n_intermediate, D]
all_intermediates = []
for traj in trajs_BTD:
# traj is [B, T, D]
# optionally inverse-transform if whitened
if self.whiten:
traj_np = traj.detach().cpu().numpy()
n_samples, n_time, n_dims = traj_np.shape
traj_flat = traj_np.reshape(-1, n_dims)
traj_inv_flat = self.trainer.datamodule.scaler.inverse_transform(traj_flat)
traj_inv = traj_inv_flat.reshape(n_samples, n_time, n_dims)
traj = torch.tensor(traj_inv, dtype=torch.float32)
all_trajs.append(traj)
# Collect six evenly spaced intermediate frames between t=0 and t=1 (exclude endpoints)
n_T = traj.shape[1]
# choose 8 points including endpoints -> take inner 6 as intermediates
inter_times = np.linspace(0.0, 1.0, 8)[1:-1] # 6 values
inter_indices = [int(round(t * (n_T - 1))) for t in inter_times]
# stack per-branch intermediate frames -> [B, 6, D]
intermediates = torch.stack([traj[:, idx, :] for idx in inter_indices], dim=1)
all_intermediates.append(intermediates)
# Final endpoints (t=1)
all_endpoints.append(traj[:, -1, :])
# Run 5 trials with random subsampling for robust metrics
n_trials = 5
metrics_dict = {}
# --- Intermediate timepoints (t1-t6) combined metrics ---
intermediate_keys = sorted([k for k in timepoint_data.keys()
if k.startswith('t') and '_' not in k and k != 't0'])
if intermediate_keys:
n_evals = min(6, len(intermediate_keys))
for j in range(n_evals):
intermediate_key = intermediate_keys[j]
true_data_intermediate = torch.tensor(timepoint_data[intermediate_key], dtype=torch.float32)
# Gather predicted intermediates across all branches
raw_intermediates = [branch[:, j, :] for branch in all_intermediates]
all_raw_concat = torch.cat(raw_intermediates, dim=0).cpu() # [n_branches*B, D]
w1_t, w2_t, mmd_t = [], [], []
w1_t_full, w2_t_full, mmd_t_full = [], [], []
for trial in range(n_trials):
n_min = min(all_raw_concat.shape[0], true_data_intermediate.shape[0])
perm_pred = torch.randperm(all_raw_concat.shape[0])[:n_min]
perm_gt = torch.randperm(true_data_intermediate.shape[0])[:n_min]
# 2D metrics (PC1-PC2)
m = compute_distribution_distances(
all_raw_concat[perm_pred, :2], true_data_intermediate[perm_gt, :2])
w1_t.append(m["W1"]); w2_t.append(m["W2"]); mmd_t.append(m["MMD"])
# Full-dimensional metrics (all PCs)
m_full = compute_distribution_distances(
all_raw_concat[perm_pred], true_data_intermediate[perm_gt])
w1_t_full.append(m_full["W1"]); w2_t_full.append(m_full["W2"]); mmd_t_full.append(m_full["MMD"])
metrics_dict[f'{intermediate_key}_combined'] = {
"W1_mean": float(np.mean(w1_t)), "W1_std": float(np.std(w1_t, ddof=1)),
"W2_mean": float(np.mean(w2_t)), "W2_std": float(np.std(w2_t, ddof=1)),
"MMD_mean": float(np.mean(mmd_t)), "MMD_std": float(np.std(mmd_t, ddof=1)),
"W1_full_mean": float(np.mean(w1_t_full)), "W1_full_std": float(np.std(w1_t_full, ddof=1)),
"W2_full_mean": float(np.mean(w2_t_full)), "W2_full_std": float(np.std(w2_t_full, ddof=1)),
"MMD_full_mean": float(np.mean(mmd_t_full)), "MMD_full_std": float(np.std(mmd_t_full, ddof=1)),
}
self.log(f"test/W1_{intermediate_key}_combined", np.mean(w1_t), on_epoch=True)
self.log(f"test/W1_full_{intermediate_key}_combined", np.mean(w1_t_full), on_epoch=True)
print(f"{intermediate_key} combined — W1: {np.mean(w1_t):.6f}±{np.std(w1_t, ddof=1):.6f}, "
f"W2: {np.mean(w2_t):.6f}±{np.std(w2_t, ddof=1):.6f}, "
f"MMD: {np.mean(mmd_t):.6f}±{np.std(mmd_t, ddof=1):.6f}")
print(f"{intermediate_key} combined (full) — W1: {np.mean(w1_t_full):.6f}±{np.std(w1_t_full, ddof=1):.6f}, "
f"W2: {np.mean(w2_t_full):.6f}±{np.std(w2_t_full, ddof=1):.6f}, "
f"MMD: {np.mean(mmd_t_full):.6f}±{np.std(mmd_t_full, ddof=1):.6f}")
# --- Final timepoint per-branch metrics ---
gt_keys = sorted([k for k in timepoint_data.keys() if k.startswith('t7_')])
for i, endpoints in enumerate(all_endpoints):
true_data_key = f"t7_{i}"
if true_data_key not in timepoint_data:
print(f"Warning: {true_data_key} not found in timepoint_data")
continue
gt = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32)
pred = endpoints.cpu()
w1_br, w2_br, mmd_br = [], [], []
w1_br_full, w2_br_full, mmd_br_full = [], [], []
for trial in range(n_trials):
n_min = min(pred.shape[0], gt.shape[0])
perm_pred = torch.randperm(pred.shape[0])[:n_min]
perm_gt = torch.randperm(gt.shape[0])[:n_min]
# 2D metrics (PC1-PC2)
m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2])
w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
# Full-dimensional metrics (all PCs)
m_full = compute_distribution_distances(pred[perm_pred], gt[perm_gt])
w1_br_full.append(m_full["W1"]); w2_br_full.append(m_full["W2"]); mmd_br_full.append(m_full["MMD"])
metrics_dict[f"branch_{i}"] = {
"W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
"W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
"MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
"W1_full_mean": float(np.mean(w1_br_full)), "W1_full_std": float(np.std(w1_br_full, ddof=1)),
"W2_full_mean": float(np.mean(w2_br_full)), "W2_full_std": float(np.std(w2_br_full, ddof=1)),
"MMD_full_mean": float(np.mean(mmd_br_full)), "MMD_full_std": float(np.std(mmd_br_full, ddof=1)),
}
self.log(f"test/W1_branch{i}", np.mean(w1_br), on_epoch=True)
self.log(f"test/W1_full_branch{i}", np.mean(w1_br_full), on_epoch=True)
print(f"Branch {i} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
print(f"Branch {i} (full) — W1: {np.mean(w1_br_full):.6f}±{np.std(w1_br_full, ddof=1):.6f}, "
f"W2: {np.mean(w2_br_full):.6f}±{np.std(w2_br_full, ddof=1):.6f}, "
f"MMD: {np.mean(mmd_br_full):.6f}±{np.std(mmd_br_full, ddof=1):.6f}")
# --- Final timepoint combined metrics ---
gt_list = [torch.tensor(timepoint_data[k], dtype=torch.float32) for k in gt_keys]
if len(gt_list) > 0 and len(all_endpoints) > 0:
gt_all = torch.cat(gt_list, dim=0)
pred_all = torch.cat([e.cpu() for e in all_endpoints], dim=0)
w1_trials, w2_trials, mmd_trials = [], [], []
w1_trials_full, w2_trials_full, mmd_trials_full = [], [], []
for trial in range(n_trials):
n_min = min(pred_all.shape[0], gt_all.shape[0])
perm_pred = torch.randperm(pred_all.shape[0])[:n_min]
perm_gt = torch.randperm(gt_all.shape[0])[:n_min]
# 2D metrics (PC1-PC2)
m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2])
w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
# Full-dimensional metrics (all PCs)
m_full = compute_distribution_distances(pred_all[perm_pred], gt_all[perm_gt])
w1_trials_full.append(m_full["W1"]); w2_trials_full.append(m_full["W2"]); mmd_trials_full.append(m_full["MMD"])
w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
w1_mean_f, w1_std_f = np.mean(w1_trials_full), np.std(w1_trials_full, ddof=1)
w2_mean_f, w2_std_f = np.mean(w2_trials_full), np.std(w2_trials_full, ddof=1)
mmd_mean_f, mmd_std_f = np.mean(mmd_trials_full), np.std(mmd_trials_full, ddof=1)
self.log("test/W1_t7_combined", w1_mean, on_epoch=True)
self.log("test/W2_t7_combined", w2_mean, on_epoch=True)
self.log("test/MMD_t7_combined", mmd_mean, on_epoch=True)
self.log("test/W1_full_t7_combined", w1_mean_f, on_epoch=True)
self.log("test/W2_full_t7_combined", w2_mean_f, on_epoch=True)
self.log("test/MMD_full_t7_combined", mmd_mean_f, on_epoch=True)
metrics_dict['t7_combined'] = {
"W1_mean": float(w1_mean), "W1_std": float(w1_std),
"W2_mean": float(w2_mean), "W2_std": float(w2_std),
"MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
"W1_full_mean": float(w1_mean_f), "W1_full_std": float(w1_std_f),
"W2_full_mean": float(w2_mean_f), "W2_full_std": float(w2_std_f),
"MMD_full_mean": float(mmd_mean_f), "MMD_full_std": float(mmd_std_f),
"n_trials": n_trials,
}
print(f"\n=== Combined @ t7 ===")
print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
print(f"W1 (full): {w1_mean_f:.6f} ± {w1_std_f:.6f}")
print(f"W2 (full): {w2_mean_f:.6f} ± {w2_std_f:.6f}")
print(f"MMD (full): {mmd_mean_f:.6f} ± {mmd_std_f:.6f}")
# Create results directory structure
run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
results_dir = os.path.join(self.args.working_dir, 'results', run_name)
figures_dir = f'{results_dir}/figures'
os.makedirs(figures_dir, exist_ok=True)
# Save metrics to JSON
metrics_path = f'{results_dir}/metrics.json'
with open(metrics_path, 'w') as f:
json.dump(metrics_dict, f, indent=2)
print(f"Metrics saved to {metrics_path}")
# Save detailed metrics to CSV
detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
with open(detailed_csv_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['Metric_Group',
'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std',
'W1_Full_Mean', 'W1_Full_Std', 'W2_Full_Mean', 'W2_Full_Std', 'MMD_Full_Mean', 'MMD_Full_Std'])
for key in sorted(metrics_dict.keys()):
m = metrics_dict[key]
writer.writerow([key,
f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}',
f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}',
f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}',
f'{m.get("W1_full_mean", 0):.6f}', f'{m.get("W1_full_std", 0):.6f}',
f'{m.get("W2_full_mean", 0):.6f}', f'{m.get("W2_full_std", 0):.6f}',
f'{m.get("MMD_full_mean", 0):.6f}', f'{m.get("MMD_full_std", 0):.6f}'])
print(f"Detailed metrics CSV saved to {detailed_csv_path}")
# ===== Plot branches =====
self._plot_veres_branches(all_trajs, timepoint_data, figures_dir, n_branches)
self._plot_veres_combined(all_trajs, timepoint_data, figures_dir, n_branches)
print(f"Veres figures saved to {figures_dir}")
def _plot_veres_branches(self, all_trajs, timepoint_data, save_dir, n_branches):
"""Plot each branch separately in PCA space (PC1 vs PC2)."""
branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9',
'#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8',
'#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8',
'#64B5F6', '#C5E1A5']
# Project to first 2 PCs (data is already in PCA space)
t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2]
t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)]
# Slice trajectories to first 2 PCs
trajs_2d = []
for traj in all_trajs:
trajs_2d.append(traj.cpu().numpy()[:, :, :2]) # [n_samples, n_time, 2]
# Compute global axis limits
all_coords = [t0_2d] + t7_2d
for traj_2d in trajs_2d:
all_coords.append(traj_2d.reshape(-1, 2))
all_coords = np.concatenate(all_coords, axis=0)
x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
x_margin = 0.05 * (x_max - x_min)
y_margin = 0.05 * (y_max - y_min)
x_min -= x_margin
x_max += x_margin
y_min -= y_margin
y_max += y_margin
for i, traj_2d in enumerate(trajs_2d):
fig, ax = plt.subplots(figsize=(10, 8))
c_end = branch_colors[i % len(branch_colors)]
# Plot timepoint background
ax.scatter(t0_2d[:, 0], t0_2d[:, 1],
c='#05009E', s=80, alpha=0.4, marker='x',
label='t=0 cells', linewidth=1.5)
ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1],
c=c_end, s=80, alpha=0.4, marker='x',
label=f't=7 (branch {i+1}) cells', linewidth=1.5)
# Plot continuous trajectories with LineCollection for speed
cmap_colors = ["#05009E", "#A19EFF", c_end]
cmap = LinearSegmentedColormap.from_list(f"veres_cmap_{i}", cmap_colors)
n_time = traj_2d.shape[1]
segments = []
seg_colors = []
color_vals = cmap(np.linspace(0, 1, n_time))
for j in range(traj_2d.shape[0]):
pts = traj_2d[j] # [T, 2]
segs = np.stack([pts[:-1], pts[1:]], axis=1) # [T-1, 2, 2]
segments.append(segs)
seg_colors.append(color_vals[:-1])
segments = np.concatenate(segments, axis=0)
seg_colors = np.concatenate(seg_colors, axis=0)
lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
ax.add_collection(lc)
# Start and end points
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
c='#05009E', s=30, marker='o', label='Trajectory start (t=0)',
zorder=5, edgecolors='white', linewidth=1)
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
c=c_end, s=30, marker='o', label='Trajectory end (t=1)',
zorder=5, edgecolors='white', linewidth=1)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xlabel("PC 1", fontsize=12)
ax.set_ylabel("PC 2", fontsize=12)
ax.set_title(f"Branch {i+1}: Trajectories (PCA)", fontsize=14)
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=9, frameon=False)
plt.tight_layout()
plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
plt.close()
def _plot_veres_combined(self, all_trajs, timepoint_data, save_dir, n_branches):
"""Plot all branches together in PCA space (PC1 vs PC2)."""
branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9',
'#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8',
'#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8',
'#64B5F6', '#C5E1A5']
# Project to first 2 PCs (data is already in PCA space)
t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2]
t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)]
# Slice trajectories to first 2 PCs
trajs_2d = []
for traj in all_trajs:
trajs_2d.append(traj.cpu().numpy()[:, :, :2]) # [n_samples, n_time, 2]
# Compute axis limits from REAL CELLS ONLY
all_coords_real = [t0_2d] + t7_2d
all_coords_real = np.concatenate(all_coords_real, axis=0)
x_min, x_max = all_coords_real[:, 0].min(), all_coords_real[:, 0].max()
y_min, y_max = all_coords_real[:, 1].min(), all_coords_real[:, 1].max()
x_margin = 0.05 * (x_max - x_min)
y_margin = 0.05 * (y_max - y_min)
x_min -= x_margin
x_max += x_margin
y_min -= y_margin
y_max += y_margin
fig, ax = plt.subplots(figsize=(14, 12))
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
# Plot t=0 cells
ax.scatter(t0_2d[:, 0], t0_2d[:, 1],
c='#05009E', s=60, alpha=0.3, marker='x',
label='t=0 cells', linewidth=1.5)
# Plot each branch's cells and trajectories
for i, traj_2d in enumerate(trajs_2d):
c_end = branch_colors[i % len(branch_colors)]
# Plot t=7 cells for this branch
ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1],
c=c_end, s=60, alpha=0.3, marker='x',
label=f't=7 (branch {i+1})', linewidth=1.5)
# Plot continuous trajectories with LineCollection for speed
cmap_colors = ["#05009E", "#A19EFF", c_end]
cmap = LinearSegmentedColormap.from_list(f"veres_combined_cmap_{i}", cmap_colors)
n_time = traj_2d.shape[1]
segments = []
seg_colors = []
color_vals = cmap(np.linspace(0, 1, n_time))
for j in range(traj_2d.shape[0]):
pts = traj_2d[j] # [T, 2]
segs = np.stack([pts[:-1], pts[1:]], axis=1) # [T-1, 2, 2]
segments.append(segs)
seg_colors.append(color_vals[:-1])
segments = np.concatenate(segments, axis=0)
seg_colors = np.concatenate(seg_colors, axis=0)
lc = LineCollection(segments, colors=seg_colors, linewidths=1.5, alpha=0.6)
ax.add_collection(lc)
# Start and end points
ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
c='#05009E', s=20, marker='o',
zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7)
ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
c=c_end, s=20, marker='o',
zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7)
ax.set_xlabel("PC 1", fontsize=14)
ax.set_ylabel("PC 2", fontsize=14)
ax.set_title(f"All {n_branches} Branch Trajectories (Veres) - PCA Projection",
fontsize=16, weight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=10, frameon=False, ncol=2)
plt.tight_layout()
plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
plt.close()