Hemanth Kunta commited on
Commit
3e987ed
·
1 Parent(s): cb70147

clamp grader outputs to strict score range

Browse files
env/app.py CHANGED
@@ -8,6 +8,7 @@ from fastapi import FastAPI, HTTPException
8
  from env.dataset_gen import generate_dataset
9
  from env.engine import SQLEngine
10
  from env.models import Action, EpisodeState, Observation, Reward, RewardBreakdown
 
11
  from tasks.task1_nulls import Task1
12
  from tasks.task2_schema import Task2
13
  from tasks.task3_drift import Task3
@@ -116,7 +117,7 @@ def step(payload: dict):
116
 
117
  base_score, score_breakdown = task.grade(action.report, gold)
118
  budget_bonus = round(min(0.10, state.query_credits * 0.01), 4)
119
- total = round(min(1.0, base_score + budget_bonus), 4)
120
 
121
  state.audit_score = total
122
  state.report_submitted = True
 
8
  from env.dataset_gen import generate_dataset
9
  from env.engine import SQLEngine
10
  from env.models import Action, EpisodeState, Observation, Reward, RewardBreakdown
11
+ from tasks.base import BaseTask
12
  from tasks.task1_nulls import Task1
13
  from tasks.task2_schema import Task2
14
  from tasks.task3_drift import Task3
 
117
 
118
  base_score, score_breakdown = task.grade(action.report, gold)
119
  budget_bonus = round(min(0.10, state.query_credits * 0.01), 4)
120
+ total = BaseTask.strict_score(round(min(1.0, base_score + budget_bonus), 4))
121
 
122
  state.audit_score = total
123
  state.report_submitted = True
high_grade_agent.py CHANGED
@@ -25,6 +25,7 @@ from env.reasoning_stack import (
25
  validate_and_repair_report,
26
  )
27
  from env.sql_brain import probes_for_task
 
28
 
29
  API_BASE_URL = os.environ.get("API_BASE_URL", "")
30
  MODEL_NAME = os.environ.get("MODEL_NAME", "")
@@ -439,7 +440,7 @@ def run_task(task_id: int, q_table: dict[str, list[float]], memory: MemoryStore)
439
 
440
  out = call_env("step", {"action": {"action_type": "submit_report", "report": report}})
441
  reward = out.get("reward", {})
442
- score = as_float(reward.get("value", 0.0))
443
 
444
  # Persist successful behavior to memory for future episodes.
445
  memory.add(
@@ -465,7 +466,8 @@ def main() -> None:
465
  print("\n=== HIGH-GRADE AGENT RESULTS ===")
466
  for k, v in scores.items():
467
  print(f" {k}: {v:.3f}")
468
- print(f" mean: {sum(scores.values())/len(scores):.3f}")
 
469
 
470
 
471
  if __name__ == "__main__":
 
25
  validate_and_repair_report,
26
  )
27
  from env.sql_brain import probes_for_task
28
+ from tasks.base import BaseTask
29
 
30
  API_BASE_URL = os.environ.get("API_BASE_URL", "")
31
  MODEL_NAME = os.environ.get("MODEL_NAME", "")
 
440
 
441
  out = call_env("step", {"action": {"action_type": "submit_report", "report": report}})
442
  reward = out.get("reward", {})
443
+ score = BaseTask.strict_score(as_float(reward.get("value", 0.0)))
444
 
445
  # Persist successful behavior to memory for future episodes.
