flow-matching / test /check_pearson.py
sabertoaster's picture
Upload folder using huggingface_hub
0254260 verified
Raw
History Blame Contribute Delete
5.92 kB
import sys
import os
import glob
import re
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
# Add src to the system path so we can import architecture and metrics
src_dir = Path(__file__).resolve().parent.parent / "src"
sys.path.append(str(src_dir))
from training import make_data_loaders, evaluate_stage1, evaluate_stage2, SUBJECTS
from medarc_architecture import MultiSubjectConvLinearEncoder
from stage2.CFM import CFM
def plot_heatmap(acc_map, title, save_path):
"""
Plots a 1D tensor (brain region voxels) as a 2D wrapped heatmap grid.
Pads the flat array with NaNs to make it square.
"""
V = acc_map.shape[0]
# Calculate square dimensions (e.g., 1000 -> 32 x 32 grid roughly)
side = int(np.ceil(np.sqrt(V)))
# Pad to make it perfectly square
padded = np.pad(acc_map, (0, side * side - V), constant_values=np.nan)
grid = padded.reshape(side, side)
plt.figure(figsize=(8, 6))
# Using 'viridis' or 'hot' for heatmap representation
# vmin/vmax set to typical Pearson's R bounds for visuals or let it auto-scale
im = plt.imshow(grid, cmap="viridis", aspect="auto")
plt.colorbar(im, label="Pearson's r Correlation")
plt.title(title)
plt.tight_layout()
plt.savefig(save_path, dpi=150)
plt.close()
print(f"Saved heatmap {title} to {save_path}")
def main():
root_dir = Path("/workspace/code/flow_matching")
# Assuming standard directory where training script drops them
run_dir = root_dir / "output" / "two_stage_encoding"
config_path = run_dir / "config.yaml"
if not config_path.exists():
print(f"Configuration file not found at {config_path}!")
return
cfg = OmegaConf.load(config_path)
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Building DataLoaders...")
data_loaders = make_data_loaders(cfg)
val_loader = data_loaders[cfg.val_set_name]
sample_batch = next(iter(val_loader))
feat_dims = [f.shape[-1] for f in sample_batch["features"]]
target_dim = sample_batch["fmri"].shape[-1]
subjects_list = cfg.get("subjects", SUBJECTS)
print(f"Found target Voxel bounds (V): {target_dim}")
print("--- Stage 1 Checkout ---")
stage1_model = MultiSubjectConvLinearEncoder(
num_subjects=len(subjects_list),
feat_dims=feat_dims,
**cfg.stage1.model
).to(device)
stage1_ckpt = run_dir / "stage1_best.pt"
if not stage1_ckpt.exists():
print(f"Cannot find Stage 1 checkpoint at {stage1_ckpt}")
return
print(f"Loading Stage 1 Mean Anchor from {stage1_ckpt.name}...")
stage1_model.load_state_dict(torch.load(stage1_ckpt, map_location=device))
# Execute Stage 1 evaluate to compute Pearson map
acc_s1, metrics_s1 = evaluate_stage1(
epoch=0,
model=stage1_model,
val_loader=val_loader,
device=device,
subjects=subjects_list,
ds_name=cfg.val_set_name
)
heatmap_dir = run_dir / "heatmaps"
heatmap_dir.mkdir(exist_ok=True, parents=True)
print(f"Stage 1 Overall Pearson's r: {acc_s1:.4f}")
# Visualize stage 1 mean anchors
for sub in subjects_list:
acc_map = metrics_s1[f"accmap_sub-{sub}"]
mean_r = acc_map.mean().item() if isinstance(acc_map, torch.Tensor) else np.mean(acc_map)
print(f"Stage 1 - Sub {sub} Mean Pearson's r: {mean_r:.4f}")
plot_heatmap(acc_map, f"Stage 1 Best (Mean Anchor) - Sub {sub}", heatmap_dir / f"stage1_sub{sub}.png")
print("\n--- Stage 2 Checkout ---")
cfm_params = cfg.stage2.cfm
velocity_net_params = cfg.stage2.velocity_net
source_ve_params = cfg.stage2.source_ve
transport_params = cfg.stage2.transport
stage2_models = nn.ModuleDict()
for sub in subjects_list:
sub_key = str(sub)
cfm_model = CFM(
feat_dim=target_dim,
cfm_params=cfm_params,
velocity_net_params=velocity_net_params,
source_ve_params=source_ve_params,
transport_params=transport_params,
).to(device)
stage2_models[sub_key] = cfm_model
# Locate and sort all consecutive stage2 weights
stage2_ckpts = list(run_dir.glob("stage2_epoch_*.pt"))
def get_epoch(p):
m = re.search(r"stage2_epoch_(\d+).pt", p.name)
return int(m.group(1)) if m else -1
stage2_ckpts.sort(key=get_epoch)
if not stage2_ckpts:
print("No Stage 2 configurations found to visualize!")
return
# Evaluate each stage 2 checkpoint consecutively and map visualizations
for ckpt in stage2_ckpts:
ep = get_epoch(ckpt)
print(f"\nProcessing Vector Field {ckpt.name}...")
stage2_models.load_state_dict(torch.load(ckpt, map_location=device))
acc_s2, metrics_s2 = evaluate_stage2(
epoch=ep,
stage1_model=stage1_model,
stage2_models=stage2_models,
val_loader=val_loader,
device=device,
subjects=subjects_list,
ds_name=cfg.val_set_name,
n_timesteps=cfg.stage2.get("n_timesteps", 25)
)
print(f"Stage 2 Epoch {ep} Overall Pearson's r: {acc_s2:.4f}")
for sub in subjects_list:
acc_map = metrics_s2[f"accmap_sub-{sub}"]
mean_r = acc_map.mean().item() if isinstance(acc_map, torch.Tensor) else np.mean(acc_map)
print(f"Stage 2 Epoch {ep} - Sub {sub} Mean Pearson's r: {mean_r:.4f}")
plot_heatmap(acc_map, f"Stage 2 Epoch {ep} - Sub {sub}", heatmap_dir / f"stage2_ep{ep}_sub{sub}.png")
print("\nVisualizations effectively correlated & processed! 🧠")
if __name__ == "__main__":
main()