Spaces:
Sleeping
Sleeping
Upload server/reward.py with huggingface_hub
Browse files- server/reward.py +80 -0
server/reward.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def compute(prev_accuracy, new_accuracy, prev_stats, new_stats, action, steps_taken, max_steps, budget_remaining, target_accuracy, relabeler_used):
|
| 2 |
+
"""
|
| 3 |
+
CRITICAL REQUIREMENT: All reward components must be graders strictly between
|
| 4 |
+
0.0 and 1.0 — exclusive. Neither 0.0 nor 1.0 are valid outputs.
|
| 5 |
+
Valid range: (0.001 ... 0.999)
|
| 6 |
+
|
| 7 |
+
Each sub-grader scores one independent aspect and returns a value in (0.0, 1.0).
|
| 8 |
+
Final reward is a weighted average of all graders — also in (0.0, 1.0).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def clamp(v):
|
| 12 |
+
"""Clamp to strictly open interval (0.0, 1.0)."""
|
| 13 |
+
return max(0.001, min(0.999, float(v)))
|
| 14 |
+
|
| 15 |
+
# --- Grader 1: Format compliance (independent) ---
|
| 16 |
+
# Did the agent produce a valid, well-formed action?
|
| 17 |
+
valid_agents = ["cleaner", "augmenter", "balancer", "relabeler", "validator"]
|
| 18 |
+
if not isinstance(action.get("agent"), str) or action.get("agent") not in valid_agents:
|
| 19 |
+
format_score = 0.001 # invalid agent — minimum non-zero
|
| 20 |
+
elif "target" not in action:
|
| 21 |
+
format_score = 0.4 # valid agent but incomplete fields
|
| 22 |
+
else:
|
| 23 |
+
format_score = 0.999 # fully valid action format
|
| 24 |
+
|
| 25 |
+
# --- Grader 2: Accuracy improvement ---
|
| 26 |
+
# How much did accuracy improve toward target?
|
| 27 |
+
delta_acc = new_accuracy - prev_accuracy
|
| 28 |
+
remaining = max(0.001, target_accuracy - prev_accuracy)
|
| 29 |
+
progress = delta_acc / remaining if remaining > 0 else 0.0
|
| 30 |
+
accuracy_score = clamp(0.5 + progress * 0.49) # neutral at 0.5, better if improving
|
| 31 |
+
|
| 32 |
+
# --- Grader 3: Dataset quality improvement ---
|
| 33 |
+
# Combined missing value reduction + balance improvement
|
| 34 |
+
missing_improvement = prev_stats["missing_pct"] - new_stats["missing_pct"]
|
| 35 |
+
balance_improvement = new_stats["balance_ratio"] - prev_stats["balance_ratio"]
|
| 36 |
+
quality_delta = (missing_improvement + balance_improvement) / 2.0
|
| 37 |
+
quality_score = clamp(0.5 + quality_delta * 2.0)
|
| 38 |
+
|
| 39 |
+
# --- Grader 4: Efficiency ---
|
| 40 |
+
# Did the agent improve anything at all? Penalize wasted steps.
|
| 41 |
+
nothing_changed = (delta_acc <= 0 and missing_improvement <= 0 and balance_improvement <= 0)
|
| 42 |
+
relabeler_overused = relabeler_used and budget_remaining < 3
|
| 43 |
+
if nothing_changed:
|
| 44 |
+
efficiency_score = 0.1 # wasted a step
|
| 45 |
+
elif relabeler_overused:
|
| 46 |
+
efficiency_score = 0.3 # used expensive tool recklessly
|
| 47 |
+
else:
|
| 48 |
+
# Reward using budget efficiently — more budget left = better
|
| 49 |
+
efficiency_score = clamp(0.5 + (budget_remaining / max_steps) * 0.49)
|
| 50 |
+
|
| 51 |
+
# --- Grader 5: Task completion ---
|
| 52 |
+
# Did this action help reach the target threshold?
|
| 53 |
+
if new_accuracy >= target_accuracy:
|
| 54 |
+
# Success — reward scales with how much budget is left (efficiency bonus)
|
| 55 |
+
completion_score = clamp(0.9 + (budget_remaining / max_steps) * 0.09)
|
| 56 |
+
elif new_accuracy > prev_accuracy:
|
| 57 |
+
completion_score = clamp(0.5 + (new_accuracy / target_accuracy) * 0.4)
|
| 58 |
+
else:
|
| 59 |
+
completion_score = 0.1 # no progress toward target
|
| 60 |
+
|
| 61 |
+
# --- Weighted average — stays in (0.0, 1.0) by construction ---
|
| 62 |
+
reward = (
|
| 63 |
+
format_score * 0.15 +
|
| 64 |
+
accuracy_score * 0.35 +
|
| 65 |
+
quality_score * 0.20 +
|
| 66 |
+
efficiency_score * 0.15 +
|
| 67 |
+
completion_score * 0.15
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Final safety clamp — must never be exactly 0.0 or 1.0
|
| 71 |
+
reward = clamp(reward)
|
| 72 |
+
|
| 73 |
+
return round(reward, 4)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def compute_stats(df):
|
| 77 |
+
missing_pct = float(df.isnull().mean().mean())
|
| 78 |
+
label_counts = df["label"].value_counts(normalize=True)
|
| 79 |
+
balance_ratio = float(label_counts.min()) if len(label_counts) > 1 else 1.0
|
| 80 |
+
return {"missing_pct": missing_pct, "balance_ratio": balance_ratio}
|