Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- inference.py +1 -1
- server/graders.py +11 -3
- server/gradio_ui.py +6 -8
- server/sql_env_environment.py +4 -4
inference.py
CHANGED
|
@@ -235,7 +235,7 @@ def run_task(client: OpenAI, task_name: str) -> None:
|
|
| 235 |
max_possible = obs.total_questions # 5 questions, max 1.0 each
|
| 236 |
if max_possible > 0:
|
| 237 |
score = sum(rewards) / max_possible
|
| 238 |
-
score = min(max(score, 0.
|
| 239 |
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 240 |
|
| 241 |
except Exception as exc:
|
|
|
|
| 235 |
max_possible = obs.total_questions # 5 questions, max 1.0 each
|
| 236 |
if max_possible > 0:
|
| 237 |
score = sum(rewards) / max_possible
|
| 238 |
+
score = min(max(score, 0.001), 0.999)
|
| 239 |
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 240 |
|
| 241 |
except Exception as exc:
|
server/graders.py
CHANGED
|
@@ -7,13 +7,21 @@ Scores agent queries against ground truth with partial credit:
|
|
| 7 |
- row_score (0.3): Fraction of expected rows matching
|
| 8 |
- exact_score (0.4): Full result set matches ground truth exactly
|
| 9 |
|
| 10 |
-
Total reward per question is in
|
| 11 |
"""
|
| 12 |
|
| 13 |
from typing import Any, List, Optional, Tuple
|
| 14 |
|
| 15 |
from .database import Database, QueryResult
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def _normalize_value(val: Any) -> Any:
|
| 19 |
"""Normalize a value for comparison (handle float/int equivalence, None)."""
|
|
@@ -76,7 +84,7 @@ def grade_query(
|
|
| 76 |
# --- Syntax Score ---
|
| 77 |
if not result.success:
|
| 78 |
return {
|
| 79 |
-
"reward": 0.0,
|
| 80 |
"syntax_score": 0.0,
|
| 81 |
"column_score": 0.0,
|
| 82 |
"row_score": 0.0,
|
|
@@ -145,7 +153,7 @@ def grade_query(
|
|
| 145 |
+ W_ROW * row_score
|
| 146 |
+ W_EXACT * exact_score
|
| 147 |
)
|
| 148 |
-
reward = round(
|
| 149 |
|
| 150 |
# --- Feedback ---
|
| 151 |
feedback_parts = []
|
|
|
|
| 7 |
- row_score (0.3): Fraction of expected rows matching
|
| 8 |
- exact_score (0.4): Full result set matches ground truth exactly
|
| 9 |
|
| 10 |
+
Total reward per question is in (0.0, 1.0) — strictly between 0 and 1.
|
| 11 |
"""
|
| 12 |
|
| 13 |
from typing import Any, List, Optional, Tuple
|
| 14 |
|
| 15 |
from .database import Database, QueryResult
|
| 16 |
|
| 17 |
+
# Epsilon to ensure scores are strictly between 0 and 1 (never exactly 0.0 or 1.0)
|
| 18 |
+
_EPS = 0.001
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _clamp_reward(reward: float) -> float:
|
| 22 |
+
"""Clamp reward to be strictly within (0, 1)."""
|
| 23 |
+
return min(max(reward, _EPS), 1.0 - _EPS)
|
| 24 |
+
|
| 25 |
|
| 26 |
def _normalize_value(val: Any) -> Any:
|
| 27 |
"""Normalize a value for comparison (handle float/int equivalence, None)."""
|
|
|
|
| 84 |
# --- Syntax Score ---
|
| 85 |
if not result.success:
|
| 86 |
return {
|
| 87 |
+
"reward": _clamp_reward(0.0),
|
| 88 |
"syntax_score": 0.0,
|
| 89 |
"column_score": 0.0,
|
| 90 |
"row_score": 0.0,
|
|
|
|
| 153 |
+ W_ROW * row_score
|
| 154 |
+ W_EXACT * exact_score
|
| 155 |
)
|
| 156 |
+
reward = round(_clamp_reward(reward), 4)
|
| 157 |
|
| 158 |
# --- Feedback ---
|
| 159 |
feedback_parts = []
|
server/gradio_ui.py
CHANGED
|
@@ -70,20 +70,18 @@ def create_gradio_app() -> gr.Blocks:
|
|
| 70 |
obs = env.step(SQLAction(query=query))
|
| 71 |
|
| 72 |
feedback = obs.metadata.get("feedback", "")
|
| 73 |
-
reward_display =
|
| 74 |
|
| 75 |
# Color the reward
|
| 76 |
-
if
|
| 77 |
reward_html = f'<span style="color:#22c55e;font-size:2em;font-weight:bold">{reward_display}</span>'
|
| 78 |
-
elif obs.reward >= 0.5:
|
| 79 |
-
reward_html = f'<span style="color:#eab308;font-size:2em;font-weight:bold">{reward_display}</span>'
|
| 80 |
else:
|
| 81 |
reward_html = f'<span style="color:#ef4444;font-size:2em;font-weight:bold">{reward_display}</span>'
|
| 82 |
|
| 83 |
if obs.done:
|
| 84 |
rewards = obs.metadata.get("rewards", [])
|
| 85 |
total = obs.metadata.get("total_reward", sum(rewards))
|
| 86 |
-
status = f"**Episode Complete!** | **Total Reward:** {total
|
| 87 |
next_question = "All questions answered! Click 'Start Task' to try again."
|
| 88 |
progress = _build_progress_html(len(rewards), obs.total_questions, rewards)
|
| 89 |
else:
|
|
@@ -116,10 +114,10 @@ def create_gradio_app() -> gr.Blocks:
|
|
| 116 |
results = []
|
| 117 |
for q in task["questions"]:
|
| 118 |
obs = env.step(SQLAction(query=q["ground_truth_sql"]))
|
| 119 |
-
results.append(f"**Q{len(results)+1}:** {q['question'][:80]}...\n- SQL: `{q['ground_truth_sql'][:100]}...`\n- Reward: **{obs.reward
|
| 120 |
|
| 121 |
total = sum(env._rewards)
|
| 122 |
-
results.append(f"\n---\n**Total: {total
|
| 123 |
return "\n".join(results)
|
| 124 |
|
| 125 |
def preview_schema():
|
|
@@ -142,7 +140,7 @@ def create_gradio_app() -> gr.Blocks:
|
|
| 142 |
color = "#eab308"
|
| 143 |
else:
|
| 144 |
color = "#ef4444"
|
| 145 |
-
bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:{color};margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1}: {r
|
| 146 |
elif i == len(rewards):
|
| 147 |
bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:#3b82f6;margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1} ▶</div>')
|
| 148 |
else:
|
|
|
|
| 70 |
obs = env.step(SQLAction(query=query))
|
| 71 |
|
| 72 |
feedback = obs.metadata.get("feedback", "")
|
| 73 |
+
reward_display = round(obs.reward) # show 0 or 1
|
| 74 |
|
| 75 |
# Color the reward
|
| 76 |
+
if reward_display == 1:
|
| 77 |
reward_html = f'<span style="color:#22c55e;font-size:2em;font-weight:bold">{reward_display}</span>'
|
|
|
|
|
|
|
| 78 |
else:
|
| 79 |
reward_html = f'<span style="color:#ef4444;font-size:2em;font-weight:bold">{reward_display}</span>'
|
| 80 |
|
| 81 |
if obs.done:
|
| 82 |
rewards = obs.metadata.get("rewards", [])
|
| 83 |
total = obs.metadata.get("total_reward", sum(rewards))
|
| 84 |
+
status = f"**Episode Complete!** | **Total Reward:** {round(total)} | **Steps:** {len(rewards)}"
|
| 85 |
next_question = "All questions answered! Click 'Start Task' to try again."
|
| 86 |
progress = _build_progress_html(len(rewards), obs.total_questions, rewards)
|
| 87 |
else:
|
|
|
|
| 114 |
results = []
|
| 115 |
for q in task["questions"]:
|
| 116 |
obs = env.step(SQLAction(query=q["ground_truth_sql"]))
|
| 117 |
+
results.append(f"**Q{len(results)+1}:** {q['question'][:80]}...\n- SQL: `{q['ground_truth_sql'][:100]}...`\n- Reward: **{round(obs.reward)}**\n")
|
| 118 |
|
| 119 |
total = sum(env._rewards)
|
| 120 |
+
results.append(f"\n---\n**Total: {round(total)} / {len(task['questions'])}**")
|
| 121 |
return "\n".join(results)
|
| 122 |
|
| 123 |
def preview_schema():
|
|
|
|
| 140 |
color = "#eab308"
|
| 141 |
else:
|
| 142 |
color = "#ef4444"
|
| 143 |
+
bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:{color};margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1}: {round(r)}</div>')
|
| 144 |
elif i == len(rewards):
|
| 145 |
bars.append(f'<div style="display:inline-block;width:18%;height:30px;background:#3b82f6;margin:1%;border-radius:4px;text-align:center;line-height:30px;color:white;font-weight:bold">Q{i+1} ▶</div>')
|
| 146 |
else:
|
server/sql_env_environment.py
CHANGED
|
@@ -20,7 +20,7 @@ except ImportError:
|
|
| 20 |
from models import SQLAction, SQLObservation
|
| 21 |
|
| 22 |
from .database import Database
|
| 23 |
-
from .graders import grade_query
|
| 24 |
|
| 25 |
TASKS_DIR = Path(__file__).resolve().parent.parent / "data" / "tasks"
|
| 26 |
|
|
@@ -92,7 +92,7 @@ class SQLEnvironment(Environment):
|
|
| 92 |
self._schema_cache = self._db.get_schema_description()
|
| 93 |
|
| 94 |
return self._make_observation(
|
| 95 |
-
reward=0.0,
|
| 96 |
query_result="",
|
| 97 |
error="",
|
| 98 |
)
|
|
@@ -108,7 +108,7 @@ class SQLEnvironment(Environment):
|
|
| 108 |
if self._done or self._current_q_index >= len(self._questions):
|
| 109 |
self._done = True
|
| 110 |
return self._make_observation(
|
| 111 |
-
reward=0.0,
|
| 112 |
query_result="Episode is over. Call reset() to start a new episode.",
|
| 113 |
error="",
|
| 114 |
)
|
|
@@ -133,7 +133,7 @@ class SQLEnvironment(Environment):
|
|
| 133 |
|
| 134 |
# Apply step penalty (not on first attempt)
|
| 135 |
penalty = STEP_PENALTY * (self._q_steps_used - 1)
|
| 136 |
-
reward =
|
| 137 |
reward = round(reward, 4)
|
| 138 |
|
| 139 |
self._rewards.append(reward)
|
|
|
|
| 20 |
from models import SQLAction, SQLObservation
|
| 21 |
|
| 22 |
from .database import Database
|
| 23 |
+
from .graders import grade_query, _clamp_reward
|
| 24 |
|
| 25 |
TASKS_DIR = Path(__file__).resolve().parent.parent / "data" / "tasks"
|
| 26 |
|
|
|
|
| 92 |
self._schema_cache = self._db.get_schema_description()
|
| 93 |
|
| 94 |
return self._make_observation(
|
| 95 |
+
reward=_clamp_reward(0.0),
|
| 96 |
query_result="",
|
| 97 |
error="",
|
| 98 |
)
|
|
|
|
| 108 |
if self._done or self._current_q_index >= len(self._questions):
|
| 109 |
self._done = True
|
| 110 |
return self._make_observation(
|
| 111 |
+
reward=_clamp_reward(0.0),
|
| 112 |
query_result="Episode is over. Call reset() to start a new episode.",
|
| 113 |
error="",
|
| 114 |
)
|
|
|
|
| 133 |
|
| 134 |
# Apply step penalty (not on first attempt)
|
| 135 |
penalty = STEP_PENALTY * (self._q_steps_used - 1)
|
| 136 |
+
reward = _clamp_reward(raw_reward - penalty)
|
| 137 |
reward = round(reward, 4)
|
| 138 |
|
| 139 |
self._rewards.append(reward)
|