kaurm43 commited on
Commit
d6c1ff7
·
verified ·
1 Parent(s): 9c5022f

Update PolyAgent/orchestrator.py

Browse files
Files changed (1) hide show
  1. PolyAgent/orchestrator.py +58 -12
PolyAgent/orchestrator.py CHANGED
@@ -20,6 +20,7 @@ import sys
20
  from pathlib import Path
21
  from typing import Dict, Any, List, Optional, Tuple
22
  from urllib.parse import urlparse
 
23
 
24
  import numpy as np
25
  import torch
@@ -76,24 +77,69 @@ SELFIES_AVAILABLE = sf is not None
76
  # =============================================================================
77
  class PathsConfig:
78
  """
79
- Centralized path placeholders. Replace these with your local paths.
 
 
 
 
80
  """
81
- # CL weights
82
- cl_weights_path = "/path/to/multimodal_output_5M/best/pytorch_model.bin"
83
 
84
- # Chroma DB (local RAG vectorstore persist dir)
85
- chroma_db_path = "/path/to/chroma_polymer_db_big"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- # SentencePiece model + vocab
88
- spm_model_path = "/path/to/spm_5M.model"
89
- spm_vocab_path = "/path/to/spm_5M.vocab"
90
 
91
- # Downstream bestweights directory produced by your 5M downstream script
92
- downstream_bestweights_5m_dir = "/path/to/multimodal_downstream_bestweights_5M"
93
 
94
- # Inverse-design generator artifacts directory produced by your 5M inverse design run
95
- inverse_design_5m_dir = "/path/to/multimodal_inverse_design_output_5M_polybart_style/best_models"
 
 
96
 
 
 
 
 
97
 
98
  # =============================================================================
99
  # DOI NORMALIZATION / RESOLUTION HELPERS
 
20
  from pathlib import Path
21
  from typing import Dict, Any, List, Optional, Tuple
22
  from urllib.parse import urlparse
23
+ from huggingface_hub import snapshot_download
24
 
25
  import numpy as np
26
  import torch
 
77
  # =============================================================================
78
  class PathsConfig:
79
  """
80
+ Centralized paths for Spaces/local runs.
81
+
82
+ On Hugging Face Spaces:
83
+ - Downloads required artifacts from a HF Model repo (weights) into a local cache dir
84
+ - Exposes stable local filesystem paths used by the rest of orchestrator.py
85
  """
 
 
86
 
87
+ def __init__(self):
88
+ # 1) HF model repo where you uploaded the staged bundle
89
+ # Example: "kaurm43/PolyFusionAgent-weights-5m" (change to your real repo_id)
90
+ self.hf_repo_id = os.getenv("POLYFUSION_WEIGHTS_REPO", "kaurm43/PolyFusionAgent-weights-5m")
91
+ self.hf_repo_type = os.getenv("POLYFUSION_WEIGHTS_REPO_TYPE", "model") # usually "model"
92
+
93
+ # 2) Where to store downloaded files
94
+ # Prefer /data on Spaces with persistent storage; else use a cache folder.
95
+ default_root = "/data/polyfusion_cache" if os.path.isdir("/data") else os.path.expanduser("~/.cache/polyfusion_cache")
96
+ self.local_weights_root = os.getenv("POLYFUSION_WEIGHTS_DIR", default_root)
97
+
98
+ # 3) Optional token (only needed if the weights repo is private)
99
+ self.hf_token = os.getenv("HF_TOKEN", None)
100
+
101
+ # 4) Download (cached) + get local folder path.
102
+ # allow_patterns keeps download smaller/faster (only pull what orchestrator needs).
103
+ allow = [
104
+ "tokenizer_spm_5m/**",
105
+ "polyfusion_cl_5m/**",
106
+ "downstream_heads_5m/**",
107
+ "inverse_design_5m/**",
108
+ "MANIFEST.txt",
109
+ ]
110
+
111
+ self._weights_dir = snapshot_download(
112
+ repo_id=self.hf_repo_id,
113
+ repo_type=self.hf_repo_type,
114
+ local_dir=self.local_weights_root,
115
+ local_dir_use_symlinks=False,
116
+ token=self.hf_token,
117
+ allow_patterns=allow,
118
+ )
119
+
120
+ # 5) Map to the exact files your existing code expects
121
+ # (Only path wiring changes; no behavior changes elsewhere.)
122
+ self.cl_weights_path = os.path.join(self._weights_dir, "polyfusion_cl_5m", "pytorch_model.bin")
123
+
124
+ # If your Space also includes a local Chroma DB folder in the Space repo,
125
+ # keep this as-is. Otherwise, you can also host Chroma DB as a dataset/model repo.
126
+ self.chroma_db_path = os.getenv("CHROMA_DB_PATH", "chroma_polymer_db_big")
127
 
128
+ self.spm_model_path = os.path.join(self._weights_dir, "tokenizer_spm_5m", "spm.model")
129
+ self.spm_vocab_path = os.path.join(self._weights_dir, "tokenizer_spm_5m", "spm.vocab")
 
130
 
131
+ self.downstream_bestweights_5m_dir = os.path.join(self._weights_dir, "downstream_heads_5m")
132
+ self.inverse_design_5m_dir = os.path.join(self._weights_dir, "inverse_design_5m")
133
 
134
+ # 6) Optional: sanity-check required files (fail early with a clear message)
135
+ self._assert_exists(self.cl_weights_path, "CL weights")
136
+ self._assert_exists(self.spm_model_path, "SentencePiece model")
137
+ self._assert_exists(self.spm_vocab_path, "SentencePiece vocab")
138
 
139
+ @staticmethod
140
+ def _assert_exists(p: str, label: str):
141
+ if not os.path.exists(p):
142
+ raise FileNotFoundError(f"{label} not found at: {p}")
143
 
144
  # =============================================================================
145
  # DOI NORMALIZATION / RESOLUTION HELPERS