UtkarshSatav commited on
Commit
54a5bf9
·
verified ·
1 Parent(s): 33fd157

Upload folder using huggingface_hub

Browse files
inference.py CHANGED
@@ -235,7 +235,7 @@ def run_task(client: OpenAI, task_name: str) -> None:
235
  max_possible = obs.total_questions # 5 questions, max 1.0 each
236
  if max_possible > 0:
237
  score = sum(rewards) / max_possible
238
- score = min(max(score, 0.0), 1.0)
239
  success = score >= SUCCESS_SCORE_THRESHOLD
240
 
241
  except Exception as exc:
 
235
  max_possible = obs.total_questions # 5 questions, max 1.0 each
236
  if max_possible > 0:
237
  score = sum(rewards) / max_possible
238
+ score = min(max(score, 0.001), 0.999)
239
  success = score >= SUCCESS_SCORE_THRESHOLD
240
 
241
  except Exception as exc:
server/graders.py CHANGED
@@ -7,13 +7,21 @@ Scores agent queries against ground truth with partial credit:
7
  - row_score (0.3): Fraction of expected rows matching
8
  - exact_score (0.4): Full result set matches ground truth exactly
9
 
10
- Total reward per question is in [0.0, 1.0].
11
  """
12
 
13
  from typing import Any, List, Optional, Tuple
14
 
15
  from .database import Database, QueryResult
16
 
 
 
 
 
 
 
 
 
17
 
18
  def _normalize_value(val: Any) -> Any:
19
  """Normalize a value for comparison (handle float/int equivalence, None)."""
@@ -76,7 +84,7 @@ def grade_query(
76
  # --- Syntax Score ---
77
  if not result.success:
78
  return {
79
- "reward": 0.0,
80
  "syntax_score": 0.0,
81
  "column_score": 0.0,
82
  "row_score": 0.0,
@@ -145,7 +153,7 @@ def grade_query(
145
  + W_ROW * row_score
146
  + W_EXACT * exact_score
147
  )
148
- reward = round(min(max(reward, 0.0), 1.0), 4)
149
 
150
  # --- Feedback ---
151
  feedback_parts = []
 
7
  - row_score (0.3): Fraction of expected rows matching
8
  - exact_score (0.4): Full result set matches ground truth exactly
9
 
10
+ Total reward per question is in (0.0, 1.0) — strictly between 0 and 1.
11
  """
12
 
13
  from typing import Any, List, Optional, Tuple
14
 
15
  from .database import Database, QueryResult
16
 
17
+ # Epsilon to ensure scores are strictly between 0 and 1 (never exactly 0.0 or 1.0)
18
+ _EPS = 0.001
19
+
20
+
21
+ def _clamp_reward(reward: float) -> float:
22
+ """Clamp reward to be strictly within (0, 1)."""
23
+ return min(max(reward, _EPS), 1.0 - _EPS)
24
+
25
 
26
  def _normalize_value(val: Any) -> Any:
27
  """Normalize a value for comparison (handle float/int equivalence, None)."""
 
84
  # --- Syntax Score ---
85
  if not result.success:
86
  return {
87
+ "reward": _clamp_reward(0.0),
88
  "syntax_score": 0.0,
89
  "column_score": 0.0,
90
  "row_score": 0.0,
 
153
  + W_ROW * row_score
154
  + W_EXACT * exact_score
155
  )
156
+ reward = round(_clamp_reward(reward), 4)
157
 
158
  # --- Feedback ---
159
  feedback_parts = []
server/gradio_ui.py CHANGED
@@ -70,20 +70,18 @@ def create_gradio_app() -> gr.Blocks:
70
  obs = env.step(SQLAction(query=query))
71
 
72
  feedback = obs.metadata.get("feedback", "")
73
- reward_display = f"{obs.reward:.2f}"
74
 
75
  # Color the reward
76
- if obs.reward >= 0.9:
77
  reward_html = f'<span style="color:#22c55e;font-size:2em;font-weight:bold">{reward_display}</span>'
78
- elif obs.reward >= 0.5:
79
- reward_html = f'<span style="color:#eab308;font-size:2em;font-weight:bold">{reward_display}</span>'
80
  else:
81
  reward_html = f'<span style="color:#ef4444;font-size:2em;font-weight:bold">{reward_display}</span>'
82
 
83
  if obs.done:
84
  rewards = obs.metadata.get("rewards", [])
85
  total = obs.metadata.get("total_reward", sum(rewards))
86
- status = f"**Episode Complete!** | **Total Reward:** {total:.2f} | **Steps:** {len(rewards)}"
87
  next_question = "All questions answered! Click 'Start Task' to try again."
88
  progress = _build_progress_html(len(rewards), obs.total_questions, rewards)
89
  else:
@@ -116,10 +114,10 @@ def create_gradio_app() -> gr.Blocks:
116
  results = []
117
  for q in task["questions"]:
118
  obs = env.step(SQLAction(query=q["ground_truth_sql"]))
