Spaces:
Sleeping
Sleeping
| """Startup wrapper for HuggingFace Spaces deployment. | |
| Downloads models from DTanzillo/panacea-models on first run, | |
| then starts the FastAPI application on port 7860. | |
| """ | |
| import os | |
| import sys | |
| import shutil | |
| from pathlib import Path | |
| # Ensure the app root is on the Python path | |
| ROOT = Path(__file__).parent | |
| sys.path.insert(0, str(ROOT)) | |
| def download_models(): | |
| """Download models from HuggingFace Hub if not present locally.""" | |
| model_dir = ROOT / "models" | |
| results_dir = ROOT / "results" | |
| model_dir.mkdir(exist_ok=True) | |
| results_dir.mkdir(exist_ok=True) | |
| # Check if models already exist | |
| needed_files = ["baseline.json", "xgboost.pkl", "transformer.pt"] | |
| all_present = all((model_dir / f).exists() for f in needed_files) | |
| if all_present: | |
| print("Models already present, skipping download.") | |
| return | |
| print("Downloading models from DTanzillo/panacea-models ...") | |
| try: | |
| from huggingface_hub import snapshot_download | |
| token = os.environ.get("HF_TOKEN") | |
| local = Path(snapshot_download( | |
| "DTanzillo/panacea-models", | |
| token=token, | |
| allow_patterns=["models/*", "results/*"], | |
| )) | |
| # Copy model files | |
| hf_models = local / "models" | |
| if hf_models.exists(): | |
| for src_file in hf_models.iterdir(): | |
| dst_file = model_dir / src_file.name | |
| if not dst_file.exists(): | |
| shutil.copy2(src_file, dst_file) | |
| print(f" Copied {src_file.name}") | |
| # Copy result files (only if missing) | |
| hf_results = local / "results" | |
| if hf_results.exists(): | |
| for src_file in hf_results.iterdir(): | |
| dst_file = results_dir / src_file.name | |
| if not dst_file.exists(): | |
| shutil.copy2(src_file, dst_file) | |
| print(f" Copied result: {src_file.name}") | |
| print("Model download complete.") | |
| except Exception as e: | |
| print(f"WARNING: Model download failed: {e}") | |
| print("The API will start but models may not be available.") | |
| if __name__ == "__main__": | |
| # Step 1: Download models | |
| download_models() | |
| # Step 2: Start uvicorn | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| print(f"Starting Panacea API on port {port} ...") | |
| uvicorn.run( | |
| "app.main:app", | |
| host="0.0.0.0", | |
| port=port, | |
| log_level="info", | |
| ) | |