panacea-api / app_wrapper.py
DTanzillo's picture
Upload folder using huggingface_hub
a4b5ecb verified
"""Startup wrapper for HuggingFace Spaces deployment.
Downloads models from DTanzillo/panacea-models on first run,
then starts the FastAPI application on port 7860.
"""
import os
import sys
import shutil
from pathlib import Path
# Ensure the app root is on the Python path
ROOT = Path(__file__).parent
sys.path.insert(0, str(ROOT))
def download_models():
"""Download models from HuggingFace Hub if not present locally."""
model_dir = ROOT / "models"
results_dir = ROOT / "results"
model_dir.mkdir(exist_ok=True)
results_dir.mkdir(exist_ok=True)
# Check if models already exist
needed_files = ["baseline.json", "xgboost.pkl", "transformer.pt"]
all_present = all((model_dir / f).exists() for f in needed_files)
if all_present:
print("Models already present, skipping download.")
return
print("Downloading models from DTanzillo/panacea-models ...")
try:
from huggingface_hub import snapshot_download
token = os.environ.get("HF_TOKEN")
local = Path(snapshot_download(
"DTanzillo/panacea-models",
token=token,
allow_patterns=["models/*", "results/*"],
))
# Copy model files
hf_models = local / "models"
if hf_models.exists():
for src_file in hf_models.iterdir():
dst_file = model_dir / src_file.name
if not dst_file.exists():
shutil.copy2(src_file, dst_file)
print(f" Copied {src_file.name}")
# Copy result files (only if missing)
hf_results = local / "results"
if hf_results.exists():
for src_file in hf_results.iterdir():
dst_file = results_dir / src_file.name
if not dst_file.exists():
shutil.copy2(src_file, dst_file)
print(f" Copied result: {src_file.name}")
print("Model download complete.")
except Exception as e:
print(f"WARNING: Model download failed: {e}")
print("The API will start but models may not be available.")
if __name__ == "__main__":
# Step 1: Download models
download_models()
# Step 2: Start uvicorn
import uvicorn
port = int(os.environ.get("PORT", 7860))
print(f"Starting Panacea API on port {port} ...")
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=port,
log_level="info",
)