Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Upload a trained CheXVision checkpoint to HuggingFace Hub. | |
| Usage: | |
| python scripts/push_models.py --checkpoint checkpoints/CheXVision-ResNet_best.pth | |
| python scripts/push_models.py --checkpoint checkpoints/CheXVision-DenseNet_best.pth --repo-id arudaev/chexvision-densenet | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from src.utils.hub import load_hf_token, upload_model_artifacts | |
| # Default repo IDs keyed by the model type stored in the checkpoint config. | |
| DEFAULT_REPOS = { | |
| "scratch": "arudaev/chexvision-scratch", | |
| "densenet": "arudaev/chexvision-densenet", | |
| } | |
| def _detect_repo(checkpoint: dict) -> str: | |
| """Infer the default HF repo from the checkpoint's config.""" | |
| config = checkpoint.get("config", {}) | |
| model_type = config.get("model", {}).get("type", "") | |
| repo = DEFAULT_REPOS.get(model_type) | |
| if repo: | |
| return repo | |
| model_name = config.get("model", {}).get("name", "") | |
| if "DenseNet" in model_name or "densenet" in model_name: | |
| return DEFAULT_REPOS["densenet"] | |
| return DEFAULT_REPOS["scratch"] | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Push a CheXVision checkpoint to HuggingFace Hub") | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=Path, | |
| required=True, | |
| help="Path to the .pth checkpoint file.", | |
| ) | |
| parser.add_argument( | |
| "--repo-id", | |
| type=str, | |
| default=None, | |
| help="HuggingFace repo ID (e.g. arudaev/chexvision-scratch). " | |
| "Auto-detected from the checkpoint config if omitted.", | |
| ) | |
| args = parser.parse_args() | |
| ckpt_path = args.checkpoint | |
| if not ckpt_path.exists(): | |
| print(f"ERROR: Checkpoint not found: {ckpt_path}", file=sys.stderr) | |
| sys.exit(1) | |
| print(f"Loading checkpoint: {ckpt_path} ...") | |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| required_keys = {"model_state_dict", "config"} | |
| if not required_keys.issubset(checkpoint.keys()): | |
| print( | |
| f"WARNING: Checkpoint is missing expected keys {required_keys - checkpoint.keys()}. " | |
| "Proceeding anyway.", | |
| file=sys.stderr, | |
| ) | |
| epoch = checkpoint.get("epoch", "?") | |
| best_auc = checkpoint.get("best_auc", "?") | |
| print(f" Epoch: {epoch} | Best AUC-ROC: {best_auc}") | |
| repo_id = args.repo_id or _detect_repo(checkpoint) | |
| print(f" Target repo: {repo_id}") | |
| try: | |
| token = load_hf_token(required=True) | |
| except RuntimeError as exc: | |
| print(f"ERROR: {exc}", file=sys.stderr) | |
| sys.exit(1) | |
| history_path = ckpt_path.with_name(ckpt_path.name.replace("_best.pth", "_history.json")) | |
| print(f"Uploading {ckpt_path.name} to {repo_id} ...") | |
| upload_model_artifacts( | |
| checkpoint_path=ckpt_path, | |
| repo_id=repo_id, | |
| token=token, | |
| checkpoint=checkpoint, | |
| history_path=history_path if history_path.exists() else None, | |
| ) | |
| print(f"Upload complete: https://huggingface.co/{repo_id}") | |
| if __name__ == "__main__": | |
| main() | |