ritvik360 commited on
Commit
ed2e608
·
verified ·
1 Parent(s): ae4bbb1

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. client.py +13 -5
client.py CHANGED
@@ -53,10 +53,18 @@ class NL2SQLEnv:
53
  def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
54
  obs_data = payload.get("observation", {})
55
 
56
- # Read standard OpenEnv top-level keys safely
57
- safe_reward = float(payload.get("reward", 0.0))
58
- safe_done = bool(payload.get("done", False))
59
 
 
 
 
 
 
 
 
 
60
  obs = NL2SQLObservation(
61
  question=obs_data.get("question", ""),
62
  schema_context=obs_data.get("schema_context", ""),
@@ -68,8 +76,8 @@ class NL2SQLEnv:
68
  step=obs_data.get("step", 0),
69
  max_steps=obs_data.get("max_steps", 5),
70
  done=safe_done,
71
- reward=safe_reward,
72
- score=float(obs_data.get("score", 0.0)),
73
  )
74
  return StepResult(
75
  observation=obs,
 
53
  def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
54
  obs_data = payload.get("observation", {})
55
 
56
+ # SAFETY CHECK: Handle JSON 'null' (None) values gracefully
57
+ raw_reward = payload.get("reward")
58
+ safe_reward = float(raw_reward) if raw_reward is not None else 0.0
59
 
60
+ raw_obs_reward = obs_data.get("reward")
61
+ safe_obs_reward = float(raw_obs_reward) if raw_obs_reward is not None else 0.0
62
+
63
+ raw_score = obs_data.get("score")
64
+ safe_score = float(raw_score) if raw_score is not None else 0.0
65
+
66
+ safe_done = bool(payload.get("done") or obs_data.get("done") or False)
67
+
68
  obs = NL2SQLObservation(
69
  question=obs_data.get("question", ""),
70
  schema_context=obs_data.get("schema_context", ""),
 
76
  step=obs_data.get("step", 0),
77
  max_steps=obs_data.get("max_steps", 5),
78
  done=safe_done,
79
+ reward=safe_obs_reward,
80
+ score=safe_score,
81
  )
82
  return StepResult(
83
  observation=obs,