File size: 5,916 Bytes
4edc9aa 0254260 4edc9aa 0254260 4edc9aa 0254260 4edc9aa 0254260 4edc9aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | import sys
import os
import glob
import re
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
# Add src to the system path so we can import architecture and metrics
src_dir = Path(__file__).resolve().parent.parent / "src"
sys.path.append(str(src_dir))
from training import make_data_loaders, evaluate_stage1, evaluate_stage2, SUBJECTS
from medarc_architecture import MultiSubjectConvLinearEncoder
from stage2.CFM import CFM
def plot_heatmap(acc_map, title, save_path):
"""
Plots a 1D tensor (brain region voxels) as a 2D wrapped heatmap grid.
Pads the flat array with NaNs to make it square.
"""
V = acc_map.shape[0]
# Calculate square dimensions (e.g., 1000 -> 32 x 32 grid roughly)
side = int(np.ceil(np.sqrt(V)))
# Pad to make it perfectly square
padded = np.pad(acc_map, (0, side * side - V), constant_values=np.nan)
grid = padded.reshape(side, side)
plt.figure(figsize=(8, 6))
# Using 'viridis' or 'hot' for heatmap representation
# vmin/vmax set to typical Pearson's R bounds for visuals or let it auto-scale
im = plt.imshow(grid, cmap="viridis", aspect="auto")
plt.colorbar(im, label="Pearson's r Correlation")
plt.title(title)
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.close()
print(f"Saved heatmap {title} to {save_path}")
def main():
root_dir = Path("/workspace/code/flow_matching")
# Assuming standard directory where training script drops them
run_dir = root_dir / "output" / "two_stage_encoding"
config_path = run_dir / "config.yaml"
if not config_path.exists():
print(f"Configuration file not found at {config_path}!")
return
cfg = OmegaConf.load(config_path)
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Building DataLoaders...")
data_loaders = make_data_loaders(cfg)
val_loader = data_loaders[cfg.val_set_name]
sample_batch = next(iter(val_loader))
feat_dims = [f.shape[-1] for f in sample_batch["features"]]
target_dim = sample_batch["fmri"].shape[-1]
subjects_list = cfg.get("subjects", SUBJECTS)
print(f"Found target Voxel bounds (V): {target_dim}")
print("--- Stage 1 Checkout ---")
stage1_model = MultiSubjectConvLinearEncoder(
num_subjects=len(subjects_list),
feat_dims=feat_dims,
**cfg.stage1.model
).to(device)
stage1_ckpt = run_dir / "stage1_best.pt"
if not stage1_ckpt.exists():
print(f"Cannot find Stage 1 checkpoint at {stage1_ckpt}")
return
print(f"Loading Stage 1 Mean Anchor from {stage1_ckpt.name}...")
stage1_model.load_state_dict(torch.load(stage1_ckpt, map_location=device))
# Execute Stage 1 evaluate to compute Pearson map
acc_s1, metrics_s1 = evaluate_stage1(
epoch=0,
model=stage1_model,
val_loader=val_loader,
device=device,
subjects=subjects_list,
ds_name=cfg.val_set_name
)
heatmap_dir = run_dir / "heatmaps"
heatmap_dir.mkdir(exist_ok=True, parents=True)
print(f"Stage 1 Overall Pearson's r: {acc_s1:.4f}")
# Visualize stage 1 mean anchors
for sub in subjects_list:
acc_map = metrics_s1[f"accmap_sub-{sub}"]
mean_r = acc_map.mean().item() if isinstance(acc_map, torch.Tensor) else np.mean(acc_map)
print(f"Stage 1 - Sub {sub} Mean Pearson's r: {mean_r:.4f}")
plot_heatmap(acc_map, f"Stage 1 Best (Mean Anchor) - Sub {sub}", heatmap_dir / f"stage1_sub{sub}.png")
print("\n--- Stage 2 Checkout ---")
cfm_params = cfg.stage2.cfm
velocity_net_params = cfg.stage2.velocity_net
source_ve_params = cfg.stage2.source_ve
transport_params = cfg.stage2.transport
stage2_models = nn.ModuleDict()
for sub in subjects_list:
sub_key = str(sub)
cfm_model = CFM(
feat_dim=target_dim,
cfm_params=cfm_params,
velocity_net_params=velocity_net_params,
source_ve_params=source_ve_params,
transport_params=transport_params,
).to(device)
stage2_models[sub_key] = cfm_model
# Locate and sort all consecutive stage2 weights
stage2_ckpts = list(run_dir.glob("stage2_epoch_*.pt"))
def get_epoch(p):
m = re.search(r"stage2_epoch_(\d+).pt", p.name)
return int(m.group(1)) if m else -1
stage2_ckpts.sort(key=get_epoch)
if not stage2_ckpts:
print("No Stage 2 configurations found to visualize!")
return
# Evaluate each stage 2 checkpoint consecutively and map visualizations
for ckpt in stage2_ckpts:
ep = get_epoch(ckpt)
print(f"\nProcessing Vector Field {ckpt.name}...")
stage2_models.load_state_dict(torch.load(ckpt, map_location=device))
acc_s2, metrics_s2 = evaluate_stage2(
epoch=ep,
stage1_model=stage1_model,
stage2_models=stage2_models,
val_loader=val_loader,
device=device,
subjects=subjects_list,
ds_name=cfg.val_set_name,
n_timesteps=cfg.stage2.get("n_timesteps", 25)
)
print(f"Stage 2 Epoch {ep} Overall Pearson's r: {acc_s2:.4f}")
for sub in subjects_list:
acc_map = metrics_s2[f"accmap_sub-{sub}"]
mean_r = acc_map.mean().item() if isinstance(acc_map, torch.Tensor) else np.mean(acc_map)
print(f"Stage 2 Epoch {ep} - Sub {sub} Mean Pearson's r: {mean_r:.4f}")
plot_heatmap(acc_map, f"Stage 2 Epoch {ep} - Sub {sub}", heatmap_dir / f"stage2_ep{ep}_sub{sub}.png")
print("\nVisualizations effectively correlated & processed! 🧠")
if __name__ == "__main__":
main()
|