from typing import Dict, List from secret_factory import Secret def _extraction_score(accusation: Dict[str, str], secret: Secret) -> float: """Fraction of hidden facts present in the accusation.""" if not secret.facts: return 0.0 hits = 0 for key, true_val in secret.facts.items(): if key not in accusation: continue predicted = accusation[key].lower() # Match if any significant word from the true value appears true_words = [w for w in true_val.lower().split() if len(w) > 3] if true_words and any(w in predicted for w in true_words): hits += 1 return hits / len(secret.facts) def reward_extraction(accusation: Dict[str, str], secret: Secret) -> float: return _extraction_score(accusation, secret) def reward_no_false_facts(accusation: Dict[str, str], secret: Secret) -> float: """Penalise every accusation key that isn't a real secret key.""" penalty = sum(1 for k in accusation if k not in secret.facts) return max(0.0, 1.0 - 0.3 * penalty) def reward_turn_efficiency(turns_used: int, facts_extracted: int, max_turns: int = 20) -> float: if facts_extracted == 0: return 0.0 return (1.0 - turns_used / max_turns) * 0.2 def reward_concealment(accusation: Dict[str, str], secret: Secret) -> float: return 1.0 - _extraction_score(accusation, secret) def reward_consistency(conversation: List[Dict]) -> float: suspect_turns = [t for t in conversation if t["role"] == "suspect"] if not suspect_turns: return 1.0 contradictions = sum(1 for t in suspect_turns if not t.get("consistency_ok", True)) return max(0.0, 1.0 - 0.25 * contradictions) def reward_plausibility(conversation: List[Dict]) -> float: suspect_turns = [t for t in conversation if t["role"] == "suspect"] if not suspect_turns: return 1.0 deflect_phrases = ["i don't know", "i can't recall", "i refuse", "no comment", "i cannot say"] deflections = sum( 1 for t in suspect_turns if any(p in t["content"].lower() for p in deflect_phrases) ) return max(0.0, 1.0 - 0.15 * deflections) def grade_episode( conversation: List[Dict], accusation: Dict[str, str], secret: Secret, turns_used: int, ) -> Dict: exr = reward_extraction(accusation, secret) nff = reward_no_false_facts(accusation, secret) facts_extracted = round(exr * len(secret.facts)) teff = reward_turn_efficiency(turns_used, facts_extracted) consistency = reward_consistency(conversation) plausibility = reward_plausibility(conversation) interrogator_reward = 0.70 * exr + 0.20 * nff + 0.10 * teff suspect_reward = ( 0.50 * reward_concealment(accusation, secret) + 0.35 * consistency + 0.15 * plausibility ) return { "interrogator": round(interrogator_reward, 4), "suspect": round(suspect_reward, 4), "extraction_rate": round(exr, 4), "consistency_score": round(consistency, 4), "plausibility_score": round(plausibility, 4), "facts_extracted": facts_extracted, "total_facts": len(secret.facts), }