446
  memory.add(
 
466
  print("\n=== HIGH-GRADE AGENT RESULTS ===")
467
  for k, v in scores.items():
468
  print(f" {k}: {v:.3f}")
469
+ mean_score = BaseTask.strict_score(sum(scores.values()) / len(scores))
470
+ print(f" mean: {mean_score:.3f}")
471
 
472
 
473
  if __name__ == "__main__":
tasks/base.py CHANGED
@@ -3,6 +3,9 @@ from abc import ABC, abstractmethod
3
  from env.models import AuditReport
4
 
5
 
 
 
 
6
  class BaseTask(ABC):
7
  @abstractmethod
8
  def get_description(self) -> str:
@@ -22,6 +25,18 @@ class BaseTask(ABC):
22
  brier = (confidence - c) ** 2
23
  return base * (1.0 - 0.3 * brier)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @staticmethod
26
  def count_accuracy(reported: int, actual: int, tolerance: float = 0.15) -> float:
27
  if actual == 0:
 
3
  from env.models import AuditReport
4
 
5
 
6
+ SCORE_EPS = 1e-6
7
+
8
+
9
  class BaseTask(ABC):
10
  @abstractmethod
11
  def get_description(self) -> str:
 
25
  brier = (confidence - c) ** 2
26
  return base * (1.0 - 0.3 * brier)
27
 
28
+ @staticmethod
29
+ def strict_score(value: float, epsilon: float = SCORE_EPS) -> float:
30
+ try:
31
+ score = float(value)
32
+ except Exception:
33
+ score = epsilon
34
+ if not (score > 0.0):
35
+ return epsilon
36
+ if not (score < 1.0):
37
+ return 1.0 - epsilon
38
+ return score
39
+
40
  @staticmethod
41
  def count_accuracy(reported: int, actual: int, tolerance: float = 0.15) -> float:
42
  if actual == 0:
tasks/task1_nulls.py CHANGED
@@ -39,4 +39,4 @@ class Task1(BaseTask):
39
 
40
  weights = {"null_email": 0.30, "null_cid": 0.25, "exact_dups": 0.30, "near_dups": 0.15}
41
  total = sum(scores[k] * weights[k] for k in weights)
42
- return round(min(1.0, total), 4), scores
 
39
 
40
  weights = {"null_email": 0.30, "null_cid": 0.25, "exact_dups": 0.30, "near_dups": 0.15}
41
  total = sum(scores[k] * weights[k] for k in weights)
42
+ return self.strict_score(round(total, 4)), scores
tasks/task2_schema.py CHANGED
@@ -50,4 +50,4 @@ class Task2(BaseTask):
50
 
51
  weights = {"amount_type": 0.25, "date_format": 0.25, "neg_qty": 0.25, "bad_amount": 0.25}
52
  total = sum(scores[k] * weights[k] for k in weights)
53
- return round(min(1.0, total), 4), scores
 
50
 
51
  weights = {"amount_type": 0.25, "date_format": 0.25, "neg_qty": 0.25, "bad_amount": 0.25}
52
  total = sum(scores[k] * weights[k] for k in weights)
53
+ return self.strict_score(round(total, 4)), scores
tasks/task3_drift.py CHANGED
@@ -57,4 +57,4 @@ class Task3(BaseTask):
57
 
58
  weights = {"mean_shift": 0.40, "new_cats": 0.35, "ref_drift": 0.25}
59
  total = sum(scores[k] * weights[k] for k in weights)
60
- return round(min(1.0, total), 4), scores
 
57
 
58
  weights = {"mean_shift": 0.40, "new_cats": 0.35, "ref_drift": 0.25}
59
  total = sum(scores[k] * weights[k] for k in weights)
60
+ return self.strict_score(round(total, 4)), scores
tasks/task4_relational.py CHANGED
@@ -50,4 +50,4 @@ class Task4(BaseTask):
50
 
51
  weights = {"orphans": 0.40, "temporal": 0.35, "aggregates": 0.25}
52
  total = sum(scores[k] * weights[k] for k in weights)
53
- return round(min(1.0, total), 4), scores
 
50
 
51
  weights = {"orphans": 0.40, "temporal": 0.35, "aggregates": 0.25}
52
  total = sum(scores[k] * weights[k] for k in weights)
53
+ return self.strict_score(round(total, 4)), scores