import sys import os import glob import re from pathlib import Path import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt from omegaconf import OmegaConf # Add src to the system path so we can import architecture and metrics src_dir = Path(__file__).resolve().parent.parent / "src" sys.path.append(str(src_dir)) from training import make_data_loaders, evaluate_stage1, evaluate_stage2, SUBJECTS from medarc_architecture import MultiSubjectConvLinearEncoder from stage2.CFM import CFM def plot_heatmap(acc_map, title, save_path): """ Plots a 1D tensor (brain region voxels) as a 2D wrapped heatmap grid. Pads the flat array with NaNs to make it square. """ V = acc_map.shape[0] # Calculate square dimensions (e.g., 1000 -> 32 x 32 grid roughly) side = int(np.ceil(np.sqrt(V))) # Pad to make it perfectly square padded = np.pad(acc_map, (0, side * side - V), constant_values=np.nan) grid = padded.reshape(side, side) plt.figure(figsize=(8, 6)) # Using 'viridis' or 'hot' for heatmap representation # vmin/vmax set to typical Pearson's R bounds for visuals or let it auto-scale im = plt.imshow(grid, cmap="viridis", aspect="auto") plt.colorbar(im, label="Pearson's r Correlation") plt.title(title) plt.tight_layout() plt.savefig(save_path, dpi=150) plt.close() print(f"Saved heatmap {title} to {save_path}") def main(): root_dir = Path("/workspace/code/flow_matching") # Assuming standard directory where training script drops them run_dir = root_dir / "output" / "two_stage_encoding" config_path = run_dir / "config.yaml" if not config_path.exists(): print(f"Configuration file not found at {config_path}!") return cfg = OmegaConf.load(config_path) device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") print("Building DataLoaders...") data_loaders = make_data_loaders(cfg) val_loader = data_loaders[cfg.val_set_name] sample_batch = next(iter(val_loader)) feat_dims = [f.shape[-1] for f in sample_batch["features"]] target_dim = sample_batch["fmri"].shape[-1] subjects_list = cfg.get("subjects", SUBJECTS) print(f"Found target Voxel bounds (V): {target_dim}") print("--- Stage 1 Checkout ---") stage1_model = MultiSubjectConvLinearEncoder( num_subjects=len(subjects_list), feat_dims=feat_dims, **cfg.stage1.model ).to(device) stage1_ckpt = run_dir / "stage1_best.pt" if not stage1_ckpt.exists(): print(f"Cannot find Stage 1 checkpoint at {stage1_ckpt}") return print(f"Loading Stage 1 Mean Anchor from {stage1_ckpt.name}...") stage1_model.load_state_dict(torch.load(stage1_ckpt, map_location=device)) # Execute Stage 1 evaluate to compute Pearson map acc_s1, metrics_s1 = evaluate_stage1( epoch=0, model=stage1_model, val_loader=val_loader, device=device, subjects=subjects_list, ds_name=cfg.val_set_name ) heatmap_dir = run_dir / "heatmaps" heatmap_dir.mkdir(exist_ok=True, parents=True) print(f"Stage 1 Overall Pearson's r: {acc_s1:.4f}") # Visualize stage 1 mean anchors for sub in subjects_list: acc_map = metrics_s1[f"accmap_sub-{sub}"] mean_r = acc_map.mean().item() if isinstance(acc_map, torch.Tensor) else np.mean(acc_map) print(f"Stage 1 - Sub {sub} Mean Pearson's r: {mean_r:.4f}") plot_heatmap(acc_map, f"Stage 1 Best (Mean Anchor) - Sub {sub}", heatmap_dir / f"stage1_sub{sub}.png") print("\n--- Stage 2 Checkout ---") cfm_params = cfg.stage2.cfm velocity_net_params = cfg.stage2.velocity_net source_ve_params = cfg.stage2.source_ve transport_params = cfg.stage2.transport stage2_models = nn.ModuleDict() for sub in subjects_list: sub_key = str(sub) cfm_model = CFM( feat_dim=target_dim, cfm_params=cfm_params, velocity_net_params=velocity_net_params, source_ve_params=source_ve_params, transport_params=transport_params, ).to(device) stage2_models[sub_key] = cfm_model # Locate and sort all consecutive stage2 weights stage2_ckpts = list(run_dir.glob("stage2_epoch_*.pt")) def get_epoch(p): m = re.search(r"stage2_epoch_(\d+).pt", p.name) return int(m.group(1)) if m else -1 stage2_ckpts.sort(key=get_epoch) if not stage2_ckpts: print("No Stage 2 configurations found to visualize!") return # Evaluate each stage 2 checkpoint consecutively and map visualizations for ckpt in stage2_ckpts: ep = get_epoch(ckpt) print(f"\nProcessing Vector Field {ckpt.name}...") stage2_models.load_state_dict(torch.load(ckpt, map_location=device)) acc_s2, metrics_s2 = evaluate_stage2( epoch=ep, stage1_model=stage1_model, stage2_models=stage2_models, val_loader=val_loader, device=device, subjects=subjects_list, ds_name=cfg.val_set_name, n_timesteps=cfg.stage2.get("n_timesteps", 25) ) print(f"Stage 2 Epoch {ep} Overall Pearson's r: {acc_s2:.4f}") for sub in subjects_list: acc_map = metrics_s2[f"accmap_sub-{sub}"] mean_r = acc_map.mean().item() if isinstance(acc_map, torch.Tensor) else np.mean(acc_map) print(f"Stage 2 Epoch {ep} - Sub {sub} Mean Pearson's r: {mean_r:.4f}") plot_heatmap(acc_map, f"Stage 2 Epoch {ep} - Sub {sub}", heatmap_dir / f"stage2_ep{ep}_sub{sub}.png") print("\nVisualizations effectively correlated & processed! 🧠") if __name__ == "__main__": main()