"""Build the bundled gallery for the EchoLVFM Space. For each chosen patient, this script: 1. Copies the latent `.pt` from `sample_data/CAMUS_Latents_4f4/.pt` into `space/samples/.pt`. 2. Copies the matching real frames from `//frame_*.png` into `space/samples//`. If the patient has more frames than the model's `max_frames` (32), the frames are subsampled with the same `linspace` indices that `EchoDataset.resample_sequence` would use, so the real frames stay aligned 1:1 with the latent. 3. Writes/updates `space/samples/manifest.json` with one entry per patient containing: `id`, `latent_path`, `real_frames_dir`, `ef_true`, `t_real` (= min(NbFrame, 32)), `fps_orig`, `view`. Usage: python space/scripts/build_samples.py \ --frames-root C:\\path\\to\\flow_matching\\data\\CAMUS_Processed_Frames Defaults pick the 5 long-video patients added to `sample_data/` (NbFrame 30–37). """ from __future__ import annotations import argparse import json import shutil from pathlib import Path import pandas as pd import torch from PIL import Image REPO_ROOT = Path(__file__).resolve().parents[2] SAMPLE_DATA_DIR = REPO_ROOT / "sample_data" / "CAMUS_Latents_4f4" SAMPLES_DIR = REPO_ROOT / "space" / "samples" DEFAULT_PATIENTS = [ "patient0082_4CH", "patient0106_4CH", "patient0310_4CH", "patient0326_4CH", "patient0422_4CH", ] MAX_FRAMES = 32 # Matches `cfg.dataset.max_frames` for all three checkpoints. def _resample_indices(t_orig: int, target: int) -> list[int]: """Mirror `EchoDataset.resample_sequence` index selection for T > target.""" return torch.linspace(0, t_orig - 1, target).round().long().tolist() def _copy_real_frames(src_dir: Path, dst_dir: Path, t_real: int) -> None: """Copy frame_*.png files. If the source has > MAX_FRAMES, subsample with linspace indices so frames align with the (resampled) latent.""" src_frames = sorted(src_dir.glob("frame_*.png"), key=lambda p: int(p.stem.split("_")[-1])) if not src_frames: raise FileNotFoundError(f"No frame_*.png in {src_dir}") n_src = len(src_frames) if n_src > MAX_FRAMES: idxs = _resample_indices(n_src, MAX_FRAMES) chosen = [src_frames[i] for i in idxs] else: chosen = src_frames # T_real == n_src if len(chosen) != t_real: raise RuntimeError( f"Expected {t_real} frames after resampling for {src_dir.name}, " f"got {len(chosen)} (n_src={n_src})" ) dst_dir.mkdir(parents=True, exist_ok=True) for new_idx, p in enumerate(chosen): # Write 0-indexed to match `frame_to_mp4` ordering. Image.open(p).convert("RGB").save(dst_dir / f"frame_{new_idx}.png") def build(patients: list[str], frames_root: Path) -> None: if not SAMPLE_DATA_DIR.exists(): raise FileNotFoundError(f"Sample data dir not found: {SAMPLE_DATA_DIR}") meta_csv = SAMPLE_DATA_DIR / "metadata.csv" df = pd.read_csv(meta_csv).set_index("video_name") SAMPLES_DIR.mkdir(parents=True, exist_ok=True) manifest = {"samples": []} for pid in patients: if pid not in df.index: raise KeyError(f"{pid} not in {meta_csv}") row = df.loc[pid] n_orig = int(row["NbFrame"]) t_real = min(n_orig, MAX_FRAMES) latent_src = SAMPLE_DATA_DIR / f"{pid}.pt" if not latent_src.exists(): raise FileNotFoundError(latent_src) shutil.copyfile(latent_src, SAMPLES_DIR / f"{pid}.pt") frames_src = frames_root / pid if not frames_src.is_dir(): raise FileNotFoundError(frames_src) _copy_real_frames(frames_src, SAMPLES_DIR / pid, t_real) manifest["samples"].append({ "id": pid, "latent_path": f"{pid}.pt", "real_frames_dir": pid, "ef_true": float(row["EF_AL"]), "t_real": t_real, "n_orig": n_orig, "fps_orig": float(row["FrameRate"]), "view": str(row["view"]), }) print(f" + {pid}: t_real={t_real} (n_orig={n_orig}), " f"ef_true={row['EF_AL']}, fps={row['FrameRate']}") manifest_path = SAMPLES_DIR / "manifest.json" manifest_path.write_text(json.dumps(manifest, indent=2)) print(f"\nWrote {manifest_path} with {len(manifest['samples'])} samples") def _parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument("--frames-root", required=True, type=Path, help="Path to CAMUS_Processed_Frames root.") p.add_argument("--patients", nargs="+", default=DEFAULT_PATIENTS, help="Patient ids to bundle.") return p.parse_args() def main() -> None: args = _parse_args() build(args.patients, args.frames_root) if __name__ == "__main__": main()