EchoLVFM / space /scripts /build_samples.py
EngEmmanuel's picture
Initial commit
0f5513d
"""Build the bundled gallery for the EchoLVFM Space.
For each chosen patient, this script:
1. Copies the latent `.pt` from `sample_data/CAMUS_Latents_4f4/<patient>.pt`
into `space/samples/<patient>.pt`.
2. Copies the matching real frames from
`<source_frames_root>/<patient>/frame_*.png` into
`space/samples/<patient>/`. If the patient has more frames than the
model's `max_frames` (32), the frames are subsampled with the same
`linspace` indices that `EchoDataset.resample_sequence` would use, so
the real frames stay aligned 1:1 with the latent.
3. Writes/updates `space/samples/manifest.json` with one entry per patient
containing: `id`, `latent_path`, `real_frames_dir`, `ef_true`,
`t_real` (= min(NbFrame, 32)), `fps_orig`, `view`.
Usage:
python space/scripts/build_samples.py \
--frames-root C:\\path\\to\\flow_matching\\data\\CAMUS_Processed_Frames
Defaults pick the 5 long-video patients added to `sample_data/` (NbFrame
30–37).
"""
from __future__ import annotations
import argparse
import json
import shutil
from pathlib import Path
import pandas as pd
import torch
from PIL import Image
REPO_ROOT = Path(__file__).resolve().parents[2]
SAMPLE_DATA_DIR = REPO_ROOT / "sample_data" / "CAMUS_Latents_4f4"
SAMPLES_DIR = REPO_ROOT / "space" / "samples"
DEFAULT_PATIENTS = [
"patient0082_4CH",
"patient0106_4CH",
"patient0310_4CH",
"patient0326_4CH",
"patient0422_4CH",
]
MAX_FRAMES = 32 # Matches `cfg.dataset.max_frames` for all three checkpoints.
def _resample_indices(t_orig: int, target: int) -> list[int]:
"""Mirror `EchoDataset.resample_sequence` index selection for T > target."""
return torch.linspace(0, t_orig - 1, target).round().long().tolist()
def _copy_real_frames(src_dir: Path, dst_dir: Path, t_real: int) -> None:
"""Copy frame_*.png files. If the source has > MAX_FRAMES, subsample with
linspace indices so frames align with the (resampled) latent."""
src_frames = sorted(src_dir.glob("frame_*.png"),
key=lambda p: int(p.stem.split("_")[-1]))
if not src_frames:
raise FileNotFoundError(f"No frame_*.png in {src_dir}")
n_src = len(src_frames)
if n_src > MAX_FRAMES:
idxs = _resample_indices(n_src, MAX_FRAMES)
chosen = [src_frames[i] for i in idxs]
else:
chosen = src_frames # T_real == n_src
if len(chosen) != t_real:
raise RuntimeError(
f"Expected {t_real} frames after resampling for {src_dir.name}, "
f"got {len(chosen)} (n_src={n_src})"
)
dst_dir.mkdir(parents=True, exist_ok=True)
for new_idx, p in enumerate(chosen):
# Write 0-indexed to match `frame_to_mp4` ordering.
Image.open(p).convert("RGB").save(dst_dir / f"frame_{new_idx}.png")
def build(patients: list[str], frames_root: Path) -> None:
if not SAMPLE_DATA_DIR.exists():
raise FileNotFoundError(f"Sample data dir not found: {SAMPLE_DATA_DIR}")
meta_csv = SAMPLE_DATA_DIR / "metadata.csv"
df = pd.read_csv(meta_csv).set_index("video_name")
SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
manifest = {"samples": []}
for pid in patients:
if pid not in df.index:
raise KeyError(f"{pid} not in {meta_csv}")
row = df.loc[pid]
n_orig = int(row["NbFrame"])
t_real = min(n_orig, MAX_FRAMES)
latent_src = SAMPLE_DATA_DIR / f"{pid}.pt"
if not latent_src.exists():
raise FileNotFoundError(latent_src)
shutil.copyfile(latent_src, SAMPLES_DIR / f"{pid}.pt")
frames_src = frames_root / pid
if not frames_src.is_dir():
raise FileNotFoundError(frames_src)
_copy_real_frames(frames_src, SAMPLES_DIR / pid, t_real)
manifest["samples"].append({
"id": pid,
"latent_path": f"{pid}.pt",
"real_frames_dir": pid,
"ef_true": float(row["EF_AL"]),
"t_real": t_real,
"n_orig": n_orig,
"fps_orig": float(row["FrameRate"]),
"view": str(row["view"]),
})
print(f" + {pid}: t_real={t_real} (n_orig={n_orig}), "
f"ef_true={row['EF_AL']}, fps={row['FrameRate']}")
manifest_path = SAMPLES_DIR / "manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2))
print(f"\nWrote {manifest_path} with {len(manifest['samples'])} samples")
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--frames-root", required=True, type=Path,
help="Path to CAMUS_Processed_Frames root.")
p.add_argument("--patients", nargs="+", default=DEFAULT_PATIENTS,
help="Patient ids to bundle.")
return p.parse_args()
def main() -> None:
args = _parse_args()
build(args.patients, args.frames_root)
if __name__ == "__main__":
main()