Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import pandas as pd | |
| PROJECT_ROOT = Path(__file__).resolve().parent | |
| SRC_ROOT = PROJECT_ROOT / "src" | |
| sys.path.insert(0, str(SRC_ROOT)) | |
| from demo_backend.assets import ensure_assets_from_hub | |
| from demo_backend.paths import ( | |
| DEFAULT_OUTPUT_DIR, | |
| EXAMPLE_EHR, | |
| resolve_default_avra_script, | |
| resolve_default_checkpoint, | |
| resolve_default_stats, | |
| ) | |
| from demo_backend.pipeline import run_full_inference | |
| def main() -> None: | |
| sample_nifti = PROJECT_ROOT.parent / "avra_public-master" / "inference" / "Oasis4" / "sub-OAS42000_sess-d3016_run-01_T1w_mni_dof_6.nii" | |
| if not sample_nifti.exists(): | |
| sample_nifti = PROJECT_ROOT.parent / "inference" / "Oasis4" / "sub-OAS42000_sess-d3016_run-01_T1w_mni_dof_6.nii" | |
| if not sample_nifti.exists(): | |
| raise FileNotFoundError( | |
| "No sample NIfTI found. Set up a local test NIfTI and update run_dummy.py path." | |
| ) | |
| if not EXAMPLE_EHR.exists(): | |
| raise FileNotFoundError(f"Sample EHR CSV not found: {EXAMPLE_EHR}") | |
| assets_repo_id = os.getenv("HF_ASSETS_REPO_ID", "") | |
| if assets_repo_id: | |
| ensure_assets_from_hub(repo_id=assets_repo_id, revision=os.getenv("HF_ASSETS_REVISION", "main")) | |
| ehr_row = pd.read_csv(EXAMPLE_EHR).iloc[0].to_dict() | |
| out_dir = DEFAULT_OUTPUT_DIR / "dummy_cli" | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| result = run_full_inference( | |
| nifti_path=sample_nifti, | |
| ehr_row=ehr_row, | |
| output_dir=out_dir, | |
| checkpoint_path=resolve_default_checkpoint(), | |
| stats_path=resolve_default_stats(), | |
| avra_script_path=resolve_default_avra_script(), | |
| device="auto", | |
| allow_foundation_backbones=False, | |
| reuse_registration=True, | |
| use_fsl=False, | |
| enable_remote_medgemma_report=False, | |
| ) | |
| print("Predicted:", result["prediction"]["predicted_class_name"]) | |
| print("Confidence:", max(result["prediction"]["class_probabilities"].values())) | |
| print("Output JSON:", result["output_json"]) | |
| pretty = out_dir / "dummy_cli_summary.json" | |
| pretty.write_text(json.dumps(result, indent=2), encoding="utf-8") | |
| print("Saved summary:", pretty) | |
| if __name__ == "__main__": | |
| main() | |