bobbypaton commited on
Commit
5a434bd
·
1 Parent(s): 899f17e

Fix hf_hub_download import

Browse files
Files changed (1) hide show
  1. worker.py +5 -244
worker.py CHANGED
@@ -1,79 +1,3 @@
1
- """
2
- CASCADE worker process
3
- """
4
-
5
- import os
6
- import sys
7
- import warnings
8
- warnings.filterwarnings("ignore", category=UserWarning)
9
- warnings.filterwarnings("ignore", category=DeprecationWarning)
10
- warnings.filterwarnings("ignore", category=FutureWarning)
11
-
12
- import json
13
- import math
14
- import pickle
15
- import datetime
16
- from io import StringIO
17
-
18
- import redis
19
- import numpy as np
20
-
21
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
22
-
23
- import keras
24
- from keras.models import load_model
25
- from NMR_Prediction.apply import (
26
- preprocess_C, preprocess_H,
27
- evaluate_C, evaluate_H,
28
- RBFSequence,
29
- )
30
- from nfp.layers import (
31
- MessageLayer, GRUStep, Squeeze, EdgeNetwork,
32
- ReduceBondToPro, ReduceBondToAtom,
33
- GatherAtomToBond, ReduceAtomToPro,
34
- )
35
- from nfp.models import GraphModel
36
-
37
- import pandas as pd
38
- from rdkit import Chem
39
- from rdkit.Chem import AllChem
40
- from rdkit.Chem import SDWriter
41
- from NMR_Prediction.genConf import genConf
42
-
43
- MODEL_PATH_C = os.path.join("NMR_Prediction", "schnet_edgeupdate", "best_model.hdf5")
44
- MODEL_PATH_H = os.path.join("NMR_Prediction", "schnet_edgeupdate_H", "best_model.hdf5")
45
- PREPROCESSOR_PATH = os.path.join("NMR_Prediction", "preprocessor.p")
46
-
47
- custom_objects = {
48
- "MessageLayer": MessageLayer,
49
- "GRUStep": GRUStep,
50
- "Squeeze": Squeeze,
51
- "EdgeNetwork": EdgeNetwork,
52
- "ReduceBondToPro": ReduceBondToPro,
53
- "ReduceBondToAtom": ReduceBondToAtom,
54
- "GatherAtomToBond": GatherAtomToBond,
55
- "ReduceAtomToPro": ReduceAtomToPro,
56
- "GraphModel": GraphModel,
57
- }
58
-
59
- print("Loading 13C model...", flush=True)
60
- model_C = load_model(MODEL_PATH_C, custom_objects=custom_objects)
61
- print("Loading 1H model...", flush=True)
62
- model_H = load_model(MODEL_PATH_H, custom_objects=custom_objects)
63
- print("Both models loaded.", flush=True)
64
-
65
- with open(PREPROCESSOR_PATH, "rb") as f:
66
- preprocessor = pickle.load(f)["preprocessor"]
67
-
68
- redis_client = redis.StrictRedis(
69
- host="localhost", port=6379, db=0, decode_responses=True
70
- )
71
-
72
- # ── Analytics logging to HF Dataset ──────────────────────────────────────────
73
- _HF_TOKEN = os.environ.get("HF_TOKEN", "")
74
- _ANALYTICS_REPO = "patonlab/analytics"
75
- _ANALYTICS_FILE = "data.csv"
76
-
77
  def _log_prediction():
