File size: 2,453 Bytes
a4b5ecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""Startup wrapper for HuggingFace Spaces deployment.

Downloads models from DTanzillo/panacea-models on first run,
then starts the FastAPI application on port 7860.
"""

import os
import sys
import shutil
from pathlib import Path

# Ensure the app root is on the Python path
ROOT = Path(__file__).parent
sys.path.insert(0, str(ROOT))


def download_models():
    """Download models from HuggingFace Hub if not present locally."""
    model_dir = ROOT / "models"
    results_dir = ROOT / "results"
    model_dir.mkdir(exist_ok=True)
    results_dir.mkdir(exist_ok=True)

    # Check if models already exist
    needed_files = ["baseline.json", "xgboost.pkl", "transformer.pt"]
    all_present = all((model_dir / f).exists() for f in needed_files)

    if all_present:
        print("Models already present, skipping download.")
        return

    print("Downloading models from DTanzillo/panacea-models ...")
    try:
        from huggingface_hub import snapshot_download

        token = os.environ.get("HF_TOKEN")
        local = Path(snapshot_download(
            "DTanzillo/panacea-models",
            token=token,
            allow_patterns=["models/*", "results/*"],
        ))

        # Copy model files
        hf_models = local / "models"
        if hf_models.exists():
            for src_file in hf_models.iterdir():
                dst_file = model_dir / src_file.name
                if not dst_file.exists():
                    shutil.copy2(src_file, dst_file)
                    print(f"  Copied {src_file.name}")

        # Copy result files (only if missing)
        hf_results = local / "results"
        if hf_results.exists():
            for src_file in hf_results.iterdir():
                dst_file = results_dir / src_file.name
                if not dst_file.exists():
                    shutil.copy2(src_file, dst_file)
                    print(f"  Copied result: {src_file.name}")

        print("Model download complete.")
    except Exception as e:
        print(f"WARNING: Model download failed: {e}")
        print("The API will start but models may not be available.")


if __name__ == "__main__":
    # Step 1: Download models
    download_models()

    # Step 2: Start uvicorn
    import uvicorn
    port = int(os.environ.get("PORT", 7860))
    print(f"Starting Panacea API on port {port} ...")
    uvicorn.run(
        "app.main:app",
        host="0.0.0.0",
        port=port,
        log_level="info",
    )