Spaces:
Sleeping
Sleeping
| """Push model artifacts and Gradio app to Hugging Face Hub. | |
| Uploads: | |
| 1. Model ensemble checkpoints to a model repo | |
| 2. Gradio app source to a Space repo | |
| Usage: | |
| python -m src.deploy.push_to_hub --config configs/deployment.yaml | |
| """ | |
| import argparse | |
| import logging | |
| import shutil | |
| from pathlib import Path | |
| import yaml | |
| from huggingface_hub import HfApi, create_repo | |
| logger = logging.getLogger(__name__) | |
| def push_model(config: dict) -> None: | |
| """Push model checkpoints and normalization params to HF Hub.""" | |
| api = HfApi() | |
| repo_id = config["huggingface"]["model_repo"] | |
| if config["huggingface"].get("organization"): | |
| repo_id = f"{config['huggingface']['organization']}/{repo_id}" | |
| create_repo(repo_id, repo_type="model", exist_ok=True) | |
| checkpoint_dir = Path("artifacts/checkpoints") | |
| # Upload model ensemble | |
| ensemble_dir = checkpoint_dir / "model_ensemble" | |
| if ensemble_dir.exists(): | |
| for pt_file in ensemble_dir.glob("*.pt"): | |
| api.upload_file( | |
| path_or_fileobj=str(pt_file), | |
| path_in_repo=f"model_ensemble/{pt_file.name}", | |
| repo_id=repo_id, | |
| repo_type="model", | |
| ) | |
| logger.info(f"Uploaded {pt_file.name}") | |
| # Upload normalization params | |
| norm_path = checkpoint_dir / "normalization_params.json" | |
| if norm_path.exists(): | |
| api.upload_file( | |
| path_or_fileobj=str(norm_path), | |
| path_in_repo="normalization_params.json", | |
| repo_id=repo_id, | |
| repo_type="model", | |
| ) | |
| # Upload model config | |
| config_path = checkpoint_dir / "model_config.json" | |
| if config_path.exists(): | |
| api.upload_file( | |
| path_or_fileobj=str(config_path), | |
| path_in_repo="config.json", | |
| repo_id=repo_id, | |
| repo_type="model", | |
| ) | |
| # Upload model card | |
| card_path = Path("docs/MODEL_CARD.md") | |
| if card_path.exists(): | |
| api.upload_file( | |
| path_or_fileobj=str(card_path), | |
| path_in_repo="README.md", | |
| repo_id=repo_id, | |
| repo_type="model", | |
| ) | |
| logger.info(f"Model pushed to: https://huggingface.co/{repo_id}") | |
| def push_space(config: dict) -> None: | |
| """Push Gradio app to HF Spaces.""" | |
| api = HfApi() | |
| repo_id = config["huggingface"]["space_repo"] | |
| if config["huggingface"].get("organization"): | |
| repo_id = f"{config['huggingface']['organization']}/{repo_id}" | |
| create_repo(repo_id, repo_type="space", space_sdk="gradio", exist_ok=True) | |
| # Files needed for the Space | |
| files_to_upload = [ | |
| ("src/app/app.py", "src/app/app.py"), | |
| ("src/app/materials.py", "src/app/materials.py"), | |
| ("src/app/visualizations.py", "src/app/visualizations.py"), | |
| ("src/app/__init__.py", "src/app/__init__.py"), | |
| ("src/__init__.py", "src/__init__.py"), | |
| ("src/data/solvers/base.py", "src/data/solvers/base.py"), | |
| ("src/data/solvers/beam.py", "src/data/solvers/beam.py"), | |
| ("src/data/solvers/plate.py", "src/data/solvers/plate.py"), | |
| ("src/data/solvers/vessel.py", "src/data/solvers/vessel.py"), | |
| ("src/data/solvers/__init__.py", "src/data/solvers/__init__.py"), | |
| ("src/data/__init__.py", "src/data/__init__.py"), | |
| ("src/data/schema.py", "src/data/schema.py"), | |
| ("src/models/architecture.py", "src/models/architecture.py"), | |
| ("src/models/ensemble.py", "src/models/ensemble.py"), | |
| ("src/models/normalization.py", "src/models/normalization.py"), | |
| ("src/models/physics_loss.py", "src/models/physics_loss.py"), | |
| ("src/models/__init__.py", "src/models/__init__.py"), | |
| ] | |
| for local_path, repo_path in files_to_upload: | |
| if Path(local_path).exists(): | |
| api.upload_file( | |
| path_or_fileobj=local_path, | |
| path_in_repo=repo_path, | |
| repo_id=repo_id, | |
| repo_type="space", | |
| ) | |
| # Upload artifacts if they exist | |
| checkpoint_dir = Path("artifacts/checkpoints") | |
| if checkpoint_dir.exists(): | |
| for f in (checkpoint_dir / "model_ensemble").glob("*.pt"): | |
| api.upload_file( | |
| path_or_fileobj=str(f), | |
| path_in_repo=f"artifacts/checkpoints/model_ensemble/{f.name}", | |
| repo_id=repo_id, | |
| repo_type="space", | |
| ) | |
| for json_file in ["normalization_params.json", "model_config.json"]: | |
| json_path = checkpoint_dir / json_file | |
| if json_path.exists(): | |
| api.upload_file( | |
| path_or_fileobj=str(json_path), | |
| path_in_repo=f"artifacts/checkpoints/{json_file}", | |
| repo_id=repo_id, | |
| repo_type="space", | |
| ) | |
| # Upload requirements | |
| api.upload_file( | |
| path_or_fileobj="pyproject.toml", | |
| path_in_repo="pyproject.toml", | |
| repo_id=repo_id, | |
| repo_type="space", | |
| ) | |
| logger.info(f"Space pushed to: https://huggingface.co/spaces/{repo_id}") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Push to HF Hub") | |
| parser.add_argument("--config", default="configs/deployment.yaml") | |
| parser.add_argument("--model-only", action="store_true") | |
| parser.add_argument("--space-only", action="store_true") | |
| args = parser.parse_args() | |
| logging.basicConfig(level=logging.INFO) | |
| with open(args.config) as f: | |
| config = yaml.safe_load(f) | |
| if not args.space_only: | |
| push_model(config) | |
| if not args.model_only: | |
| push_space(config) | |
| if __name__ == "__main__": | |
| main() | |