| import sys |
| import os |
| import glob |
| import re |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from omegaconf import OmegaConf |
|
|
| |
| src_dir = Path(__file__).resolve().parent.parent / "src" |
| sys.path.append(str(src_dir)) |
|
|
| from training import make_data_loaders, evaluate_stage1, evaluate_stage2, SUBJECTS |
| from medarc_architecture import MultiSubjectConvLinearEncoder |
| from stage2.CFM import CFM |
|
|
| def plot_heatmap(acc_map, title, save_path): |
| """ |
| Plots a 1D tensor (brain region voxels) as a 2D wrapped heatmap grid. |
| Pads the flat array with NaNs to make it square. |
| """ |
| V = acc_map.shape[0] |
| |
| |
| side = int(np.ceil(np.sqrt(V))) |
| |
| |
| padded = np.pad(acc_map, (0, side * side - V), constant_values=np.nan) |
| grid = padded.reshape(side, side) |
|
|
| plt.figure(figsize=(8, 6)) |
| |
| |
| |
| im = plt.imshow(grid, cmap="viridis", aspect="auto") |
| plt.colorbar(im, label="Pearson's r Correlation") |
| |
| plt.title(title) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=150) |
| plt.close() |
| print(f"Saved heatmap {title} to {save_path}") |
|
|
| def main(): |
| root_dir = Path("/workspace/code/flow_matching") |
| |
| |
| run_dir = root_dir / "output" / "two_stage_encoding" |
| config_path = run_dir / "config.yaml" |
| |
| if not config_path.exists(): |
| print(f"Configuration file not found at {config_path}!") |
| return |
|
|
| cfg = OmegaConf.load(config_path) |
| device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| print("Building DataLoaders...") |
| data_loaders = make_data_loaders(cfg) |
| val_loader = data_loaders[cfg.val_set_name] |
| |
| sample_batch = next(iter(val_loader)) |
| feat_dims = [f.shape[-1] for f in sample_batch["features"]] |
| target_dim = sample_batch["fmri"].shape[-1] |
| subjects_list = cfg.get("subjects", SUBJECTS) |
|
|
| print(f"Found target Voxel bounds (V): {target_dim}") |
| print("--- Stage 1 Checkout ---") |
| stage1_model = MultiSubjectConvLinearEncoder( |
| num_subjects=len(subjects_list), |
| feat_dims=feat_dims, |
| **cfg.stage1.model |
| ).to(device) |
|
|
| stage1_ckpt = run_dir / "stage1_best.pt" |
| if not stage1_ckpt.exists(): |
| print(f"Cannot find Stage 1 checkpoint at {stage1_ckpt}") |
| return |
| |
| print(f"Loading Stage 1 Mean Anchor from {stage1_ckpt.name}...") |
| stage1_model.load_state_dict(torch.load(stage1_ckpt, map_location=device)) |
| |
| |
| acc_s1, metrics_s1 = evaluate_stage1( |
| epoch=0, |
| model=stage1_model, |
| val_loader=val_loader, |
| device=device, |
| subjects=subjects_list, |
| ds_name=cfg.val_set_name |
| ) |
| |
| heatmap_dir = run_dir / "heatmaps" |
| heatmap_dir.mkdir(exist_ok=True, parents=True) |
| |
| print(f"Stage 1 Overall Pearson's r: {acc_s1:.4f}") |
| |
| for sub in subjects_list: |
| acc_map = metrics_s1[f"accmap_sub-{sub}"] |
| mean_r = acc_map.mean().item() if isinstance(acc_map, torch.Tensor) else np.mean(acc_map) |
| print(f"Stage 1 - Sub {sub} Mean Pearson's r: {mean_r:.4f}") |
| plot_heatmap(acc_map, f"Stage 1 Best (Mean Anchor) - Sub {sub}", heatmap_dir / f"stage1_sub{sub}.png") |
|
|
| print("\n--- Stage 2 Checkout ---") |
| cfm_params = cfg.stage2.cfm |
| velocity_net_params = cfg.stage2.velocity_net |
| source_ve_params = cfg.stage2.source_ve |
| transport_params = cfg.stage2.transport |
| stage2_models = nn.ModuleDict() |
|
|
| for sub in subjects_list: |
| sub_key = str(sub) |
| cfm_model = CFM( |
| feat_dim=target_dim, |
| cfm_params=cfm_params, |
| velocity_net_params=velocity_net_params, |
| source_ve_params=source_ve_params, |
| transport_params=transport_params, |
| ).to(device) |
| stage2_models[sub_key] = cfm_model |
|
|
| |
| stage2_ckpts = list(run_dir.glob("stage2_epoch_*.pt")) |
| def get_epoch(p): |
| m = re.search(r"stage2_epoch_(\d+).pt", p.name) |
| return int(m.group(1)) if m else -1 |
| |
| stage2_ckpts.sort(key=get_epoch) |
| |
| if not stage2_ckpts: |
| print("No Stage 2 configurations found to visualize!") |
| return |
|
|
| |
| for ckpt in stage2_ckpts: |
| ep = get_epoch(ckpt) |
| print(f"\nProcessing Vector Field {ckpt.name}...") |
| |
| stage2_models.load_state_dict(torch.load(ckpt, map_location=device)) |
| |
| acc_s2, metrics_s2 = evaluate_stage2( |
| epoch=ep, |
| stage1_model=stage1_model, |
| stage2_models=stage2_models, |
| val_loader=val_loader, |
| device=device, |
| subjects=subjects_list, |
| ds_name=cfg.val_set_name, |
| n_timesteps=cfg.stage2.get("n_timesteps", 25) |
| ) |
| |
| print(f"Stage 2 Epoch {ep} Overall Pearson's r: {acc_s2:.4f}") |
| for sub in subjects_list: |
| acc_map = metrics_s2[f"accmap_sub-{sub}"] |
| mean_r = acc_map.mean().item() if isinstance(acc_map, torch.Tensor) else np.mean(acc_map) |
| print(f"Stage 2 Epoch {ep} - Sub {sub} Mean Pearson's r: {mean_r:.4f}") |
| plot_heatmap(acc_map, f"Stage 2 Epoch {ep} - Sub {sub}", heatmap_dir / f"stage2_ep{ep}_sub{sub}.png") |
| |
| print("\nVisualizations effectively correlated & processed! 🧠") |
|
|
| if __name__ == "__main__": |
| main() |
|
|