fea-surrogate / src /deploy /push_to_hub.py
WolfDavid's picture
Upload folder using huggingface_hub
8e5ba9e verified
"""Push model artifacts and Gradio app to Hugging Face Hub.
Uploads:
1. Model ensemble checkpoints to a model repo
2. Gradio app source to a Space repo
Usage:
python -m src.deploy.push_to_hub --config configs/deployment.yaml
"""
import argparse
import logging
import shutil
from pathlib import Path
import yaml
from huggingface_hub import HfApi, create_repo
logger = logging.getLogger(__name__)
def push_model(config: dict) -> None:
"""Push model checkpoints and normalization params to HF Hub."""
api = HfApi()
repo_id = config["huggingface"]["model_repo"]
if config["huggingface"].get("organization"):
repo_id = f"{config['huggingface']['organization']}/{repo_id}"
create_repo(repo_id, repo_type="model", exist_ok=True)
checkpoint_dir = Path("artifacts/checkpoints")
# Upload model ensemble
ensemble_dir = checkpoint_dir / "model_ensemble"
if ensemble_dir.exists():
for pt_file in ensemble_dir.glob("*.pt"):
api.upload_file(
path_or_fileobj=str(pt_file),
path_in_repo=f"model_ensemble/{pt_file.name}",
repo_id=repo_id,
repo_type="model",
)
logger.info(f"Uploaded {pt_file.name}")
# Upload normalization params
norm_path = checkpoint_dir / "normalization_params.json"
if norm_path.exists():
api.upload_file(
path_or_fileobj=str(norm_path),
path_in_repo="normalization_params.json",
repo_id=repo_id,
repo_type="model",
)
# Upload model config
config_path = checkpoint_dir / "model_config.json"
if config_path.exists():
api.upload_file(
path_or_fileobj=str(config_path),
path_in_repo="config.json",
repo_id=repo_id,
repo_type="model",
)
# Upload model card
card_path = Path("docs/MODEL_CARD.md")
if card_path.exists():
api.upload_file(
path_or_fileobj=str(card_path),
path_in_repo="README.md",
repo_id=repo_id,
repo_type="model",
)
logger.info(f"Model pushed to: https://huggingface.co/{repo_id}")
def push_space(config: dict) -> None:
"""Push Gradio app to HF Spaces."""
api = HfApi()
repo_id = config["huggingface"]["space_repo"]
if config["huggingface"].get("organization"):
repo_id = f"{config['huggingface']['organization']}/{repo_id}"
create_repo(repo_id, repo_type="space", space_sdk="gradio", exist_ok=True)
# Files needed for the Space
files_to_upload = [
("src/app/app.py", "src/app/app.py"),
("src/app/materials.py", "src/app/materials.py"),
("src/app/visualizations.py", "src/app/visualizations.py"),
("src/app/__init__.py", "src/app/__init__.py"),
("src/__init__.py", "src/__init__.py"),
("src/data/solvers/base.py", "src/data/solvers/base.py"),
("src/data/solvers/beam.py", "src/data/solvers/beam.py"),
("src/data/solvers/plate.py", "src/data/solvers/plate.py"),
("src/data/solvers/vessel.py", "src/data/solvers/vessel.py"),
("src/data/solvers/__init__.py", "src/data/solvers/__init__.py"),
("src/data/__init__.py", "src/data/__init__.py"),
("src/data/schema.py", "src/data/schema.py"),
("src/models/architecture.py", "src/models/architecture.py"),
("src/models/ensemble.py", "src/models/ensemble.py"),
("src/models/normalization.py", "src/models/normalization.py"),
("src/models/physics_loss.py", "src/models/physics_loss.py"),
("src/models/__init__.py", "src/models/__init__.py"),
]
for local_path, repo_path in files_to_upload:
if Path(local_path).exists():
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=repo_path,
repo_id=repo_id,
repo_type="space",
)
# Upload artifacts if they exist
checkpoint_dir = Path("artifacts/checkpoints")
if checkpoint_dir.exists():
for f in (checkpoint_dir / "model_ensemble").glob("*.pt"):
api.upload_file(
path_or_fileobj=str(f),
path_in_repo=f"artifacts/checkpoints/model_ensemble/{f.name}",
repo_id=repo_id,
repo_type="space",
)
for json_file in ["normalization_params.json", "model_config.json"]:
json_path = checkpoint_dir / json_file
if json_path.exists():
api.upload_file(
path_or_fileobj=str(json_path),
path_in_repo=f"artifacts/checkpoints/{json_file}",
repo_id=repo_id,
repo_type="space",
)
# Upload requirements
api.upload_file(
path_or_fileobj="pyproject.toml",
path_in_repo="pyproject.toml",
repo_id=repo_id,
repo_type="space",
)
logger.info(f"Space pushed to: https://huggingface.co/spaces/{repo_id}")
def main() -> None:
parser = argparse.ArgumentParser(description="Push to HF Hub")
parser.add_argument("--config", default="configs/deployment.yaml")
parser.add_argument("--model-only", action="store_true")
parser.add_argument("--space-only", action="store_true")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
with open(args.config) as f:
config = yaml.safe_load(f)
if not args.space_only:
push_model(config)
if not args.model_only:
push_space(config)
if __name__ == "__main__":
main()