Spaces:
Sleeping
Sleeping
File size: 8,532 Bytes
e2ca55c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | import json
import random
import uuid
from pathlib import Path
from enum import Enum
from server.models import (
Secret,
MindReadObservation,
StepResult,
SubmitResult,
RewardBreakdown,
TaskMeta,
)
from server.oracle import ask_oracle
from server.reward import compute_reward
SECRETS_PATH = Path(__file__).parent / "data" / "secrets.json"
TASK_META: dict[str, TaskMeta] = {
"factual_easy": TaskMeta(
id="factual_easy",
description="Infer a hidden factual workplace secret (easy) β event, decision, or fact the Oracle knows but hasn't announced.",
max_steps=8,
reward_range=[0.0, 1.0],
difficulty="easy",
category="factual",
),
"factual_hard": TaskMeta(
id="factual_hard",
description="Infer a precise numerical or date-bound secret. Requires specific inference, not just general direction.",
max_steps=6,
reward_range=[0.0, 1.0],
difficulty="hard",
category="factual",
),
"belief_inference": TaskMeta(
id="belief_inference",
description="Infer what the Oracle believes about another person's internal state β emotions, plans, or intentions.",
max_steps=8,
reward_range=[0.0, 1.0],
difficulty="medium",
category="belief",
),
"goal_inference": TaskMeta(
id="goal_inference",
description="Infer the Oracle's hidden personal or professional ambition they haven't disclosed to the team.",
max_steps=8,
reward_range=[0.0, 1.0],
difficulty="medium",
category="goal",
),
"second_order": TaskMeta(
id="second_order",
description="Infer a recursive belief: what the Oracle believes someone else believes β second-order Theory of Mind.",
max_steps=10,
reward_range=[0.0, 1.0],
difficulty="hard",
category="second_order",
),
}
TASK_DESCRIPTION = {
"factual_easy": (
"Figure out what factual information the Oracle is privately aware of "
"but has not publicly disclosed. Ask indirect, strategic questions."
),
"factual_hard": (
"Infer a specific fact (number, date, or precise detail) the Oracle knows privately. "
"You need precision β vague guesses score low."
),
"belief_inference": (
"Determine what the Oracle believes about another person's state of mind, "
"intentions, or emotional situation. The belief may not be stated but can be inferred."
),
"goal_inference": (
"Infer the Oracle's hidden personal ambition or undisclosed professional goal. "
"They won't tell you directly but their answers will reveal it."
),
"second_order": (
"Determine what the Oracle believes that ANOTHER PERSON believes or thinks. "
"This is second-order Theory of Mind β you must infer a belief about a belief."
),
}
class EpisodeState(str, Enum):
IDLE = "idle"
ACTIVE = "active"
SCORED = "scored"
class Episode:
def __init__(self, episode_id: str, secret: Secret, task_id: str):
self.episode_id = episode_id
self.secret = secret
self.task_id = task_id
self.state = EpisodeState.ACTIVE
self.conversation_history: list[dict] = []
self.step = 0
self.max_steps = TASK_META[task_id].max_steps
self.reward: float | None = None
self.breakdown: RewardBreakdown | None = None
def questions_remaining(self) -> int:
return max(0, self.max_steps - self.step)
def to_observation(self) -> MindReadObservation:
return MindReadObservation(
episode_id=self.episode_id,
task_id=self.task_id,
step=self.step,
max_steps=self.max_steps,
context=self.secret.context,
oracle_persona=self.secret.persona,
conversation_history=list(self.conversation_history),
questions_remaining=self.questions_remaining(),
task_description=TASK_DESCRIPTION[self.task_id],
)
class MindReadEnv:
def __init__(self):
self._secrets: dict[str, list[Secret]] = {}
self._episodes: dict[str, Episode] = {}
self._load_secrets()
def _load_secrets(self):
raw = json.loads(SECRETS_PATH.read_text(encoding="utf-8"))
for item in raw:
s = Secret(**item)
self._secrets.setdefault(s.task_id, []).append(s)
def get_tasks(self) -> list[TaskMeta]:
return list(TASK_META.values())
def reset(self, task_id: str, secret_id: str | None = None) -> MindReadObservation:
if task_id not in TASK_META:
raise ValueError(f"Unknown task_id: {task_id}")
pool = self._secrets.get(task_id, [])
if not pool:
raise RuntimeError(f"No secrets available for task: {task_id}")
if secret_id:
candidates = [s for s in pool if s.id == secret_id]
if not candidates:
raise ValueError(f"secret_id {secret_id!r} not found in task {task_id!r}")
secret = candidates[0]
else:
secret = random.choice(pool)
episode_id = str(uuid.uuid4())
ep = Episode(episode_id=episode_id, secret=secret, task_id=task_id)
self._episodes[episode_id] = ep
return ep.to_observation()
def step(self, episode_id: str, question: str) -> StepResult:
ep = self._get_active(episode_id)
if ep.questions_remaining() == 0:
obs = ep.to_observation()
return StepResult(
observation=obs,
reward=0.0,
done=True,
info={"error": "No questions remaining. Please submit a hypothesis."},
)
oracle_answer = ask_oracle(ep.secret, ep.conversation_history, question)
ep.conversation_history.append({"role": "detective", "content": question})
ep.conversation_history.append({"role": "oracle", "content": oracle_answer})
ep.step += 1
done = ep.questions_remaining() == 0
obs = ep.to_observation()
return StepResult(
observation=obs,
reward=0.0,
done=done,
info={"oracle_response": oracle_answer},
)
def submit(
self,
episode_id: str,
hypothesis: str,
category_prediction: str | None = None,
) -> SubmitResult:
ep = self._get_active(episode_id)
result = compute_reward(
hypothesis=hypothesis,
true_secret=ep.secret.content,
n_questions_used=ep.step,
max_questions=ep.max_steps,
category_predicted=category_prediction,
category_true=ep.secret.category,
hint_keywords=ep.secret.hint_keywords,
)
breakdown = RewardBreakdown(
reward=result["reward"],
semantic_similarity=result["components"]["semantic"],
efficiency_bonus=result["components"]["efficiency"],
category_bonus=result["components"]["category_bonus"],
keyword_bonus=result["components"]["keyword_bonus"],
questions_used=ep.step,
hypothesis=hypothesis,
)
ep.reward = result["reward"]
ep.breakdown = breakdown
ep.state = EpisodeState.SCORED
return SubmitResult(
reward=result["reward"],
breakdown=breakdown,
true_secret=ep.secret.content,
episode_id=episode_id,
done=True,
)
def get_state(self, episode_id: str) -> MindReadObservation:
if episode_id not in self._episodes:
raise KeyError(f"Episode {episode_id!r} not found")
return self._episodes[episode_id].to_observation()
def add_secret(self, secret: Secret):
self._secrets.setdefault(secret.task_id, []).append(secret)
def _get_active(self, episode_id: str) -> Episode:
if episode_id not in self._episodes:
raise KeyError(f"Episode {episode_id!r} not found")
ep = self._episodes[episode_id]
if ep.state != EpisodeState.ACTIVE:
raise ValueError(f"Episode {episode_id!r} is in state {ep.state.value}, not active")
return ep
|