| """ |
| Submission script for 2-stage fMRI encoding with Flow Matching. |
| |
| Generates predictions for friends-s7 (in-distribution) and ood (out-of-distribution) |
| test sets. Outputs are saved as .npy and .zip files matching the Algonauts 2025 |
| challenge format. |
| |
| Usage: |
| python -m src.submission --checkpoint-dir output/two_stage_encoding |
| python -m src.submission --checkpoint-dir output/two_stage_encoding --test-set ood |
| python -m src.submission --checkpoint-dir output/two_stage_encoding --test-set all |
| # Both test sets |
| python -m src.submission --checkpoint-dir output/two_stage_encoding |
| |
| # Single test set |
| python -m src.submission --checkpoint-dir output/two_stage_encoding --test-set ood |
| |
| # Custom output dir and device |
| python -m src.submission --checkpoint-dir output/two_stage_encoding --output-dir output/submission --device cpu |
| |
| # Specific stage2 checkpoint |
| python -m src.submission --checkpoint-dir output/two_stage_encoding --stage2-ckpt stage2_epoch_9.pt |
| """ |
|
|
| import argparse |
| import warnings |
| import zipfile |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from omegaconf import OmegaConf |
|
|
| from .data import ( |
| Algonauts2025Dataset, |
| load_sharded_features, |
| episode_filter, |
| ) |
| from .stage1.medarc_architecture import MultiSubjectConvLinearEncoder |
| from .stage2.CFM import CFM |
|
|
| DEFAULT_DATA_DIR = Path("/raid/lttung05/fmri_encoder/data") |
| SUBJECTS = (1, 2, 3, 5) |
|
|
| OOD_MOVIES = ["chaplin", "mononoke", "passepartout", "planetearth", "pulpfiction", "wot"] |
|
|
| |
| EXPECTED_OOD_KEYS = { |
| "chaplin1", "chaplin2", |
| "mononoke1", "mononoke2", |
| "passepartout1", "passepartout2", |
| "planetearth1", "planetearth2", |
| "pulpfiction1", "pulpfiction2", |
| "wot1", "wot2", |
| } |
|
|
| EXPECTED_SUBJECTS = {"sub-01", "sub-02", "sub-03", "sub-05"} |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate submission predictions") |
| parser.add_argument( |
| "--checkpoint-dir", type=str, required=True, |
| help="Path to trained model output directory (contains config.yaml, stage1_best.pt, stage2_epoch_*.pt)", |
| ) |
| parser.add_argument( |
| "--test-set", type=str, default="all", choices=["friends-s7", "ood", "all"], |
| help="Which test set(s) to generate predictions for", |
| ) |
| parser.add_argument("--stage2-ckpt", type=str, default=None, help="Stage 2 checkpoint filename (default: latest)") |
| parser.add_argument("--n-timesteps", type=int, default=None, help="Override number of ODE steps for stage 2") |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--datasets-root", type=str, default=None) |
| parser.add_argument("--output-dir", type=str, default=None, help="Output directory (default: <checkpoint-dir>/submission)") |
| args = parser.parse_args() |
|
|
| ckpt_dir = Path(args.checkpoint_dir) |
| cfg = OmegaConf.load(ckpt_dir / "config.yaml") |
|
|
| datasets_root = Path(args.datasets_root or cfg.get("datasets_root") or DEFAULT_DATA_DIR) |
| device = torch.device(args.device) |
| subjects = cfg.get("subjects", list(SUBJECTS)) |
| n_timesteps = args.n_timesteps or cfg.stage2.get("n_timesteps", 25) |
|
|
| out_dir = Path(args.output_dir) if args.output_dir else ckpt_dir / "submission" |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| np_version = tuple(int(x) for x in np.__version__.split(".")[:2]) |
| if np_version[0] >= 2: |
| warnings.warn( |
| f"NumPy {np.__version__} detected. Codabench requires NumPy < 2.0. " |
| f"Submissions saved with NumPy 2.x will cause a formatting error. " |
| f"Install numpy<2.0 and re-run.", |
| stacklevel=1, |
| ) |
|
|
| print(f"Checkpoint dir: {ckpt_dir}") |
| print(f"Output dir: {out_dir}") |
| print(f"Device: {device}") |
| print(f"Subjects: {subjects}") |
| print(f"Stage 2 ODE timesteps: {n_timesteps}") |
|
|
| |
| print("Loading features...") |
| all_features = load_all_features(cfg, datasets_root) |
|
|
| |
| print("Building models...") |
| stage1_model, stage2_models = build_models(cfg, ckpt_dir, all_features, subjects, device, args.stage2_ckpt) |
|
|
| |
| test_sets = ["friends-s7", "ood"] if args.test_set == "all" else [args.test_set] |
|
|
| for test_set_name in test_sets: |
| print(f"\n{'='*60}") |
| print(f"Generating predictions for: {test_set_name}") |
| print(f"{'='*60}") |
|
|
| fmri_num_samples = load_fmri_num_samples(datasets_root, test_set_name) |
| test_loader = make_test_loader(cfg, all_features, fmri_num_samples, test_set_name) |
|
|
| predictions = run_inference( |
| stage1_model=stage1_model, |
| stage2_models=stage2_models, |
| test_loader=test_loader, |
| fmri_num_samples=fmri_num_samples, |
| subjects=subjects, |
| device=device, |
| n_timesteps=n_timesteps, |
| ) |
|
|
| validate_submission(predictions, test_set_name, fmri_num_samples) |
| print_summary(predictions) |
| save_predictions(predictions, test_set_name, out_dir) |
|
|
| print("\nDone!") |
|
|
|
|
| def load_all_features(cfg, datasets_root: Path) -> list[dict[str, np.ndarray]]: |
| """Load all feature sets specified in config, for both friends and ood series.""" |
| all_features = [] |
| for feat_name in cfg.include_features: |
| model, layer = feat_name.split("/") |
| feat_cfg = cfg.features[model] |
| model_name = feat_cfg.model |
| layer_name = feat_cfg.layers[layer] |
| print(f" Loading {feat_name} ({model_name}/{layer_name})") |
|
|
| features_dir = datasets_root / "features" |
|
|
| |
| friends_features = load_sharded_features( |
| features_dir, model=model_name, layer=layer_name, series="friends" |
| ) |
| |
| ood_features = load_sharded_features( |
| features_dir, model=model_name, layer=layer_name, series="ood" |
| ) |
| features = {**friends_features, **ood_features} |
|
|
| if cfg.stage1.model.global_pool == "avg": |
| features = pool_features(features) |
|
|
| all_features.append(features) |
|
|
| return all_features |
|
|
|
|
| def pool_features(features: dict[str, np.ndarray]) -> dict[str, np.ndarray]: |
| pooled = {} |
| for key, feat in features.items(): |
| assert feat.ndim in {2, 3} |
| if feat.ndim == 3: |
| feat = feat.mean(axis=1) |
| pooled[key] = feat |
| return pooled |
|
|
|
|
| def build_models(cfg, ckpt_dir: Path, all_features, subjects, device, stage2_ckpt_name=None): |
| """Build and load stage 1 and stage 2 models from checkpoints.""" |
| |
| sample_episode = next(iter(all_features[0])) |
| feat_dims = [feats[sample_episode].shape[-1] for feats in all_features] |
| print(f" Feature dims: {feat_dims}") |
|
|
| |
| stage1_model = MultiSubjectConvLinearEncoder( |
| num_subjects=len(subjects), |
| feat_dims=feat_dims, |
| **cfg.stage1.model, |
| ).to(device) |
|
|
| stage1_path = ckpt_dir / "stage1_best.pt" |
| print(f" Loading stage 1: {stage1_path}") |
| stage1_model.load_state_dict(torch.load(stage1_path, map_location=device, weights_only=True)) |
| stage1_model.eval() |
|
|
| |
| target_dim = 1000 |
| cfm_params = cfg.stage2.cfm |
| velocity_net_params = cfg.stage2.velocity_net |
| source_ve_params = cfg.stage2.source_ve |
| transport_params = cfg.stage2.transport |
|
|
| stage2_models = nn.ModuleDict() |
| for sub in subjects: |
| sub_key = str(sub) |
| cfm_model = CFM( |
| feat_dim=target_dim, |
| cfm_params=cfm_params, |
| velocity_net_params=velocity_net_params, |
| source_ve_params=source_ve_params, |
| transport_params=transport_params, |
| ) |
| stage2_models[sub_key] = cfm_model |
|
|
| |
| if stage2_ckpt_name: |
| stage2_path = ckpt_dir / stage2_ckpt_name |
| else: |
| |
| stage2_paths = sorted(ckpt_dir.glob("stage2_epoch_*.pt")) |
| if not stage2_paths: |
| raise FileNotFoundError(f"No stage2 checkpoints found in {ckpt_dir}") |
| stage2_path = stage2_paths[-1] |
|
|
| print(f" Loading stage 2: {stage2_path}") |
| stage2_models.load_state_dict(torch.load(stage2_path, map_location=device, weights_only=True)) |
| stage2_models = stage2_models.to(device) |
| stage2_models.eval() |
|
|
| return stage1_model, stage2_models |
|
|
|
|
| def load_fmri_num_samples(datasets_root: Path, test_set_name: str) -> dict[str, dict[str, int]]: |
| """Load per-subject, per-episode fMRI sample counts for the test set.""" |
| if test_set_name == "friends-s7": |
| file_pattern = "friends-s7" |
| elif test_set_name == "ood": |
| file_pattern = "ood" |
| else: |
| raise ValueError(f"Unknown test set: {test_set_name}") |
|
|
| fmri_dir = datasets_root / "algonauts_2025.competitors" / "fmri" |
| sample_paths = sorted(fmri_dir.rglob(f"*_{file_pattern}_fmri_samples.npy")) |
|
|
| if not sample_paths: |
| raise FileNotFoundError( |
| f"No fmri_samples files found for {test_set_name} in {fmri_dir}. " |
| f"Expected pattern: *_{file_pattern}_fmri_samples.npy" |
| ) |
|
|
| fmri_num_samples = {} |
| for path in sample_paths: |
| |
| sub = path.parents[1].name |
| fmri_num_samples[sub] = np.load(path, allow_pickle=True).item() |
|
|
| print(f" Loaded fmri_num_samples for {list(fmri_num_samples.keys())}") |
| return fmri_num_samples |
|
|
|
|
| def make_test_loader(cfg, all_features, fmri_num_samples, test_set_name): |
| """Create a DataLoader for the specified test set.""" |
| from torch.utils.data import DataLoader |
|
|
| all_episodes = list(all_features[0]) |
|
|
| if test_set_name == "friends-s7": |
| filter_fn = episode_filter(seasons=[7], movies=[]) |
| elif test_set_name == "ood": |
| filter_fn = episode_filter(seasons=[], movies=OOD_MOVIES) |
| else: |
| raise ValueError(f"Unknown test set: {test_set_name}") |
|
|
| ds_episodes = sorted(filter(filter_fn, all_episodes)) |
| print(f" Episodes ({len(ds_episodes)}): {ds_episodes}") |
|
|
| |
| |
| expected_episodes = set() |
| for sub_samples in fmri_num_samples.values(): |
| expected_episodes.update(sub_samples.keys()) |
|
|
| |
| ds_episodes = [ep for ep in ds_episodes if ep in expected_episodes] |
|
|
| |
| feature_episodes = set(all_episodes) |
| for ep in expected_episodes: |
| if ep not in ds_episodes and ep in feature_episodes: |
| ds_episodes.append(ep) |
| ds_episodes = sorted(ds_episodes) |
|
|
| |
| episode_num_samples = {} |
| for ep in ds_episodes: |
| max_samples = max( |
| fmri_num_samples[sub].get(ep, 0) for sub in fmri_num_samples |
| ) |
| if max_samples == 0: |
| print(f" Warning: no fmri_num_samples for episode {ep}, skipping") |
| continue |
| episode_num_samples[ep] = max_samples |
|
|
| ds_episodes = [ep for ep in ds_episodes if ep in episode_num_samples] |
|
|
| missing = expected_episodes - set(ds_episodes) |
| if missing: |
| raise ValueError( |
| f"Missing episodes in features: {missing}. " |
| f"Cannot produce complete submission." |
| ) |
|
|
| dataset = Algonauts2025Dataset( |
| episode_list=ds_episodes, |
| feat_data=all_features, |
| fmri_num_samples=episode_num_samples, |
| sample_length=None, |
| shuffle=False, |
| ) |
|
|
| loader = DataLoader(dataset, batch_size=1) |
| return loader |
|
|
|
|
| @torch.no_grad() |
| def run_inference( |
| *, |
| stage1_model: nn.Module, |
| stage2_models: nn.ModuleDict, |
| test_loader, |
| fmri_num_samples: dict[str, dict[str, int]], |
| subjects: list[int], |
| device: torch.device, |
| n_timesteps: int = 25, |
| ) -> dict[str, dict[str, np.ndarray]]: |
| """Run two-stage inference and collect per-subject, per-episode predictions.""" |
| stage1_model.eval() |
| stage2_models.eval() |
|
|
| submission = {f"sub-{sub:02d}": {} for sub in subjects} |
|
|
| for batch_idx, batch in enumerate(test_loader): |
| feats = [f.to(device) for f in batch["features"]] |
| episodes = batch["episode"] |
|
|
| |
| mu_anchor = stage1_model(feats) |
|
|
| N, S, T, V = mu_anchor.shape |
| assert N == 1, "Batch size must be 1 for submission" |
|
|
| |
| batch_preds = [] |
| for i, sub in enumerate(subjects): |
| sub_key = str(sub) |
| cfm = stage2_models[sub_key] |
|
|
| |
| mu = mu_anchor[:, i].transpose(1, 2) |
|
|
| |
| pred = cfm(mu, n_timesteps=n_timesteps) |
| pred = pred.transpose(1, 2) |
| batch_preds.append(pred) |
|
|
| |
| for ii, episode in enumerate(episodes): |
| for jj, sub_id in enumerate(subjects): |
| sub = f"sub-{sub_id:02d}" |
| pred = batch_preds[jj][ii].cpu().numpy() |
|
|
| |
| num_samples = fmri_num_samples[sub].get(episode, len(pred)) |
| pred = pred[:num_samples].astype(np.float32) |
|
|
| submission[sub][episode] = pred |
|
|
| if (batch_idx + 1) % 10 == 0: |
| print(f" Processed {batch_idx + 1} episodes...") |
|
|
| return submission |
|
|
|
|
| def validate_submission( |
| predictions: dict[str, dict[str, np.ndarray]], |
| test_set_name: str, |
| fmri_num_samples: dict[str, dict[str, int]], |
| ): |
| """Validate submission keys, shapes, and dtypes match challenge requirements.""" |
| |
| subject_keys = set(predictions.keys()) |
| if subject_keys != EXPECTED_SUBJECTS: |
| extra = subject_keys - EXPECTED_SUBJECTS |
| missing = EXPECTED_SUBJECTS - subject_keys |
| raise ValueError( |
| f"Subject key mismatch. Extra: {extra}, Missing: {missing}. " |
| f"Expected exactly: {EXPECTED_SUBJECTS}" |
| ) |
|
|
| |
| if test_set_name == "ood": |
| expected_episodes = EXPECTED_OOD_KEYS |
| elif test_set_name == "friends-s7": |
| |
| |
| all_episodes = set() |
| for sub_samples in fmri_num_samples.values(): |
| all_episodes.update(sub_samples.keys()) |
| expected_episodes = all_episodes |
| else: |
| print(f" Warning: no key validation for test set '{test_set_name}'") |
| return |
|
|
| |
| for sub, episodes_dict in predictions.items(): |
| episode_keys = set(episodes_dict.keys()) |
| extra = episode_keys - expected_episodes |
| missing = expected_episodes - episode_keys |
|
|
| if extra: |
| raise ValueError( |
| f"{sub}: extra episode keys {extra} — these will cause a formatting error" |
| ) |
| if missing: |
| raise ValueError( |
| f"{sub}: missing episode keys {missing} — submission is incomplete" |
| ) |
|
|
| |
| for ep, pred in episodes_dict.items(): |
| expected_n = fmri_num_samples[sub].get(ep) |
| if expected_n is not None and pred.shape[0] != expected_n: |
| raise ValueError( |
| f"{sub}/{ep}: shape {pred.shape} but expected N={expected_n}" |
| ) |
| if pred.shape[1] != 1000: |
| raise ValueError( |
| f"{sub}/{ep}: shape {pred.shape} but expected 1000 parcels" |
| ) |
| if pred.dtype != np.float32: |
| raise ValueError( |
| f"{sub}/{ep}: dtype {pred.dtype} but expected float32" |
| ) |
|
|
| print(f" Validation passed: {len(predictions)} subjects, " |
| f"{len(expected_episodes)} episodes each") |
|
|
|
|
| def print_summary(predictions: dict[str, dict[str, np.ndarray]]): |
| """Print a summary of the generated predictions.""" |
| for subject, episodes_dict in predictions.items(): |
| print(f" {subject}: {len(episodes_dict)} episodes") |
| for episode, pred in episodes_dict.items(): |
| print(f" {episode}: {pred.shape} {pred.dtype}") |
|
|
|
|
| def save_predictions( |
| predictions: dict[str, dict[str, np.ndarray]], |
| test_set_name: str, |
| out_dir: Path, |
| ): |
| """Save predictions as .npy and .zip files.""" |
| file_name = f"fmri_predictions_{test_set_name.replace('-', '_')}" |
|
|
| npy_path = out_dir / f"{file_name}.npy" |
| np.save(npy_path, predictions) |
| print(f" Saved: {npy_path}") |
|
|
| zip_path = out_dir / f"{file_name}.zip" |
| with zipfile.ZipFile(zip_path, "w") as zipf: |
| zipf.write(npy_path, npy_path.name) |
| print(f" Saved: {zip_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|