119
- results.append(f"**Q{len(results)+1}:** {q['question'][:80]}...\n- SQL: `{q['ground_truth_sql'][:100]}...`\n- Reward: **{obs.reward:.2f}**\n")
120
 
121
  total = sum(env._rewards)
122
- results.append(f"\n---\n**Total: {total:.2f} / {len(task['questions']):.1f}**")
123
  return "\n".join(results)
124
 
125
  def preview_schema():
@@ -142,7 +140,7 @@ def create_gradio_app() -> gr.Blocks:
142
  color = "#eab308"
143
  else:
144
  color = "#ef4444"
145
- bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:{color};margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1}: {r:.2f}</div>')
146
  elif i == len(rewards):
147
  bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:#3b82f6;margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1} ▶</div>')
148
  else:
 
70
  obs = env.step(SQLAction(query=query))
71
 
72
  feedback = obs.metadata.get("feedback", "")
73
+ reward_display = round(obs.reward) # show 0 or 1
74
 
75
  # Color the reward
76
+ if reward_display == 1:
77
  reward_html = f'<span style="color:#22c55e;font-size:2em;font-weight:bold">{reward_display}</span>'
 
 
78
  else:
79
  reward_html = f'<span style="color:#ef4444;font-size:2em;font-weight:bold">{reward_display}</span>'
80
 
81
  if obs.done:
82
  rewards = obs.metadata.get("rewards", [])
83
  total = obs.metadata.get("total_reward", sum(rewards))
84
+ status = f"**Episode Complete!** | **Total Reward:** {round(total)} | **Steps:** {len(rewards)}"
85
  next_question = "All questions answered! Click 'Start Task' to try again."
86
  progress = _build_progress_html(len(rewards), obs.total_questions, rewards)
87
  else:
 
114
  results = []
115
  for q in task["questions"]:
116
  obs = env.step(SQLAction(query=q["ground_truth_sql"]))
117
+ results.append(f"**Q{len(results)+1}:** {q['question'][:80]}...\n- SQL: `{q['ground_truth_sql'][:100]}...`\n- Reward: **{round(obs.reward)}**\n")
118
 
119
  total = sum(env._rewards)
120
+ results.append(f"\n---\n**Total: {round(total)} / {len(task['questions'])}**")
121
  return "\n".join(results)
122
 
123
  def preview_schema():
 
140
  color = "#eab308"
141
  else:
142
  color = "#ef4444"
143
+ bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:{color};margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1}: {round(r)}</div>')
144
  elif i == len(rewards):
145
  bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:#3b82f6;margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1} ▶</div>')
146
  else:
server/sql_env_environment.py CHANGED
@@ -20,7 +20,7 @@ except ImportError:
20
  from models import SQLAction, SQLObservation
21
 
22
  from .database import Database
23
- from .graders import grade_query
24
 
25
  TASKS_DIR = Path(__file__).resolve().parent.parent / "data" / "tasks"
26
 
@@ -92,7 +92,7 @@ class SQLEnvironment(Environment):
92
  self._schema_cache = self._db.get_schema_description()
93
 
94
  return self._make_observation(
95
- reward=0.0,
96
  query_result="",
97
  error="",
98
  )
@@ -108,7 +108,7 @@ class SQLEnvironment(Environment):
108
  if self._done or self._current_q_index >= len(self._questions):
109
  self._done = True
110
  return self._make_observation(
111
- reward=0.0,
112
  query_result="Episode is over. Call reset() to start a new episode.",
113
  error="",
114
  )
@@ -133,7 +133,7 @@ class SQLEnvironment(Environment):
133
 
134
  # Apply step penalty (not on first attempt)
135
  penalty = STEP_PENALTY * (self._q_steps_used - 1)
136
- reward = max(raw_reward - penalty, 0.0)
137
  reward = round(reward, 4)
138
 
139
  self._rewards.append(reward)
 
20
  from models import SQLAction, SQLObservation
21
 
22
  from .database import Database
23
+ from .graders import grade_query, _clamp_reward
24
 
25
  TASKS_DIR = Path(__file__).resolve().parent.parent / "data" / "tasks"
26
 
 
92
  self._schema_cache = self._db.get_schema_description()
93
 
94
  return self._make_observation(
95
+ reward=_clamp_reward(0.0),
96
  query_result="",
97
  error="",
98
  )
 
108
  if self._done or self._current_q_index >= len(self._questions):
109
  self._done = True
110
  return self._make_observation(
111
+ reward=_clamp_reward(0.0),
112
  query_result="Episode is over. Call reset() to start a new episode.",
113
  error="",
114
  )
 
133
 
134
  # Apply step penalty (not on first attempt)
135
  penalty = STEP_PENALTY * (self._q_steps_used - 1)
136
+ reward = _clamp_reward(raw_reward - penalty)
137
  reward = round(reward, 4)
138
 
139
  self._rewards.append(reward)