Rockerleo commited on
Commit
a744b64
·
verified ·
1 Parent(s): 7996c05

Upload server/mlops_environment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/mlops_environment.py +3 -3
server/mlops_environment.py CHANGED
@@ -159,7 +159,7 @@ class MLOpsEnvironment:
159
 
160
  def step(self, action: MLOpsAction) -> Tuple[MLOpsObservation, float, bool, Dict[str, Any]]:
161
  if self._done:
162
- return self._build_obs({"status": "done", "message": "Episode over. Call reset()."}), 0.0, True, {}
163
 
164
  self._step_count += 1
165
  reward = 0.0
@@ -170,7 +170,7 @@ class MLOpsEnvironment:
170
  self._done = True
171
  score = max(0.01, self._current_score)
172
  result = {"status": "timeout", "message": f"Max steps ({self._max_steps}) reached.", "score": score}
173
- return self._build_obs(result), 0.0, True, {"score": score, "reason": "timeout"}
174
 
175
  atype = action.action_type
176
 
@@ -458,4 +458,4 @@ def grade_task(task_id: str, seed: int, diagnosis: Dict[str, Any]) -> float:
458
  env._artifacts_read = list(env._artifacts.keys())
459
  action = MLOpsAction(action_type="submit_diagnosis", **diagnosis)
460
  _, reward, _, info = env.step(action)
461
- return info.get("score", 0.0)
 
159
 
160
  def step(self, action: MLOpsAction) -> Tuple[MLOpsObservation, float, bool, Dict[str, Any]]:
161
  if self._done:
162
+ return self._build_obs({"status": "done", "message": "Episode over. Call reset()."}), 0.01, True, {"score": max(0.01, min(0.99, self._current_score))}
163
 
164
  self._step_count += 1
165
  reward = 0.0
 
170
  self._done = True
171
  score = max(0.01, self._current_score)
172
  result = {"status": "timeout", "message": f"Max steps ({self._max_steps}) reached.", "score": score}
173
+ return self._build_obs(result), score, True, {"score": score, "reason": "timeout"}
174
 
175
  atype = action.action_type
176
 
 
458
  env._artifacts_read = list(env._artifacts.keys())
459
  action = MLOpsAction(action_type="submit_diagnosis", **diagnosis)
460
  _, reward, _, info = env.step(action)
461
+ return max(0.01, min(0.99, info.get("score", 0.01)))