flow-matching / src /submission.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
Raw
History Blame Contribute Delete
17.5 kB
"""
Submission script for 2-stage fMRI encoding with Flow Matching.
Generates predictions for friends-s7 (in-distribution) and ood (out-of-distribution)
test sets. Outputs are saved as .npy and .zip files matching the Algonauts 2025
challenge format.
Usage:
python -m src.submission --checkpoint-dir output/two_stage_encoding
python -m src.submission --checkpoint-dir output/two_stage_encoding --test-set ood
python -m src.submission --checkpoint-dir output/two_stage_encoding --test-set all
# Both test sets
python -m src.submission --checkpoint-dir output/two_stage_encoding
# Single test set
python -m src.submission --checkpoint-dir output/two_stage_encoding --test-set ood
# Custom output dir and device
python -m src.submission --checkpoint-dir output/two_stage_encoding --output-dir output/submission --device cpu
# Specific stage2 checkpoint
python -m src.submission --checkpoint-dir output/two_stage_encoding --stage2-ckpt stage2_epoch_9.pt
"""
import argparse
import warnings
import zipfile
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from .data import (
Algonauts2025Dataset,
load_sharded_features,
episode_filter,
)
from .stage1.medarc_architecture import MultiSubjectConvLinearEncoder
from .stage2.CFM import CFM
DEFAULT_DATA_DIR = Path("/raid/lttung05/fmri_encoder/data")
SUBJECTS = (1, 2, 3, 5)
OOD_MOVIES = ["chaplin", "mononoke", "passepartout", "planetearth", "pulpfiction", "wot"]
# Exact expected second-layer keys for OOD submission (no extras allowed)
EXPECTED_OOD_KEYS = {
"chaplin1", "chaplin2",
"mononoke1", "mononoke2",
"passepartout1", "passepartout2",
"planetearth1", "planetearth2",
"pulpfiction1", "pulpfiction2",
"wot1", "wot2",
}
EXPECTED_SUBJECTS = {"sub-01", "sub-02", "sub-03", "sub-05"}
def main():
parser = argparse.ArgumentParser(description="Generate submission predictions")
parser.add_argument(
"--checkpoint-dir", type=str, required=True,
help="Path to trained model output directory (contains config.yaml, stage1_best.pt, stage2_epoch_*.pt)",
)
parser.add_argument(
"--test-set", type=str, default="all", choices=["friends-s7", "ood", "all"],
help="Which test set(s) to generate predictions for",
)
parser.add_argument("--stage2-ckpt", type=str, default=None, help="Stage 2 checkpoint filename (default: latest)")
parser.add_argument("--n-timesteps", type=int, default=None, help="Override number of ODE steps for stage 2")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--datasets-root", type=str, default=None)
parser.add_argument("--output-dir", type=str, default=None, help="Output directory (default: <checkpoint-dir>/submission)")
args = parser.parse_args()
ckpt_dir = Path(args.checkpoint_dir)
cfg = OmegaConf.load(ckpt_dir / "config.yaml")
datasets_root = Path(args.datasets_root or cfg.get("datasets_root") or DEFAULT_DATA_DIR)
device = torch.device(args.device)
subjects = cfg.get("subjects", list(SUBJECTS))
n_timesteps = args.n_timesteps or cfg.stage2.get("n_timesteps", 25)
out_dir = Path(args.output_dir) if args.output_dir else ckpt_dir / "submission"
out_dir.mkdir(parents=True, exist_ok=True)
# Warn if NumPy >= 2.0 (Codabench server uses NumPy 1.x)
np_version = tuple(int(x) for x in np.__version__.split(".")[:2])
if np_version[0] >= 2:
warnings.warn(
f"NumPy {np.__version__} detected. Codabench requires NumPy < 2.0. "
f"Submissions saved with NumPy 2.x will cause a formatting error. "
f"Install numpy<2.0 and re-run.",
stacklevel=1,
)
print(f"Checkpoint dir: {ckpt_dir}")
print(f"Output dir: {out_dir}")
print(f"Device: {device}")
print(f"Subjects: {subjects}")
print(f"Stage 2 ODE timesteps: {n_timesteps}")
# --- Load features ---
print("Loading features...")
all_features = load_all_features(cfg, datasets_root)
# --- Build models ---
print("Building models...")
stage1_model, stage2_models = build_models(cfg, ckpt_dir, all_features, subjects, device, args.stage2_ckpt)
# --- Generate predictions ---
test_sets = ["friends-s7", "ood"] if args.test_set == "all" else [args.test_set]
for test_set_name in test_sets:
print(f"\n{'='*60}")
print(f"Generating predictions for: {test_set_name}")
print(f"{'='*60}")
fmri_num_samples = load_fmri_num_samples(datasets_root, test_set_name)
test_loader = make_test_loader(cfg, all_features, fmri_num_samples, test_set_name)
predictions = run_inference(
stage1_model=stage1_model,
stage2_models=stage2_models,
test_loader=test_loader,
fmri_num_samples=fmri_num_samples,
subjects=subjects,
device=device,
n_timesteps=n_timesteps,
)
validate_submission(predictions, test_set_name, fmri_num_samples)
print_summary(predictions)
save_predictions(predictions, test_set_name, out_dir)
print("\nDone!")
def load_all_features(cfg, datasets_root: Path) -> list[dict[str, np.ndarray]]:
"""Load all feature sets specified in config, for both friends and ood series."""
all_features = []
for feat_name in cfg.include_features:
model, layer = feat_name.split("/")
feat_cfg = cfg.features[model]
model_name = feat_cfg.model
layer_name = feat_cfg.layers[layer]
print(f" Loading {feat_name} ({model_name}/{layer_name})")
features_dir = datasets_root / "features"
# Load friends (all seasons including s7)
friends_features = load_sharded_features(
features_dir, model=model_name, layer=layer_name, series="friends"
)
# Load ood movies
ood_features = load_sharded_features(
features_dir, model=model_name, layer=layer_name, series="ood"
)
features = {**friends_features, **ood_features}
if cfg.stage1.model.global_pool == "avg":
features = pool_features(features)
all_features.append(features)
return all_features
def pool_features(features: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
pooled = {}
for key, feat in features.items():
assert feat.ndim in {2, 3}
if feat.ndim == 3:
feat = feat.mean(axis=1)
pooled[key] = feat
return pooled
def build_models(cfg, ckpt_dir: Path, all_features, subjects, device, stage2_ckpt_name=None):
"""Build and load stage 1 and stage 2 models from checkpoints."""
# Infer feature dimensions from an arbitrary episode
sample_episode = next(iter(all_features[0]))
feat_dims = [feats[sample_episode].shape[-1] for feats in all_features]
print(f" Feature dims: {feat_dims}")
# --- Stage 1 ---
stage1_model = MultiSubjectConvLinearEncoder(
num_subjects=len(subjects),
feat_dims=feat_dims,
**cfg.stage1.model,
).to(device)
stage1_path = ckpt_dir / "stage1_best.pt"
print(f" Loading stage 1: {stage1_path}")
stage1_model.load_state_dict(torch.load(stage1_path, map_location=device, weights_only=True))
stage1_model.eval()
# --- Stage 2 ---
target_dim = 1000 # Schaefer atlas parcels
cfm_params = cfg.stage2.cfm
velocity_net_params = cfg.stage2.velocity_net
source_ve_params = cfg.stage2.source_ve
transport_params = cfg.stage2.transport
stage2_models = nn.ModuleDict()
for sub in subjects:
sub_key = str(sub)
cfm_model = CFM(
feat_dim=target_dim,
cfm_params=cfm_params,
velocity_net_params=velocity_net_params,
source_ve_params=source_ve_params,
transport_params=transport_params,
)
stage2_models[sub_key] = cfm_model
# Find stage 2 checkpoint
if stage2_ckpt_name:
stage2_path = ckpt_dir / stage2_ckpt_name
else:
# Pick the latest stage2 checkpoint
stage2_paths = sorted(ckpt_dir.glob("stage2_epoch_*.pt"))
if not stage2_paths:
raise FileNotFoundError(f"No stage2 checkpoints found in {ckpt_dir}")
stage2_path = stage2_paths[-1]
print(f" Loading stage 2: {stage2_path}")
stage2_models.load_state_dict(torch.load(stage2_path, map_location=device, weights_only=True))
stage2_models = stage2_models.to(device)
stage2_models.eval()
return stage1_model, stage2_models
def load_fmri_num_samples(datasets_root: Path, test_set_name: str) -> dict[str, dict[str, int]]:
"""Load per-subject, per-episode fMRI sample counts for the test set."""
if test_set_name == "friends-s7":
file_pattern = "friends-s7"
elif test_set_name == "ood":
file_pattern = "ood"
else:
raise ValueError(f"Unknown test set: {test_set_name}")
fmri_dir = datasets_root / "algonauts_2025.competitors" / "fmri"
sample_paths = sorted(fmri_dir.rglob(f"*_{file_pattern}_fmri_samples.npy"))
if not sample_paths:
raise FileNotFoundError(
f"No fmri_samples files found for {test_set_name} in {fmri_dir}. "
f"Expected pattern: *_{file_pattern}_fmri_samples.npy"
)
fmri_num_samples = {}
for path in sample_paths:
# path like: .../sub-01/target_sample_number/sub-01_friends-s7_fmri_samples.npy
sub = path.parents[1].name # e.g. "sub-01"
fmri_num_samples[sub] = np.load(path, allow_pickle=True).item()
print(f" Loaded fmri_num_samples for {list(fmri_num_samples.keys())}")
return fmri_num_samples
def make_test_loader(cfg, all_features, fmri_num_samples, test_set_name):
"""Create a DataLoader for the specified test set."""
from torch.utils.data import DataLoader
all_episodes = list(all_features[0])
if test_set_name == "friends-s7":
filter_fn = episode_filter(seasons=[7], movies=[])
elif test_set_name == "ood":
filter_fn = episode_filter(seasons=[], movies=OOD_MOVIES)
else:
raise ValueError(f"Unknown test set: {test_set_name}")
ds_episodes = sorted(filter(filter_fn, all_episodes))
print(f" Episodes ({len(ds_episodes)}): {ds_episodes}")
# Use fmri_num_samples as the authoritative episode list to ensure
# we produce exactly the required keys (no extras, no missing)
expected_episodes = set()
for sub_samples in fmri_num_samples.values():
expected_episodes.update(sub_samples.keys())
# Only keep episodes that are both in features and in expected set
ds_episodes = [ep for ep in ds_episodes if ep in expected_episodes]
# Also add any expected episodes present in features but missed by filter
feature_episodes = set(all_episodes)
for ep in expected_episodes:
if ep not in ds_episodes and ep in feature_episodes:
ds_episodes.append(ep)
ds_episodes = sorted(ds_episodes)
# Build per-episode max sample count across subjects
episode_num_samples = {}
for ep in ds_episodes:
max_samples = max(
fmri_num_samples[sub].get(ep, 0) for sub in fmri_num_samples
)
if max_samples == 0:
print(f" Warning: no fmri_num_samples for episode {ep}, skipping")
continue
episode_num_samples[ep] = max_samples
ds_episodes = [ep for ep in ds_episodes if ep in episode_num_samples]
missing = expected_episodes - set(ds_episodes)
if missing:
raise ValueError(
f"Missing episodes in features: {missing}. "
f"Cannot produce complete submission."
)
dataset = Algonauts2025Dataset(
episode_list=ds_episodes,
feat_data=all_features,
fmri_num_samples=episode_num_samples,
sample_length=None,
shuffle=False,
)
loader = DataLoader(dataset, batch_size=1)
return loader
@torch.no_grad()
def run_inference(
*,
stage1_model: nn.Module,
stage2_models: nn.ModuleDict,
test_loader,
fmri_num_samples: dict[str, dict[str, int]],
subjects: list[int],
device: torch.device,
n_timesteps: int = 25,
) -> dict[str, dict[str, np.ndarray]]:
"""Run two-stage inference and collect per-subject, per-episode predictions."""
stage1_model.eval()
stage2_models.eval()
submission = {f"sub-{sub:02d}": {} for sub in subjects}
for batch_idx, batch in enumerate(test_loader):
feats = [f.to(device) for f in batch["features"]]
episodes = batch["episode"]
# Stage 1: mean anchor prediction (N, S, T, V)
mu_anchor = stage1_model(feats)
N, S, T, V = mu_anchor.shape
assert N == 1, "Batch size must be 1 for submission"
# Stage 2: per-subject flow matching refinement
batch_preds = []
for i, sub in enumerate(subjects):
sub_key = str(sub)
cfm = stage2_models[sub_key]
# mu for this subject: (N, T, V) -> (N, V, T) for CFM
mu = mu_anchor[:, i].transpose(1, 2)
# Flow matching inference
pred = cfm(mu, n_timesteps=n_timesteps) # (N, V, T)
pred = pred.transpose(1, 2) # (N, T, V)
batch_preds.append(pred)
# Store predictions per episode and subject
for ii, episode in enumerate(episodes):
for jj, sub_id in enumerate(subjects):
sub = f"sub-{sub_id:02d}"
pred = batch_preds[jj][ii].cpu().numpy() # (T, V)
# Truncate to exact fmri_num_samples for this subject
num_samples = fmri_num_samples[sub].get(episode, len(pred))
pred = pred[:num_samples].astype(np.float32)
submission[sub][episode] = pred
if (batch_idx + 1) % 10 == 0:
print(f" Processed {batch_idx + 1} episodes...")
return submission
def validate_submission(
predictions: dict[str, dict[str, np.ndarray]],
test_set_name: str,
fmri_num_samples: dict[str, dict[str, int]],
):
"""Validate submission keys, shapes, and dtypes match challenge requirements."""
# Check subject keys
subject_keys = set(predictions.keys())
if subject_keys != EXPECTED_SUBJECTS:
extra = subject_keys - EXPECTED_SUBJECTS
missing = EXPECTED_SUBJECTS - subject_keys
raise ValueError(
f"Subject key mismatch. Extra: {extra}, Missing: {missing}. "
f"Expected exactly: {EXPECTED_SUBJECTS}"
)
# Determine expected episode keys from fmri_num_samples
if test_set_name == "ood":
expected_episodes = EXPECTED_OOD_KEYS
elif test_set_name == "friends-s7":
# For friends-s7, expected keys come from fmri_num_samples
# (the ground-truth sample counts define the required episodes)
all_episodes = set()
for sub_samples in fmri_num_samples.values():
all_episodes.update(sub_samples.keys())
expected_episodes = all_episodes
else:
print(f" Warning: no key validation for test set '{test_set_name}'")
return
# Check episode keys per subject
for sub, episodes_dict in predictions.items():
episode_keys = set(episodes_dict.keys())
extra = episode_keys - expected_episodes
missing = expected_episodes - episode_keys
if extra:
raise ValueError(
f"{sub}: extra episode keys {extra} — these will cause a formatting error"
)
if missing:
raise ValueError(
f"{sub}: missing episode keys {missing} — submission is incomplete"
)
# Validate shapes and dtype
for ep, pred in episodes_dict.items():
expected_n = fmri_num_samples[sub].get(ep)
if expected_n is not None and pred.shape[0] != expected_n:
raise ValueError(
f"{sub}/{ep}: shape {pred.shape} but expected N={expected_n}"
)
if pred.shape[1] != 1000:
raise ValueError(
f"{sub}/{ep}: shape {pred.shape} but expected 1000 parcels"
)
if pred.dtype != np.float32:
raise ValueError(
f"{sub}/{ep}: dtype {pred.dtype} but expected float32"
)
print(f" Validation passed: {len(predictions)} subjects, "
f"{len(expected_episodes)} episodes each")
def print_summary(predictions: dict[str, dict[str, np.ndarray]]):
"""Print a summary of the generated predictions."""
for subject, episodes_dict in predictions.items():
print(f" {subject}: {len(episodes_dict)} episodes")
for episode, pred in episodes_dict.items():
print(f" {episode}: {pred.shape} {pred.dtype}")
def save_predictions(
predictions: dict[str, dict[str, np.ndarray]],
test_set_name: str,
out_dir: Path,
):
"""Save predictions as .npy and .zip files."""
file_name = f"fmri_predictions_{test_set_name.replace('-', '_')}"
npy_path = out_dir / f"{file_name}.npy"
np.save(npy_path, predictions)
print(f" Saved: {npy_path}")
zip_path = out_dir / f"{file_name}.zip"
with zipfile.ZipFile(zip_path, "w") as zipf:
zipf.write(npy_path, npy_path.name)
print(f" Saved: {zip_path}")
if __name__ == "__main__":
main()