78
  """Append one row to the existing patonlab/analytics data.csv.
79
  Format matches the alfabet log: space,timestamp
@@ -81,19 +5,19 @@ def _log_prediction():
81
  if not _HF_TOKEN:
82
  return
83
  try:
84
- from huggingface_hub import HfApi
85
  import tempfile
86
 
87
  api = HfApi(token=_HF_TOKEN)
88
  timestamp = datetime.datetime.utcnow().isoformat()
89
 
90
- # Download the current CSV, append a row, re-upload
91
  with tempfile.TemporaryDirectory() as tmpdir:
92
- local_path = os.path.join(tmpdir, "data.csv")
93
- api.hf_hub_download(
94
  repo_id=_ANALYTICS_REPO,
95
  filename=_ANALYTICS_FILE,
96
  repo_type="dataset",
 
97
  local_dir=tmpdir,
98
  )
99
  with open(local_path, "a") as f:
@@ -107,167 +31,4 @@ def _log_prediction():
107
  commit_message=f"log: cascade prediction {timestamp[:10]}",
108
  )
109
  except Exception as e:
110
- print(f"Analytics logging failed (non-fatal): {e}", flush=True)
111
-
112
-
113
- def _mol_to_sdf(mol, conf_id=0):
114
- sio = StringIO()
115
- w = SDWriter(sio)
116
- w.write(mol, confId=conf_id)
117
- w.close()
118
- return sio.getvalue()
119
-
120
-
121
- def _build_sdfs_from_genconf(mol_with_confs, ids):
122
- """
123
- Build SDF strings directly from the genConf mol using real conformer IDs.
124
- ids is a list of (energy, conf_id) tuples sorted by energy (lowest first).
125
- Returns (sdfs, energy_order).
126
- """
127
- sdfs = []
128
- energy_order = []
129
- for energy, conf_id in ids:
130
- try:
131
- sdf = _mol_to_sdf(mol_with_confs, conf_id=int(conf_id))
132
- if sdf.strip():
133
- sdfs.append(sdf)
134
- energy_order.append(int(conf_id))
135
- except Exception as e:
136
- print(f"SDF error for conf_id={conf_id}: {e}", flush=True)
137
- return sdfs, energy_order
138
-
139
-
140
- def _boltzmann_average(spread_df):
141
- spread_df["b_weight"] = spread_df["relative_E"].apply(
142
- lambda x: math.exp(-x / (0.001987 * 298.15))
143
- )
144
- df_group = spread_df.set_index(["mol_id", "atom_index", "cf_id"]).groupby(level=[0, 1])
145
- final = []
146
- for (m_id, a_id), df in df_group:
147
- ws = (df["b_weight"] * df["predicted"]).sum() / df["b_weight"].sum()
148
- final.append([m_id, a_id, ws])
149
- final = pd.DataFrame(final, columns=["mol_id", "atom_index", "Shift"])
150
- final["atom_index"] = final["atom_index"].apply(lambda x: x + 1)
151
- return final.round(2).astype(dtype={"atom_index": "int"})
152
-
153
-
154
- def _fmt_weighted(final_df):
155
- return "".join(f"{int(r['atom_index'])},{r['Shift']:.2f};" for _, r in final_df.iterrows())
156
-
157
-
158
- def _fmt_conf_shifts(spread_df, energy_order):
159
- parts = []
160
- for cf_id in energy_order:
161
- sub = spread_df[spread_df["cf_id"] == cf_id]
162
- if len(sub) == 0:
163
- continue
164
- parts.append("".join(f"{int(r['atom_index'])},{r['predicted']:.2f};" for _, r in sub.iterrows()))
165
- return "!".join(parts)
166
-
167
-
168
- def _fmt_relative_E(spread_df, energy_order):
169
- total_bw = spread_df.groupby("cf_id")["b_weight"].first().sum()
170
- parts = []
171
- for cf_id in energy_order:
172
- sub = spread_df[spread_df["cf_id"] == cf_id]
173
- if len(sub) == 0:
174
- continue
175
- e = round(sub["relative_E"].iloc[0], 2)
176
- bw = round(sub["b_weight"].iloc[0] / total_bw, 4)
177
- parts.append(f"{e},{bw},")
178
- return "!".join(parts)
179
-
180
-
181
- def run_job(task_id, smiles, type_):
182
- result_key = f"task_result_{task_id}"
183
- try:
184
- mol = Chem.MolFromSmiles(smiles)
185
- AllChem.EmbedMolecule(mol, useRandomCoords=True)
186
- mol_with_h = Chem.AddHs(mol, addCoords=True)
187
-
188
- # Single conformer search
189
- mol_with_confs, ids, nr = genConf(mol_with_h, rms=-1, nc=200, efilter=10.0, rmspost=0.5)
190
- print(f"genConf: {len(ids)} conformers", flush=True)
191
-
192
- conf_sdfs, energy_order = _build_sdfs_from_genconf(mol_with_confs, ids)
193
-
194
- mols = [Chem.MolFromSmiles(smiles)]
195
- for m in mols:
196
- AllChem.EmbedMolecule(m, useRandomCoords=True)
197
- mols = [Chem.AddHs(m, addCoords=True) for m in mols]
198
-
199
- # Suppress duplicate genConf stdout during preprocess
200
- _stdout, _stderr = sys.stdout, sys.stderr
201
- sys.stdout = sys.stderr = open(os.devnull, 'w')
202
- try:
203
- if type_ == "C":
204
- inputs, df, conf_mols = preprocess_C(mols, preprocessor, keep_all_cf=True)
205
- else:
206
- inputs, df, conf_mols = preprocess_H(mols, preprocessor, keep_all_cf=True)
207
- finally:
208
- sys.stdout.close()
209
- sys.stdout, sys.stderr = _stdout, _stderr
210
-
211
- if type_ == "C":
212
- predicted = evaluate_C(inputs, preprocessor, model_C)
213
- else:
214
- predicted = evaluate_H(inputs, preprocessor, model_H)
215
-
216
- if len(inputs) == 0:
217
- raise RuntimeError("No conformers generated")
218
-
219
- spread_df = pd.DataFrame(columns=["mol_id", "atom_index", "relative_E", "cf_id"])
220
- for _, r in df.iterrows():
221
- n = len(r["atom_index"])
222
- tmp = pd.DataFrame({
223
- "mol_id": [r["mol_id"]] * n,
224
- "atom_index": r["atom_index"],
225
- "relative_E": [r["relative_E"]] * n,
226
- "cf_id": [r["cf_id"]] * n,
227
- })
228
- spread_df = pd.concat([spread_df, tmp], sort=True)
229
-
230
- spread_df["predicted"] = predicted
231
- spread_df["b_weight"] = spread_df["relative_E"].apply(
232
- lambda x: math.exp(-x / (0.001987 * 298.15))
233
- )
234
- spread_df["atom_index"] = spread_df["atom_index"].apply(lambda x: x + 1)
235
- spread_df = spread_df.round(2)
236
-
237
- final_df = _boltzmann_average(
238
- spread_df.copy().assign(
239
- atom_index=spread_df["atom_index"].apply(lambda x: x - 1)
240
- )
241
- )
242
-
243
- result = {
244
- "smiles": smiles,
245
- "type_": type_,
246
- "conf_sdfs": conf_sdfs,
247
- "weightedShiftTxt": _fmt_weighted(final_df),
248
- "confShiftTxt": _fmt_conf_shifts(spread_df, energy_order),
249
- "relative_E": _fmt_relative_E(spread_df, energy_order),
250
- }
251
- redis_client.set(result_key, json.dumps(result), ex=3600)
252
- print(f"Task {task_id} complete — {len(conf_sdfs)} conformers", flush=True)
253
-
254
- # Log to analytics dataset (non-blocking)
255
- _log_prediction()
256
-
257
- except Exception as e:
258
- import traceback; traceback.print_exc()
259
- redis_client.set(result_key, json.dumps({"errMessage": str(e)}), ex=3600)
260
-
261
-
262
- print("Worker ready, waiting for jobs...", flush=True)
263
- while True:
264
- item = redis_client.blpop("task_queue", timeout=5)
265
- if item is None:
266
- continue
267
- _, task_id = item
268
- detail = redis_client.get(f"task_detail_{task_id}")
269
- if not detail:
270
- continue
271
- detail = json.loads(detail)
272
- print(f"Processing task {task_id} smiles={detail['smiles']} type={detail['type_']}", flush=True)
273
- run_job(task_id, detail["smiles"], detail["type_"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def _log_prediction():
2
  """Append one row to the existing patonlab/analytics data.csv.
3
  Format matches the alfabet log: space,timestamp
 
5
  if not _HF_TOKEN:
6
  return
7
  try:
8
+ from huggingface_hub import HfApi, hf_hub_download
9
  import tempfile
10
 
11
  api = HfApi(token=_HF_TOKEN)
12
  timestamp = datetime.datetime.utcnow().isoformat()
13
 
14
+ # Download current CSV, append a row, re-upload
15
  with tempfile.TemporaryDirectory() as tmpdir:
16
+ local_path = hf_hub_download(
 
17
  repo_id=_ANALYTICS_REPO,
18
  filename=_ANALYTICS_FILE,
19
  repo_type="dataset",
20
+ token=_HF_TOKEN,
21
  local_dir=tmpdir,
22
  )
23
  with open(local_path, "a") as f:
 
31
  commit_message=f"log: cascade prediction {timestamp[:10]}",
32
  )
33
  except Exception as e:
34
+ print(f"Analytics logging failed (non-fatal): {e}", flush=True)