""" Submission script for 2-stage fMRI encoding with Flow Matching. Generates predictions for friends-s7 (in-distribution) and ood (out-of-distribution) test sets. Outputs are saved as .npy and .zip files matching the Algonauts 2025 challenge format. Usage: python -m src.submission --checkpoint-dir output/two_stage_encoding python -m src.submission --checkpoint-dir output/two_stage_encoding --test-set ood python -m src.submission --checkpoint-dir output/two_stage_encoding --test-set all # Both test sets python -m src.submission --checkpoint-dir output/two_stage_encoding # Single test set python -m src.submission --checkpoint-dir output/two_stage_encoding --test-set ood # Custom output dir and device python -m src.submission --checkpoint-dir output/two_stage_encoding --output-dir output/submission --device cpu # Specific stage2 checkpoint python -m src.submission --checkpoint-dir output/two_stage_encoding --stage2-ckpt stage2_epoch_9.pt """ import argparse import warnings import zipfile from pathlib import Path import numpy as np import torch import torch.nn as nn from omegaconf import OmegaConf from .data import ( Algonauts2025Dataset, load_sharded_features, episode_filter, ) from .stage1.medarc_architecture import MultiSubjectConvLinearEncoder from .stage2.CFM import CFM DEFAULT_DATA_DIR = Path("/raid/lttung05/fmri_encoder/data") SUBJECTS = (1, 2, 3, 5) OOD_MOVIES = ["chaplin", "mononoke", "passepartout", "planetearth", "pulpfiction", "wot"] # Exact expected second-layer keys for OOD submission (no extras allowed) EXPECTED_OOD_KEYS = { "chaplin1", "chaplin2", "mononoke1", "mononoke2", "passepartout1", "passepartout2", "planetearth1", "planetearth2", "pulpfiction1", "pulpfiction2", "wot1", "wot2", } EXPECTED_SUBJECTS = {"sub-01", "sub-02", "sub-03", "sub-05"} def main(): parser = argparse.ArgumentParser(description="Generate submission predictions") parser.add_argument( "--checkpoint-dir", type=str, required=True, help="Path to trained model output directory (contains config.yaml, stage1_best.pt, stage2_epoch_*.pt)", ) parser.add_argument( "--test-set", type=str, default="all", choices=["friends-s7", "ood", "all"], help="Which test set(s) to generate predictions for", ) parser.add_argument("--stage2-ckpt", type=str, default=None, help="Stage 2 checkpoint filename (default: latest)") parser.add_argument("--n-timesteps", type=int, default=None, help="Override number of ODE steps for stage 2") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--datasets-root", type=str, default=None) parser.add_argument("--output-dir", type=str, default=None, help="Output directory (default: /submission)") args = parser.parse_args() ckpt_dir = Path(args.checkpoint_dir) cfg = OmegaConf.load(ckpt_dir / "config.yaml") datasets_root = Path(args.datasets_root or cfg.get("datasets_root") or DEFAULT_DATA_DIR) device = torch.device(args.device) subjects = cfg.get("subjects", list(SUBJECTS)) n_timesteps = args.n_timesteps or cfg.stage2.get("n_timesteps", 25) out_dir = Path(args.output_dir) if args.output_dir else ckpt_dir / "submission" out_dir.mkdir(parents=True, exist_ok=True) # Warn if NumPy >= 2.0 (Codabench server uses NumPy 1.x) np_version = tuple(int(x) for x in np.__version__.split(".")[:2]) if np_version[0] >= 2: warnings.warn( f"NumPy {np.__version__} detected. Codabench requires NumPy < 2.0. " f"Submissions saved with NumPy 2.x will cause a formatting error. " f"Install numpy<2.0 and re-run.", stacklevel=1, ) print(f"Checkpoint dir: {ckpt_dir}") print(f"Output dir: {out_dir}") print(f"Device: {device}") print(f"Subjects: {subjects}") print(f"Stage 2 ODE timesteps: {n_timesteps}") # --- Load features --- print("Loading features...") all_features = load_all_features(cfg, datasets_root) # --- Build models --- print("Building models...") stage1_model, stage2_models = build_models(cfg, ckpt_dir, all_features, subjects, device, args.stage2_ckpt) # --- Generate predictions --- test_sets = ["friends-s7", "ood"] if args.test_set == "all" else [args.test_set] for test_set_name in test_sets: print(f"\n{'='*60}") print(f"Generating predictions for: {test_set_name}") print(f"{'='*60}") fmri_num_samples = load_fmri_num_samples(datasets_root, test_set_name) test_loader = make_test_loader(cfg, all_features, fmri_num_samples, test_set_name) predictions = run_inference( stage1_model=stage1_model, stage2_models=stage2_models, test_loader=test_loader, fmri_num_samples=fmri_num_samples, subjects=subjects, device=device, n_timesteps=n_timesteps, ) validate_submission(predictions, test_set_name, fmri_num_samples) print_summary(predictions) save_predictions(predictions, test_set_name, out_dir) print("\nDone!") def load_all_features(cfg, datasets_root: Path) -> list[dict[str, np.ndarray]]: """Load all feature sets specified in config, for both friends and ood series.""" all_features = [] for feat_name in cfg.include_features: model, layer = feat_name.split("/") feat_cfg = cfg.features[model] model_name = feat_cfg.model layer_name = feat_cfg.layers[layer] print(f" Loading {feat_name} ({model_name}/{layer_name})") features_dir = datasets_root / "features" # Load friends (all seasons including s7) friends_features = load_sharded_features( features_dir, model=model_name, layer=layer_name, series="friends" ) # Load ood movies ood_features = load_sharded_features( features_dir, model=model_name, layer=layer_name, series="ood" ) features = {**friends_features, **ood_features} if cfg.stage1.model.global_pool == "avg": features = pool_features(features) all_features.append(features) return all_features def pool_features(features: dict[str, np.ndarray]) -> dict[str, np.ndarray]: pooled = {} for key, feat in features.items(): assert feat.ndim in {2, 3} if feat.ndim == 3: feat = feat.mean(axis=1) pooled[key] = feat return pooled def build_models(cfg, ckpt_dir: Path, all_features, subjects, device, stage2_ckpt_name=None): """Build and load stage 1 and stage 2 models from checkpoints.""" # Infer feature dimensions from an arbitrary episode sample_episode = next(iter(all_features[0])) feat_dims = [feats[sample_episode].shape[-1] for feats in all_features] print(f" Feature dims: {feat_dims}") # --- Stage 1 --- stage1_model = MultiSubjectConvLinearEncoder( num_subjects=len(subjects), feat_dims=feat_dims, **cfg.stage1.model, ).to(device) stage1_path = ckpt_dir / "stage1_best.pt" print(f" Loading stage 1: {stage1_path}") stage1_model.load_state_dict(torch.load(stage1_path, map_location=device, weights_only=True)) stage1_model.eval() # --- Stage 2 --- target_dim = 1000 # Schaefer atlas parcels cfm_params = cfg.stage2.cfm velocity_net_params = cfg.stage2.velocity_net source_ve_params = cfg.stage2.source_ve transport_params = cfg.stage2.transport stage2_models = nn.ModuleDict() for sub in subjects: sub_key = str(sub) cfm_model = CFM( feat_dim=target_dim, cfm_params=cfm_params, velocity_net_params=velocity_net_params, source_ve_params=source_ve_params, transport_params=transport_params, ) stage2_models[sub_key] = cfm_model # Find stage 2 checkpoint if stage2_ckpt_name: stage2_path = ckpt_dir / stage2_ckpt_name else: # Pick the latest stage2 checkpoint stage2_paths = sorted(ckpt_dir.glob("stage2_epoch_*.pt")) if not stage2_paths: raise FileNotFoundError(f"No stage2 checkpoints found in {ckpt_dir}") stage2_path = stage2_paths[-1] print(f" Loading stage 2: {stage2_path}") stage2_models.load_state_dict(torch.load(stage2_path, map_location=device, weights_only=True)) stage2_models = stage2_models.to(device) stage2_models.eval() return stage1_model, stage2_models def load_fmri_num_samples(datasets_root: Path, test_set_name: str) -> dict[str, dict[str, int]]: """Load per-subject, per-episode fMRI sample counts for the test set.""" if test_set_name == "friends-s7": file_pattern = "friends-s7" elif test_set_name == "ood": file_pattern = "ood" else: raise ValueError(f"Unknown test set: {test_set_name}") fmri_dir = datasets_root / "algonauts_2025.competitors" / "fmri" sample_paths = sorted(fmri_dir.rglob(f"*_{file_pattern}_fmri_samples.npy")) if not sample_paths: raise FileNotFoundError( f"No fmri_samples files found for {test_set_name} in {fmri_dir}. " f"Expected pattern: *_{file_pattern}_fmri_samples.npy" ) fmri_num_samples = {} for path in sample_paths: # path like: .../sub-01/target_sample_number/sub-01_friends-s7_fmri_samples.npy sub = path.parents[1].name # e.g. "sub-01" fmri_num_samples[sub] = np.load(path, allow_pickle=True).item() print(f" Loaded fmri_num_samples for {list(fmri_num_samples.keys())}") return fmri_num_samples def make_test_loader(cfg, all_features, fmri_num_samples, test_set_name): """Create a DataLoader for the specified test set.""" from torch.utils.data import DataLoader all_episodes = list(all_features[0]) if test_set_name == "friends-s7": filter_fn = episode_filter(seasons=[7], movies=[]) elif test_set_name == "ood": filter_fn = episode_filter(seasons=[], movies=OOD_MOVIES) else: raise ValueError(f"Unknown test set: {test_set_name}") ds_episodes = sorted(filter(filter_fn, all_episodes)) print(f" Episodes ({len(ds_episodes)}): {ds_episodes}") # Use fmri_num_samples as the authoritative episode list to ensure # we produce exactly the required keys (no extras, no missing) expected_episodes = set() for sub_samples in fmri_num_samples.values(): expected_episodes.update(sub_samples.keys()) # Only keep episodes that are both in features and in expected set ds_episodes = [ep for ep in ds_episodes if ep in expected_episodes] # Also add any expected episodes present in features but missed by filter feature_episodes = set(all_episodes) for ep in expected_episodes: if ep not in ds_episodes and ep in feature_episodes: ds_episodes.append(ep) ds_episodes = sorted(ds_episodes) # Build per-episode max sample count across subjects episode_num_samples = {} for ep in ds_episodes: max_samples = max( fmri_num_samples[sub].get(ep, 0) for sub in fmri_num_samples ) if max_samples == 0: print(f" Warning: no fmri_num_samples for episode {ep}, skipping") continue episode_num_samples[ep] = max_samples ds_episodes = [ep for ep in ds_episodes if ep in episode_num_samples] missing = expected_episodes - set(ds_episodes) if missing: raise ValueError( f"Missing episodes in features: {missing}. " f"Cannot produce complete submission." ) dataset = Algonauts2025Dataset( episode_list=ds_episodes, feat_data=all_features, fmri_num_samples=episode_num_samples, sample_length=None, shuffle=False, ) loader = DataLoader(dataset, batch_size=1) return loader @torch.no_grad() def run_inference( *, stage1_model: nn.Module, stage2_models: nn.ModuleDict, test_loader, fmri_num_samples: dict[str, dict[str, int]], subjects: list[int], device: torch.device, n_timesteps: int = 25, ) -> dict[str, dict[str, np.ndarray]]: """Run two-stage inference and collect per-subject, per-episode predictions.""" stage1_model.eval() stage2_models.eval() submission = {f"sub-{sub:02d}": {} for sub in subjects} for batch_idx, batch in enumerate(test_loader): feats = [f.to(device) for f in batch["features"]] episodes = batch["episode"] # Stage 1: mean anchor prediction (N, S, T, V) mu_anchor = stage1_model(feats) N, S, T, V = mu_anchor.shape assert N == 1, "Batch size must be 1 for submission" # Stage 2: per-subject flow matching refinement batch_preds = [] for i, sub in enumerate(subjects): sub_key = str(sub) cfm = stage2_models[sub_key] # mu for this subject: (N, T, V) -> (N, V, T) for CFM mu = mu_anchor[:, i].transpose(1, 2) # Flow matching inference pred = cfm(mu, n_timesteps=n_timesteps) # (N, V, T) pred = pred.transpose(1, 2) # (N, T, V) batch_preds.append(pred) # Store predictions per episode and subject for ii, episode in enumerate(episodes): for jj, sub_id in enumerate(subjects): sub = f"sub-{sub_id:02d}" pred = batch_preds[jj][ii].cpu().numpy() # (T, V) # Truncate to exact fmri_num_samples for this subject num_samples = fmri_num_samples[sub].get(episode, len(pred)) pred = pred[:num_samples].astype(np.float32) submission[sub][episode] = pred if (batch_idx + 1) % 10 == 0: print(f" Processed {batch_idx + 1} episodes...") return submission def validate_submission( predictions: dict[str, dict[str, np.ndarray]], test_set_name: str, fmri_num_samples: dict[str, dict[str, int]], ): """Validate submission keys, shapes, and dtypes match challenge requirements.""" # Check subject keys subject_keys = set(predictions.keys()) if subject_keys != EXPECTED_SUBJECTS: extra = subject_keys - EXPECTED_SUBJECTS missing = EXPECTED_SUBJECTS - subject_keys raise ValueError( f"Subject key mismatch. Extra: {extra}, Missing: {missing}. " f"Expected exactly: {EXPECTED_SUBJECTS}" ) # Determine expected episode keys from fmri_num_samples if test_set_name == "ood": expected_episodes = EXPECTED_OOD_KEYS elif test_set_name == "friends-s7": # For friends-s7, expected keys come from fmri_num_samples # (the ground-truth sample counts define the required episodes) all_episodes = set() for sub_samples in fmri_num_samples.values(): all_episodes.update(sub_samples.keys()) expected_episodes = all_episodes else: print(f" Warning: no key validation for test set '{test_set_name}'") return # Check episode keys per subject for sub, episodes_dict in predictions.items(): episode_keys = set(episodes_dict.keys()) extra = episode_keys - expected_episodes missing = expected_episodes - episode_keys if extra: raise ValueError( f"{sub}: extra episode keys {extra} — these will cause a formatting error" ) if missing: raise ValueError( f"{sub}: missing episode keys {missing} — submission is incomplete" ) # Validate shapes and dtype for ep, pred in episodes_dict.items(): expected_n = fmri_num_samples[sub].get(ep) if expected_n is not None and pred.shape[0] != expected_n: raise ValueError( f"{sub}/{ep}: shape {pred.shape} but expected N={expected_n}" ) if pred.shape[1] != 1000: raise ValueError( f"{sub}/{ep}: shape {pred.shape} but expected 1000 parcels" ) if pred.dtype != np.float32: raise ValueError( f"{sub}/{ep}: dtype {pred.dtype} but expected float32" ) print(f" Validation passed: {len(predictions)} subjects, " f"{len(expected_episodes)} episodes each") def print_summary(predictions: dict[str, dict[str, np.ndarray]]): """Print a summary of the generated predictions.""" for subject, episodes_dict in predictions.items(): print(f" {subject}: {len(episodes_dict)} episodes") for episode, pred in episodes_dict.items(): print(f" {episode}: {pred.shape} {pred.dtype}") def save_predictions( predictions: dict[str, dict[str, np.ndarray]], test_set_name: str, out_dir: Path, ): """Save predictions as .npy and .zip files.""" file_name = f"fmri_predictions_{test_set_name.replace('-', '_')}" npy_path = out_dir / f"{file_name}.npy" np.save(npy_path, predictions) print(f" Saved: {npy_path}") zip_path = out_dir / f"{file_name}.zip" with zipfile.ZipFile(zip_path, "w") as zipf: zipf.write(npy_path, npy_path.name) print(f" Saved: {zip_path}") if __name__ == "__main__": main()