""" server/rubric.py -- Multi-component reward rubric for the Hypothesis Lab. Components: 1. accuracy_score (0.0-1.0) -- how close is the hypothesis to ground truth 2. precision_bonus (+0.10) -- hypothesis contains quantitative claims 3. calibration_score (0.0-0.20) -- expressed confidence matches accuracy 4. efficiency_bonus (+0.15) -- submitted early with high accuracy 5. contradiction_penalty (-0.50) -- hypothesis contradicts hard constraints Per-step info-gain scoring is handled by InfoGainTracker. """ from __future__ import annotations import re import math from dataclasses import dataclass, field from typing import Any, Optional import numpy as np from .causal_world import CausalWorld @dataclass class RubricResult: """Full rubric breakdown returned when a hypothesis is scored.""" accuracy_score: float = 0.0 precision_bonus: float = 0.0 calibration_score: float = 0.0 efficiency_bonus: float = 0.0 contradiction_penalty: float = 0.0 feedback: str = "" ground_truth: str = "" @property def total(self) -> float: return ( self.accuracy_score + self.precision_bonus + self.calibration_score + self.efficiency_bonus + self.contradiction_penalty ) def to_dict(self) -> dict[str, float]: return { "accuracy_score": round(self.accuracy_score, 4), "precision_bonus": round(self.precision_bonus, 4), "calibration_score": round(self.calibration_score, 4), "efficiency_bonus": round(self.efficiency_bonus, 4), "contradiction_penalty": round(self.contradiction_penalty, 4), "total": round(self.total, 4), } class InfoGainTracker: """ Tracks experiment history and computes per-step information gain rewards. Also detects redundant experiments. """ def __init__(self) -> None: self._edge_counts: dict[tuple[str, str], int] = {} self._edge_types: dict[tuple[str, str], set[str]] = {} self.cumulative_gain: float = 0.0 self.redundant_count: int = 0 def record_and_score( self, cause: str, effect: str, exp_type: str, result_value: Any, ) -> tuple[float, bool]: """ Record an experiment and return (reward, is_redundant). Reward schedule: - First observation of an edge: +0.20 - Second (different exp type = triangulation bonus): +0.25 - Second (same type): +0.12 - Third+: -0.10 (redundant penalty) """ key = (cause, effect) prior = self._edge_counts.get(key, 0) prior_types = set(self._edge_types.get(key, set())) self._edge_counts[key] = prior + 1 if key not in self._edge_types: self._edge_types[key] = set() self._edge_types[key].add(exp_type) if prior == 0: reward = 0.20 elif prior == 1: triangulation = exp_type not in prior_types reward = 0.25 if triangulation else 0.12 elif prior == 2: reward = 0.05 else: reward = -0.10 self.redundant_count += 1 is_redundant = prior >= 3 if is_redundant: reward = -0.10 self.redundant_count += 1 self.cumulative_gain += max(reward, 0.0) return round(reward, 4), is_redundant HARD_CONSTRAINTS = [ (r"all variables.*independent", "Claiming all variables are independent contradicts the experimental setup"), (r"no.*relationship|no.*causal", "Claiming no relationships exist contradicts the experimental setup"), ] _RULE_KEYWORDS: dict[str, list[str]] = { "linear": [ "linear", "proportional", "slope", "times", "multiply", "increases", "decreases", ], "threshold": [ "threshold", "above", "below", "greater", "less", "if", "when", "switch", "cutoff", ], "inverse": ["inverse", "inversely", "reciprocal", "divided", "1/"], "quadratic": [ "quadratic", "squared", "parabol", "x^2", "x²", "nonlinear", "curve", "polynomial", ], "exponential": [ "exponential", "exp(", "growth", "decay", "e^", "geometric", ], "logarithmic": [ "logarithm", "log(", "ln(", "log ", "diminishing returns", ], "saturating": [ "saturating", "saturat", "michaelis", "plateau", "asymptote", "levels off", "diminishing", "vmax", ], "piecewise_linear": [ "piecewise", "breakpoint", "knot", "changes slope", "two-segment", "regime change", "kink", ], "additive": [ "additive", "sum", "combines", "both contribute", "joint", ], "multiplicative": [ "multiplicative", "product", "multiply", "synerg", "interaction", ], "min": ["minimum", "bottleneck", "limiting factor", "min("], "max": ["maximum", "dominant", "max("], } def _accuracy_score(hypothesis: str, world: CausalWorld) -> float: """Score how well the hypothesis captures the ground truth rules.""" if not hypothesis.strip(): return 0.0 text = hypothesis.lower() all_scorable = list(world.rules) total_items = len(all_scorable) + len(world.interactions) if total_items == 0: return 0.5 hits = 0.0 for rule in all_scorable: cause_l = rule.cause.lower() effect_l = rule.effect.lower() has_cause = cause_l in text or cause_l[:4] in text has_effect = effect_l in text or effect_l[:4] in text if not (has_cause and has_effect): continue hits += 0.4 keywords = _RULE_KEYWORDS.get(rule.rule_type, []) if any(w in text for w in keywords): hits += 0.3 key_param = _key_param_for_rule(rule) if key_param is not None and str(round(abs(key_param), 1)) in hypothesis: hits += 0.3 for inter in world.interactions: c1_l = inter.cause1.lower() c2_l = inter.cause2.lower() eff_l = inter.effect.lower() found_c1 = c1_l in text or c1_l[:4] in text found_c2 = c2_l in text or c2_l[:4] in text found_eff = eff_l in text or eff_l[:4] in text if found_eff and (found_c1 or found_c2): hits += 0.3 if found_eff and found_c1 and found_c2: hits += 0.2 keywords = _RULE_KEYWORDS.get(inter.interaction_type, []) if any(w in text for w in keywords): hits += 0.5 max_possible = total_items * 1.0 return min(hits / max_possible, 1.0) if max_possible > 0 else 0.0 def _key_param_for_rule(rule) -> Optional[float]: """Return the most important parameter for a rule type, for matching.""" rt = rule.rule_type p = rule.params if rt == "linear": return p.get("a") elif rt == "threshold": return p.get("threshold") elif rt == "inverse": return p.get("a") elif rt == "quadratic": return p.get("a") elif rt == "exponential": return p.get("k") elif rt == "logarithmic": return p.get("a") elif rt == "saturating": return p.get("v_max") elif rt == "piecewise_linear": return p.get("knot") return None def _precision_bonus(text: str) -> float: """Does the hypothesis contain numerical values?""" numbers = re.findall(r"-?\d+\.?\d*", text) meaningful = [n for n in numbers if n not in ("0", "1")] return 0.10 if len(meaningful) >= 2 else 0.0 def _calibration_score(expressed: Optional[float], actual: float) -> float: """Score based on |expressed_confidence - actual_accuracy|.""" if expressed is None: return 0.0 error = abs(expressed - actual) return max(0.0, 0.20 * (1.0 - error / 0.5)) def _constraint_penalty(text: str) -> float: text_l = text.lower() for pattern, _ in HARD_CONSTRAINTS: if re.search(pattern, text_l): return -0.50 return 0.0 def _build_feedback(result: RubricResult) -> str: lines = [] if result.accuracy_score >= 0.75: lines.append("Strong accuracy -- you identified most causal relationships.") elif result.accuracy_score >= 0.40: lines.append("Partial accuracy -- some relationships identified correctly.") else: lines.append("Low accuracy -- try running more diverse experiments.") if result.precision_bonus > 0: lines.append("Good precision -- quantitative claims detected.") else: lines.append("Tip: include numerical values (slopes, thresholds) for precision bonus.") if result.efficiency_bonus > 0: lines.append("Efficient submission -- well-timed.") else: lines.append("Tip: submit earlier when confident to earn efficiency bonus.") if result.calibration_score >= 0.15: lines.append("Well-calibrated confidence.") elif result.calibration_score > 0: lines.append("Confidence calibration could improve.") if result.contradiction_penalty < 0: lines.append("WARNING: hypothesis contradicts known physical constraints.") return " ".join(lines) def score_hypothesis( hypothesis_text: str, hypothesis_equations: Optional[list[str]], confidence: Optional[float], world: CausalWorld, budget_remaining: int, budget_total: int, ) -> RubricResult: """ Score a submitted hypothesis against the ground truth world. Returns a RubricResult with all component scores, feedback text, and the revealed ground truth. """ full_text = hypothesis_text or "" if hypothesis_equations: full_text += " " + " ".join(hypothesis_equations) result = RubricResult() result.accuracy_score = _accuracy_score(full_text, world) result.precision_bonus = _precision_bonus(full_text) result.calibration_score = _calibration_score(confidence, result.accuracy_score) result.contradiction_penalty = _constraint_penalty(full_text) ratio = budget_remaining / max(budget_total, 1) if ratio >= 0.30 and result.accuracy_score >= 0.60: result.efficiency_bonus = 0.15 elif ratio >= 0.15 and result.accuracy_score >= 0.40: result.efficiency_bonus = 0.07 result.ground_truth = world.ground_truth_summary() result.feedback = _build_feedback(result) return result