sk16er commited on
Commit
3700a25
·
verified ·
1 Parent(s): 291414a

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +4 -16
predict.py CHANGED
@@ -26,7 +26,7 @@ class Tox21Predictor:
26
  self.models = []
27
  self.scaler = None
28
 
29
- # [cite_start]Load scaler [cite: 5]
30
  if os.path.exists(SCALER_FILE):
31
  try:
32
  with open(SCALER_FILE, "rb") as f:
@@ -134,8 +134,8 @@ _predictor = None
134
 
135
  def predict(smiles_list):
136
  """
137
- Leaderboard Requirement: Returns a dictionary with "predictions" key
138
- containing a list of 12-task float arrays.
139
  """
140
  global _predictor
141
  if _predictor is None:
@@ -143,16 +143,4 @@ def predict(smiles_list):
143
 
144
  raw_results = _predictor.predict(smiles_list)
145
 
146
- predictions = []
147
- for s in smiles_list:
148
- # Build the vector using the strict TASKS order from data.py
149
- # This ensures 'NR-AR' and other tasks are always in the correct index.
150
- task_vector = []
151
- for task in TASKS:
152
- # Default to 0.0 if anything went wrong with that specific task
153
- val = raw_results.get(s, {}).get(task, 0.0)
154
- task_vector.append(float(val))
155
-
156
- predictions.append(task_vector)
157
-
158
- return {"predictions": predictions}
 
26
  self.models = []
27
  self.scaler = None
28
 
29
+ # Load scaler
30
  if os.path.exists(SCALER_FILE):
31
  try:
32
  with open(SCALER_FILE, "rb") as f:
 
134
 
135
  def predict(smiles_list):
136
  """
137
+ Leaderboard Entry Point: Returns a dictionary with "predictions" key
138
+ containing the raw dictionary results from Tox21Predictor.
139
  """
140
  global _predictor
141
  if _predictor is None:
 
143
 
144
  raw_results = _predictor.predict(smiles_list)
145
 
146
+ return {"predictions": raw_results}