Spaces:
Running
Running
Update PolyAgent/orchestrator.py
Browse files- PolyAgent/orchestrator.py +58 -12
PolyAgent/orchestrator.py
CHANGED
|
@@ -20,6 +20,7 @@ import sys
|
|
| 20 |
from pathlib import Path
|
| 21 |
from typing import Dict, Any, List, Optional, Tuple
|
| 22 |
from urllib.parse import urlparse
|
|
|
|
| 23 |
|
| 24 |
import numpy as np
|
| 25 |
import torch
|
|
@@ -76,24 +77,69 @@ SELFIES_AVAILABLE = sf is not None
|
|
| 76 |
# =============================================================================
|
| 77 |
class PathsConfig:
|
| 78 |
"""
|
| 79 |
-
Centralized
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"""
|
| 81 |
-
# CL weights
|
| 82 |
-
cl_weights_path = "/path/to/multimodal_output_5M/best/pytorch_model.bin"
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
spm_vocab_path = "/path/to/spm_5M.vocab"
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
# =============================================================================
|
| 99 |
# DOI NORMALIZATION / RESOLUTION HELPERS
|
|
|
|
| 20 |
from pathlib import Path
|
| 21 |
from typing import Dict, Any, List, Optional, Tuple
|
| 22 |
from urllib.parse import urlparse
|
| 23 |
+
from huggingface_hub import snapshot_download
|
| 24 |
|
| 25 |
import numpy as np
|
| 26 |
import torch
|
|
|
|
| 77 |
# =============================================================================
|
| 78 |
class PathsConfig:
|
| 79 |
"""
|
| 80 |
+
Centralized paths for Spaces/local runs.
|
| 81 |
+
|
| 82 |
+
On Hugging Face Spaces:
|
| 83 |
+
- Downloads required artifacts from a HF Model repo (weights) into a local cache dir
|
| 84 |
+
- Exposes stable local filesystem paths used by the rest of orchestrator.py
|
| 85 |
"""
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
def __init__(self):
|
| 88 |
+
# 1) HF model repo where you uploaded the staged bundle
|
| 89 |
+
# Example: "kaurm43/PolyFusionAgent-weights-5m" (change to your real repo_id)
|
| 90 |
+
self.hf_repo_id = os.getenv("POLYFUSION_WEIGHTS_REPO", "kaurm43/PolyFusionAgent-weights-5m")
|
| 91 |
+
self.hf_repo_type = os.getenv("POLYFUSION_WEIGHTS_REPO_TYPE", "model") # usually "model"
|
| 92 |
+
|
| 93 |
+
# 2) Where to store downloaded files
|
| 94 |
+
# Prefer /data on Spaces with persistent storage; else use a cache folder.
|
| 95 |
+
default_root = "/data/polyfusion_cache" if os.path.isdir("/data") else os.path.expanduser("~/.cache/polyfusion_cache")
|
| 96 |
+
self.local_weights_root = os.getenv("POLYFUSION_WEIGHTS_DIR", default_root)
|
| 97 |
+
|
| 98 |
+
# 3) Optional token (only needed if the weights repo is private)
|
| 99 |
+
self.hf_token = os.getenv("HF_TOKEN", None)
|
| 100 |
+
|
| 101 |
+
# 4) Download (cached) + get local folder path.
|
| 102 |
+
# allow_patterns keeps download smaller/faster (only pull what orchestrator needs).
|
| 103 |
+
allow = [
|
| 104 |
+
"tokenizer_spm_5m/**",
|
| 105 |
+
"polyfusion_cl_5m/**",
|
| 106 |
+
"downstream_heads_5m/**",
|
| 107 |
+
"inverse_design_5m/**",
|
| 108 |
+
"MANIFEST.txt",
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
self._weights_dir = snapshot_download(
|
| 112 |
+
repo_id=self.hf_repo_id,
|
| 113 |
+
repo_type=self.hf_repo_type,
|
| 114 |
+
local_dir=self.local_weights_root,
|
| 115 |
+
local_dir_use_symlinks=False,
|
| 116 |
+
token=self.hf_token,
|
| 117 |
+
allow_patterns=allow,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# 5) Map to the exact files your existing code expects
|
| 121 |
+
# (Only path wiring changes; no behavior changes elsewhere.)
|
| 122 |
+
self.cl_weights_path = os.path.join(self._weights_dir, "polyfusion_cl_5m", "pytorch_model.bin")
|
| 123 |
+
|
| 124 |
+
# If your Space also includes a local Chroma DB folder in the Space repo,
|
| 125 |
+
# keep this as-is. Otherwise, you can also host Chroma DB as a dataset/model repo.
|
| 126 |
+
self.chroma_db_path = os.getenv("CHROMA_DB_PATH", "chroma_polymer_db_big")
|
| 127 |
|
| 128 |
+
self.spm_model_path = os.path.join(self._weights_dir, "tokenizer_spm_5m", "spm.model")
|
| 129 |
+
self.spm_vocab_path = os.path.join(self._weights_dir, "tokenizer_spm_5m", "spm.vocab")
|
|
|
|
| 130 |
|
| 131 |
+
self.downstream_bestweights_5m_dir = os.path.join(self._weights_dir, "downstream_heads_5m")
|
| 132 |
+
self.inverse_design_5m_dir = os.path.join(self._weights_dir, "inverse_design_5m")
|
| 133 |
|
| 134 |
+
# 6) Optional: sanity-check required files (fail early with a clear message)
|
| 135 |
+
self._assert_exists(self.cl_weights_path, "CL weights")
|
| 136 |
+
self._assert_exists(self.spm_model_path, "SentencePiece model")
|
| 137 |
+
self._assert_exists(self.spm_vocab_path, "SentencePiece vocab")
|
| 138 |
|
| 139 |
+
@staticmethod
|
| 140 |
+
def _assert_exists(p: str, label: str):
|
| 141 |
+
if not os.path.exists(p):
|
| 142 |
+
raise FileNotFoundError(f"{label} not found at: {p}")
|
| 143 |
|
| 144 |
# =============================================================================
|
| 145 |
# DOI NORMALIZATION / RESOLUTION HELPERS
|