Coding Ninja commited on
Commit
1f2ca34
·
1 Parent(s): 22170b0

Push 3: Curriculum controller, hidden-state pipeline, phase detector, trial judge, and full EpisodeManager wiring

Browse files
models.py CHANGED
@@ -100,6 +100,7 @@ class TrialLatentState(BaseModel):
100
  protocol_submitted: bool
101
  interim_complete: bool
102
  trial_complete: bool
 
103
  # Episode tracking (used by rule engine and phase detector)
104
  episode_phase: str
105
  action_history: list[str]
 
100
  protocol_submitted: bool
101
  interim_complete: bool
102
  trial_complete: bool
103
+ adverse_events: int # cumulative count of recorded adverse events
104
  # Episode tracking (used by rule engine and phase detector)
105
  episode_phase: str
106
  action_history: list[str]
pyproject.toml CHANGED
@@ -32,6 +32,9 @@ target-version = "py311"
32
  select = ["E", "F", "W", "I"]
33
  ignore = []
34
 
 
 
 
35
  [tool.pytest.ini_options]
36
  testpaths = ["tests"]
37
  addopts = "-v"
 
32
  select = ["E", "F", "W", "I"]
33
  ignore = []
34
 
35
+ [tool.ruff.lint.per-file-ignores]
36
+ "tests/**" = ["E501"]
37
+
38
  [tool.pytest.ini_options]
39
  testpaths = ["tests"]
40
  addopts = "-v"
server/curriculum/__init__.py CHANGED
@@ -1,7 +1,35 @@
1
  """
2
  curriculum — Curriculum controller and scenario registry.
3
 
4
- Provides advance_curriculum, select_scenario, and the four initial
5
- ScenarioConfig instances (solid_tumor_chemo, autoimmune_biologic,
6
  cns_depression, rare_disease_orphan).
7
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  curriculum — Curriculum controller and scenario registry.
3
 
4
+ Provides advance_curriculum, select_scenario, EpisodeMetrics, and the four
5
+ initial ScenarioConfig instances (solid_tumor_chemo, autoimmune_biologic,
6
  cns_depression, rare_disease_orphan).
7
  """
8
+
9
+ from server.curriculum.controller import (
10
+ EpisodeMetrics,
11
+ advance_curriculum,
12
+ select_scenario,
13
+ )
14
+ from server.curriculum.scenarios import (
15
+ AUTOIMMUNE_BIOLOGIC,
16
+ CNS_DEPRESSION,
17
+ RARE_DISEASE_ORPHAN,
18
+ SCENARIO_LIST,
19
+ SCENARIOS,
20
+ SOLID_TUMOR_CHEMO,
21
+ WARMUP,
22
+ )
23
+
24
+ __all__ = [
25
+ "EpisodeMetrics",
26
+ "advance_curriculum",
27
+ "select_scenario",
28
+ "WARMUP",
29
+ "SOLID_TUMOR_CHEMO",
30
+ "AUTOIMMUNE_BIOLOGIC",
31
+ "CNS_DEPRESSION",
32
+ "RARE_DISEASE_ORPHAN",
33
+ "SCENARIOS",
34
+ "SCENARIO_LIST",
35
+ ]
server/curriculum/controller.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Curriculum controller for the Clinical Trial Designer environment.
3
+
4
+ Exposes:
5
+ - advance_curriculum(tier, metrics) -> int
6
+ - select_scenario(tier, rng) -> ScenarioConfig
7
+
8
+ 5-tier mastery logic:
9
+ Tier 0: warmup
10
+ Tier 1: beginner
11
+ Tier 2: intermediate
12
+ Tier 3: advanced
13
+ Tier 4: expert
14
+
15
+ Graduation rules:
16
+ - 70% rolling success rate over recent episodes → advance one tier
17
+ - 90% success rate after at least 3 episodes → fast-track (skip one tier)
18
+ - Max tier is 4 (expert); cannot advance beyond.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass, field
24
+ from typing import Sequence
25
+
26
+ import numpy as np
27
+
28
+ from models import ScenarioConfig
29
+ from server.curriculum.scenarios import (
30
+ AUTOIMMUNE_BIOLOGIC,
31
+ CNS_DEPRESSION,
32
+ RARE_DISEASE_ORPHAN,
33
+ SOLID_TUMOR_CHEMO,
34
+ WARMUP,
35
+ )
36
+
37
+ # ── Constants ────────────────────────────────────────────────────────────────
38
+
39
+ MIN_TIER: int = 0
40
+ MAX_TIER: int = 4
41
+
42
+ MASTERY_THRESHOLD: float = 0.70 # 70% rolling success → graduate
43
+ FAST_TRACK_THRESHOLD: float = 0.90 # 90% success after ≥3 episodes → skip tier
44
+ FAST_TRACK_MIN_EPISODES: int = 3
45
+
46
+ # Rolling window size for success-rate calculation
47
+ ROLLING_WINDOW: int = 10
48
+
49
+ # Tier → ScenarioConfig mapping (one canonical scenario per tier)
50
+ _TIER_SCENARIO: dict[int, ScenarioConfig] = {
51
+ 0: WARMUP,
52
+ 1: SOLID_TUMOR_CHEMO,
53
+ 2: AUTOIMMUNE_BIOLOGIC,
54
+ 3: CNS_DEPRESSION,
55
+ 4: RARE_DISEASE_ORPHAN,
56
+ }
57
+
58
+ TIER_NAMES: dict[int, str] = {
59
+ 0: "warmup",
60
+ 1: "beginner",
61
+ 2: "intermediate",
62
+ 3: "advanced",
63
+ 4: "expert",
64
+ }
65
+
66
+
67
+ # ── EpisodeMetrics ────────────────────────────────────────────────────────────
68
+
69
+
70
+ @dataclass
71
+ class EpisodeMetrics:
72
+ """Performance metrics for a completed episode.
73
+
74
+ Attributes:
75
+ success: Whether the episode ended in a successful trial outcome.
76
+ episode_history: Rolling list of recent success booleans (most recent
77
+ episode appended last). The controller uses the last
78
+ ``ROLLING_WINDOW`` entries to compute the rolling success rate.
79
+ Callers should append the current episode's ``success`` value
80
+ *before* passing this object to ``advance_curriculum``.
81
+ """
82
+
83
+ success: bool
84
+ episode_history: list[bool] = field(default_factory=list)
85
+
86
+
87
+ # ── Public API ────────────────────────────────────────────────────────────────
88
+
89
+
90
+ def advance_curriculum(tier: int, metrics: EpisodeMetrics) -> int:
91
+ """Return the updated curriculum tier after evaluating episode metrics.
92
+
93
+ Args:
94
+ tier: Current curriculum tier (0–4).
95
+ metrics: Performance metrics for the just-completed episode.
96
+ ``metrics.episode_history`` must already include the current
97
+ episode's success value as its last element.
98
+
99
+ Returns:
100
+ The new curriculum tier. May be the same tier (not yet mastered),
101
+ ``tier + 1`` (normal graduation), or ``tier + 2`` (fast-track skip).
102
+ Never exceeds ``MAX_TIER``.
103
+ """
104
+ if tier >= MAX_TIER:
105
+ return MAX_TIER
106
+
107
+ history: Sequence[bool] = metrics.episode_history
108
+ n_episodes = len(history)
109
+
110
+ if n_episodes == 0:
111
+ return tier
112
+
113
+ # Use the most recent ROLLING_WINDOW episodes for the rolling rate
114
+ window = list(history[-ROLLING_WINDOW:])
115
+ rolling_rate = sum(window) / len(window)
116
+
117
+ # Fast-track: 90%+ success after at least 3 episodes → skip one tier
118
+ if n_episodes >= FAST_TRACK_MIN_EPISODES and rolling_rate >= FAST_TRACK_THRESHOLD:
119
+ new_tier = min(tier + 2, MAX_TIER)
120
+ return new_tier
121
+
122
+ # Normal graduation: 70%+ rolling success → advance one tier
123
+ if rolling_rate >= MASTERY_THRESHOLD:
124
+ return min(tier + 1, MAX_TIER)
125
+
126
+ return tier
127
+
128
+
129
+ def select_scenario(tier: int, rng: np.random.Generator) -> ScenarioConfig:
130
+ """Select a ScenarioConfig appropriate for the given curriculum tier.
131
+
132
+ At tier 0 (warmup) the solid_tumor_chemo scenario is returned with an
133
+ inflated effect size (already encoded in the WARMUP ScenarioConfig).
134
+
135
+ Args:
136
+ tier: Current curriculum tier (0–4). Values outside [0, 4] are
137
+ clamped to the valid range.
138
+ rng: A seeded ``numpy.random.Generator`` used for any stochastic
139
+ selection. Currently each tier maps to exactly one scenario, so
140
+ ``rng`` is accepted for API consistency and future extensibility
141
+ (e.g. sampling from a pool of scenarios at the same tier).
142
+
143
+ Returns:
144
+ The ``ScenarioConfig`` for the given tier.
145
+ """
146
+ clamped_tier = max(MIN_TIER, min(tier, MAX_TIER))
147
+ return _TIER_SCENARIO[clamped_tier]
server/curriculum/scenarios.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Scenario registry for the curriculum controller.
3
+
4
+ Defines ScenarioConfig instances for all four scenario IDs plus a tier-0 warmup
5
+ variant of solid_tumor_chemo with an inflated effect size.
6
+ """
7
+
8
+ from models import ScenarioConfig
9
+
10
+ # Tier 0 — warmup (solid_tumor_chemo with inflated effect size, easier)
11
+ WARMUP = ScenarioConfig(
12
+ scenario_id="solid_tumor_chemo_warmup",
13
+ curriculum_tier=0,
14
+ disease_area="oncology",
15
+ effect_size_range=(0.55, 0.85), # inflated vs tier-1 (0.25–0.55)
16
+ side_effect_rate_range=(0.10, 0.25),
17
+ placebo_response_range=(0.05, 0.15),
18
+ dropout_rate_range=(0.05, 0.10),
19
+ budget_usd=8_000_000.0,
20
+ time_budget_days=365,
21
+ min_sample_size=60,
22
+ description=(
23
+ "Warmup scenario: EGFR+ solid-tumour chemotherapy with an inflated "
24
+ "effect size to help the agent learn basic trial-design mechanics."
25
+ ),
26
+ )
27
+
28
+ # Tier 1 — EGFR+ subgroup enrichment
29
+ SOLID_TUMOR_CHEMO = ScenarioConfig(
30
+ scenario_id="solid_tumor_chemo",
31
+ curriculum_tier=1,
32
+ disease_area="oncology",
33
+ effect_size_range=(0.25, 0.55),
34
+ side_effect_rate_range=(0.15, 0.35),
35
+ placebo_response_range=(0.05, 0.15),
36
+ dropout_rate_range=(0.05, 0.15),
37
+ budget_usd=10_000_000.0,
38
+ time_budget_days=540,
39
+ min_sample_size=80,
40
+ description=(
41
+ "EGFR+ solid-tumour chemotherapy. Agent must identify the EGFR+ "
42
+ "biomarker subgroup to unlock the true effect size."
43
+ ),
44
+ )
45
+
46
+ # Tier 2 — U-shaped dose-response
47
+ AUTOIMMUNE_BIOLOGIC = ScenarioConfig(
48
+ scenario_id="autoimmune_biologic",
49
+ curriculum_tier=2,
50
+ disease_area="immunology",
51
+ effect_size_range=(0.20, 0.45),
52
+ side_effect_rate_range=(0.10, 0.30),
53
+ placebo_response_range=(0.15, 0.30),
54
+ dropout_rate_range=(0.08, 0.18),
55
+ budget_usd=15_000_000.0,
56
+ time_budget_days=720,
57
+ min_sample_size=120,
58
+ description=(
59
+ "Autoimmune biologic with a U-shaped dose-response curve. "
60
+ "Agent must run dose-escalation to find the optimal dose window."
61
+ ),
62
+ )
63
+
64
+ # Tier 3 — high placebo response
65
+ CNS_DEPRESSION = ScenarioConfig(
66
+ scenario_id="cns_depression",
67
+ curriculum_tier=3,
68
+ disease_area="psychiatry",
69
+ effect_size_range=(0.15, 0.35),
70
+ side_effect_rate_range=(0.10, 0.25),
71
+ placebo_response_range=(0.35, 0.55), # high placebo response
72
+ dropout_rate_range=(0.10, 0.25),
73
+ budget_usd=20_000_000.0,
74
+ time_budget_days=900,
75
+ min_sample_size=200,
76
+ description=(
77
+ "CNS depression trial with a high placebo-response rate. "
78
+ "Agent must power the study to detect a small drug-placebo delta."
79
+ ),
80
+ )
81
+
82
+ # Tier 4 — rare disease / tiny n
83
+ RARE_DISEASE_ORPHAN = ScenarioConfig(
84
+ scenario_id="rare_disease_orphan",
85
+ curriculum_tier=4,
86
+ disease_area="rare_disease",
87
+ effect_size_range=(0.40, 0.80), # larger effect needed to compensate tiny n
88
+ side_effect_rate_range=(0.05, 0.20),
89
+ placebo_response_range=(0.05, 0.15),
90
+ dropout_rate_range=(0.05, 0.15),
91
+ budget_usd=5_000_000.0,
92
+ time_budget_days=1080,
93
+ min_sample_size=10, # tiny n — orphan disease
94
+ description=(
95
+ "Rare-disease orphan drug trial with a very small patient population. "
96
+ "Agent must justify statistical validity under FDA orphan-drug rules."
97
+ ),
98
+ )
99
+
100
+ # Registry — keyed by scenario_id for O(1) lookup
101
+ SCENARIOS: dict[str, ScenarioConfig] = {
102
+ WARMUP.scenario_id: WARMUP,
103
+ SOLID_TUMOR_CHEMO.scenario_id: SOLID_TUMOR_CHEMO,
104
+ AUTOIMMUNE_BIOLOGIC.scenario_id: AUTOIMMUNE_BIOLOGIC,
105
+ CNS_DEPRESSION.scenario_id: CNS_DEPRESSION,
106
+ RARE_DISEASE_ORPHAN.scenario_id: RARE_DISEASE_ORPHAN,
107
+ }
108
+
109
+ # Convenience list ordered by tier
110
+ SCENARIO_LIST: list[ScenarioConfig] = [
111
+ WARMUP,
112
+ SOLID_TUMOR_CHEMO,
113
+ AUTOIMMUNE_BIOLOGIC,
114
+ CNS_DEPRESSION,
115
+ RARE_DISEASE_ORPHAN,
116
+ ]
server/episode_manager.py CHANGED
@@ -10,38 +10,45 @@ from __future__ import annotations
10
 
11
  import random
12
  import uuid
 
 
 
13
 
14
  from models import (
 
15
  RewardBreakdown,
16
  ScenarioConfig,
17
  TrialAction,
18
  TrialLatentState,
19
  TrialObservation,
20
- TrialResult,
21
  TrialState,
22
  )
 
 
23
  from server.logger import EpisodeLogger
24
  from server.noise_model import NoiseModel
 
 
25
  from server.rules.fda_rules import check_fda_compliance
26
-
27
- # Default scenario used until CurriculumController is fully wired (Push 3).
28
- _DEFAULT_SCENARIO = ScenarioConfig(
29
- scenario_id="solid_tumor_chemo",
30
- curriculum_tier=0,
31
- disease_area="NSCLC",
32
- effect_size_range=(0.3, 0.7),
33
- side_effect_rate_range=(0.05, 0.20),
34
- placebo_response_range=(0.10, 0.25),
35
- dropout_rate_range=(0.05, 0.15),
36
- budget_usd=1_000_000.0,
37
- time_budget_days=365,
38
- min_sample_size=100,
39
- description="Solid tumor chemotherapy — find EGFR+ subgroup",
40
- )
41
 
42
  _MAX_STEPS = 100
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  class EpisodeManager:
46
  """Orchestrates the reset/step lifecycle for a single clinical trial episode.
47
 
@@ -58,22 +65,35 @@ class EpisodeManager:
58
  self._episode_id: str = ""
59
  self._difficulty: float = 0.0
60
  self._scenario: ScenarioConfig | None = None
 
 
 
 
 
61
 
62
  # ------------------------------------------------------------------
63
  # Public API
64
  # ------------------------------------------------------------------
65
 
66
  def reset(self, seed: int | None = None) -> TrialObservation:
67
- """Initialize a new episode and return the initial TrialObservation."""
 
 
 
 
68
  resolved_seed = seed if seed is not None else random.randint(0, 2**31 - 1)
69
  self._episode_id = str(uuid.uuid4())
70
 
71
- # Step 1: Select scenario (stub CurriculumController wired in Push 3)
72
- scenario = _DEFAULT_SCENARIO
 
 
73
  self._scenario = scenario
74
 
75
- # Step 2: Apply domain randomization via NoiseModel (req 9.1, 9.2)
 
76
  noise_model = NoiseModel(seed=resolved_seed)
 
77
  randomized = noise_model.randomize(scenario)
78
 
79
  # Sample concrete hidden values from randomized ranges
@@ -109,6 +129,7 @@ class EpisodeManager:
109
  protocol_submitted=False,
110
  interim_complete=False,
111
  trial_complete=False,
 
112
  episode_phase="literature_review",
113
  action_history=[],
114
  seed=resolved_seed,
@@ -117,26 +138,55 @@ class EpisodeManager:
117
  # Step 4: Build lightweight TrialState for training loop
118
  self._state = self._state_from_latent(self._latent, randomized)
119
 
 
120
  self._clear_cache()
 
121
 
122
- # Step 5: Fresh logger and reward accumulator
123
  self._logger = EpisodeLogger(
124
- curriculum_tier=randomized.curriculum_tier
 
125
  )
126
  self._total_reward = 0.0
127
- self._difficulty = 0.0
128
-
129
- return self._observation_from_latent(self._latent, randomized)
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  def step(
132
  self, action: TrialAction
133
  ) -> tuple[TrialObservation, RewardBreakdown, bool, dict]:
134
- """Advance the episode by one step."""
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  if self._latent is None or self._scenario is None:
136
  raise RuntimeError("No active episode. Call reset() before step().")
137
 
138
  try:
139
- # Check FDA compliance against latent state (req 10.1, 10.4)
140
  compliance = check_fda_compliance(action, self._latent)
141
 
142
  if not compliance.valid:
@@ -146,73 +196,116 @@ class EpisodeManager:
146
  r_info_gain=0.0,
147
  r_efficiency=0.0,
148
  r_novelty=0.0,
149
- r_penalty=0.0,
150
  r_terminal_success=0.0,
151
  r_terminal_calibration=0.0,
152
  )
153
  done = False
 
154
  info: dict = {
155
- "step_index": len(self._latent.action_history),
156
  "action_valid": False,
157
  "violations": compliance.violations,
158
  }
159
- obs = self._observation_from_latent(
160
- self._latent,
161
- self._scenario,
 
 
 
 
 
 
162
  rule_violations=compliance.violations,
 
 
 
 
163
  )
 
164
  if self._logger is not None:
165
- self._logger.log_step(
166
- len(self._latent.action_history), action, obs, reward, done
167
- )
168
  return obs, reward, done, info
169
 
170
- # Valid action: advance latent state
171
- self._latent = self._latent.model_copy(
172
- update={
173
- "action_history": (
174
- self._latent.action_history + [action.action_type.value]
175
- ),
176
- }
177
  )
 
178
 
179
- # Stub TrialResult
180
- _result = TrialResult(
181
- p_value=0.05,
182
- success=False,
183
- power=0.8,
184
- adverse_event_rate=0.1,
185
- confidence_interval=(0.0, 1.0),
186
- failure_reason=None,
187
- )
188
 
189
- # Stub reward
190
- reward = RewardBreakdown(
191
- r_validity=0.0,
192
- r_ordering=0.0,
193
- r_info_gain=0.0,
194
- r_efficiency=0.0,
195
- r_novelty=0.0,
196
- r_penalty=0.0,
197
- r_terminal_success=0.0,
198
- r_terminal_calibration=0.0,
199
  )
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  step_idx = len(self._latent.action_history)
202
  done = step_idx >= _MAX_STEPS or self._latent.trial_complete
203
- info = {"step_index": step_idx, "action_valid": True}
204
 
205
- obs = self._observation_from_latent(self._latent, self._scenario)
206
-
207
- # Update training-loop TrialState
208
- self._state = self._state_from_latent(self._latent, self._scenario)
 
 
 
 
 
 
 
 
 
 
209
 
210
- # Accumulate reward and log step (req 7.1)
211
- self._total_reward += sum(reward.model_dump().values())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  if self._logger is not None:
213
  self._logger.log_step(step_idx, action, obs, reward, done)
 
 
 
 
214
 
215
- # Log summary on episode end (req 7.2)
216
  if done and self._logger is not None:
217
  self._logger.log_summary(
218
  scenario_id=self._scenario.scenario_id,
@@ -223,11 +316,22 @@ class EpisodeManager:
223
  ),
224
  )
225
 
 
 
 
 
 
 
 
 
 
 
 
226
  return obs, reward, done, info
227
 
228
  except RuntimeError:
229
  raise
230
- except Exception as exc: # req 10.4: no unhandled exceptions
231
  reward = RewardBreakdown(
232
  r_validity=-1.0,
233
  r_ordering=0.0,
@@ -244,10 +348,50 @@ class EpisodeManager:
244
  "action_valid": False,
245
  "violations": [f"Internal error: {exc}"],
246
  }
247
- obs = self._observation_from_latent(
248
- self._latent,
249
- self._scenario,
250
- rule_violations=[f"Internal error: {exc}"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  )
252
  return obs, reward, False, info
253
 
@@ -276,9 +420,20 @@ class EpisodeManager:
276
  """Build the lightweight TrialState from latent state."""
277
  step_count = len(latent.action_history)
278
  unique_actions = len(set(latent.action_history))
279
- action_diversity = (
280
- unique_actions / step_count if step_count > 0 else 0.0
281
- )
 
 
 
 
 
 
 
 
 
 
 
282
  return TrialState(
283
  episode_id=self._episode_id,
284
  step_count=step_count,
@@ -287,37 +442,6 @@ class EpisodeManager:
287
  curriculum_tier=str(scenario.curriculum_tier),
288
  curriculum_stats={},
289
  action_diversity=action_diversity,
290
- phase_compliance_rate=0.0, # wired in Push 3 with PhaseDetector
291
  is_resolved=latent.trial_complete,
292
  )
293
-
294
- def _observation_from_latent(
295
- self,
296
- latent: TrialLatentState,
297
- scenario: ScenarioConfig,
298
- rule_violations: list[str] | None = None,
299
- ) -> TrialObservation:
300
- """Build a TrialObservation from latent state — noisy, agent-facing."""
301
- return TrialObservation(
302
- scenario_description=scenario.description,
303
- phase_data={
304
- "episode_phase": latent.episode_phase,
305
- "observed_effect_estimate": None,
306
- "observed_side_effect_rate": None,
307
- "phase_i_complete": latent.phase_i_complete,
308
- "interim_complete": latent.interim_complete,
309
- "protocol_submitted": latent.protocol_submitted,
310
- },
311
- resource_status={
312
- "budget_remaining": latent.budget_remaining,
313
- "time_remaining_days": latent.time_remaining_days,
314
- "patients_enrolled": latent.patients_enrolled,
315
- },
316
- rule_violations=rule_violations or [],
317
- available_actions=[], # wired in Push 3 with TransitionEngine
318
- steps_taken=len(latent.action_history),
319
- max_steps=_MAX_STEPS,
320
- hint="", # populated by TrialJudge at junior difficulty (Push 3)
321
- done=latent.trial_complete,
322
- reward=0.0, # filled in by step() after reward computation
323
- )
 
10
 
11
  import random
12
  import uuid
13
+ from datetime import datetime, timezone
14
+
15
+ import numpy as np
16
 
17
  from models import (
18
+ EpisodeTranscript,
19
  RewardBreakdown,
20
  ScenarioConfig,
21
  TrialAction,
22
  TrialLatentState,
23
  TrialObservation,
 
24
  TrialState,
25
  )
26
+ from server.curriculum.controller import select_scenario
27
+ from server.judge import TrialJudge
28
  from server.logger import EpisodeLogger
29
  from server.noise_model import NoiseModel
30
+ from server.phase_detector import detect_phase
31
+ from server.reward.reward_computer import compute_reward
32
  from server.rules.fda_rules import check_fda_compliance
33
+ from server.simulator.output_generator import OutputGenerator
34
+ from server.simulator.transition_engine import TransitionEngine
35
+ from server.simulator.trial_simulator import simulate_trial
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  _MAX_STEPS = 100
38
 
39
 
40
+ def _phase_order_correct_at(phase: str, prior_history: list[str]) -> bool:
41
+ """Return True if `phase` is a valid next phase given `prior_history`."""
42
+ from server.phase_detector import PHASE_ORDER
43
+
44
+ if not prior_history:
45
+ return True
46
+ last = prior_history[-1]
47
+ last_idx = PHASE_ORDER.index(last) if last in PHASE_ORDER else 0
48
+ current_idx = PHASE_ORDER.index(phase) if phase in PHASE_ORDER else 0
49
+ return current_idx >= last_idx and (current_idx - last_idx) <= 1
50
+
51
+
52
  class EpisodeManager:
53
  """Orchestrates the reset/step lifecycle for a single clinical trial episode.
54
 
 
65
  self._episode_id: str = ""
66
  self._difficulty: float = 0.0
67
  self._scenario: ScenarioConfig | None = None
68
+ self._phase_history: list[str] = []
69
+ self._noise_model: NoiseModel | None = None
70
+ self._curriculum_tier: int = 0
71
+ self._transition_engine: TransitionEngine = TransitionEngine()
72
+ self._judge: TrialJudge = TrialJudge()
73
 
74
  # ------------------------------------------------------------------
75
  # Public API
76
  # ------------------------------------------------------------------
77
 
78
  def reset(self, seed: int | None = None) -> TrialObservation:
79
+ """Initialize a new episode and return the initial TrialObservation.
80
+
81
+ Seeded resets are reproducible: same seed → same scenario selection
82
+ and initial TrialLatentState (Req 8.5, 9.4).
83
+ """
84
  resolved_seed = seed if seed is not None else random.randint(0, 2**31 - 1)
85
  self._episode_id = str(uuid.uuid4())
86
 
87
+ # Step 1: Select scenario via CurriculumController (Req 8.3, 8.5)
88
+ # Use a seeded RNG so scenario selection is reproducible for same seed.
89
+ scenario_rng = np.random.default_rng(resolved_seed)
90
+ scenario = select_scenario(self._curriculum_tier, scenario_rng)
91
  self._scenario = scenario
92
 
93
+ # Step 2: Apply domain randomization via NoiseModel (Req 9.1, 9.2)
94
+ # NoiseModel is seeded so same seed → same randomized config.
95
  noise_model = NoiseModel(seed=resolved_seed)
96
+ self._noise_model = noise_model
97
  randomized = noise_model.randomize(scenario)
98
 
99
  # Sample concrete hidden values from randomized ranges
 
129
  protocol_submitted=False,
130
  interim_complete=False,
131
  trial_complete=False,
132
+ adverse_events=0,
133
  episode_phase="literature_review",
134
  action_history=[],
135
  seed=resolved_seed,
 
138
  # Step 4: Build lightweight TrialState for training loop
139
  self._state = self._state_from_latent(self._latent, randomized)
140
 
141
+ # Step 5: Clear power cache (Req 14.3)
142
  self._clear_cache()
143
+ self._phase_history = []
144
 
145
+ # Step 6: Fresh logger (episode_id matches this episode), reward accumulator
146
  self._logger = EpisodeLogger(
147
+ episode_id=self._episode_id,
148
+ curriculum_tier=randomized.curriculum_tier,
149
  )
150
  self._total_reward = 0.0
151
+ # Difficulty scales linearly with curriculum tier: tier 0 → 0.0, tier 4 → 1.0
152
+ self._difficulty = scenario.curriculum_tier / 4.0
153
+
154
+ # Step 7: Return initial TrialObservation via OutputGenerator
155
+ output_gen = OutputGenerator(noise_model)
156
+ return output_gen.generate(
157
+ latent=self._latent,
158
+ trial_state=self._state,
159
+ steps_taken=0,
160
+ max_steps=_MAX_STEPS,
161
+ rule_violations=[],
162
+ done=False,
163
+ reward=0.0,
164
+ scenario_description=scenario.description,
165
+ hint="",
166
+ )
167
 
168
  def step(
169
  self, action: TrialAction
170
  ) -> tuple[TrialObservation, RewardBreakdown, bool, dict]:
171
+ """Advance the episode by one step.
172
+
173
+ Full pipeline (Req 8.5, 9.4, 7.1):
174
+ 1. Validate active episode
175
+ 2. check_fda_compliance → ComplianceResult
176
+ 3. TransitionEngine.apply_transition() mutates TrialLatentState
177
+ 4. OutputGenerator.generate() produces noisy TrialObservation
178
+ 5. compute_reward() → RewardBreakdown
179
+ 6. PhaseDetector.detect_phase() classifies action
180
+ 7. TrialJudge.verify() for hint/feedback
181
+ 8. Check terminal condition
182
+ 9. Log full EpisodeTranscript to JSONL
183
+ 10. Return (obs, reward_breakdown, done, info)
184
+ """
185
  if self._latent is None or self._scenario is None:
186
  raise RuntimeError("No active episode. Call reset() before step().")
187
 
188
  try:
189
+ # Step 1: Check FDA compliance (read-only, does not mutate state)
190
  compliance = check_fda_compliance(action, self._latent)
191
 
192
  if not compliance.valid:
 
196
  r_info_gain=0.0,
197
  r_efficiency=0.0,
198
  r_novelty=0.0,
199
+ r_penalty=-0.5 * len(compliance.violations),
200
  r_terminal_success=0.0,
201
  r_terminal_calibration=0.0,
202
  )
203
  done = False
204
+ step_idx = len(self._latent.action_history)
205
  info: dict = {
206
+ "step_index": step_idx,
207
  "action_valid": False,
208
  "violations": compliance.violations,
209
  }
210
+ # Build observation without mutating latent
211
+ noise_model = self._noise_model or NoiseModel(seed=self._latent.seed)
212
+ output_gen = OutputGenerator(noise_model)
213
+ obs = output_gen.generate(
214
+ latent=self._latent,
215
+ trial_state=self._state
216
+ or self._state_from_latent(self._latent, self._scenario),
217
+ steps_taken=step_idx,
218
+ max_steps=_MAX_STEPS,
219
  rule_violations=compliance.violations,
220
+ done=False,
221
+ reward=reward.total,
222
+ scenario_description=self._scenario.description,
223
+ hint="",
224
  )
225
+ # Log invalid step
226
  if self._logger is not None:
227
+ self._logger.log_step(step_idx, action, obs, reward, done)
 
 
228
  return obs, reward, done, info
229
 
230
+ # Step 2: TransitionEngine mutates TrialLatentState
231
+ updated_latent = self._transition_engine.apply_transition(
232
+ self._latent, action
 
 
 
 
233
  )
234
+ self._latent = updated_latent
235
 
236
+ # Step 3: Detect phase and update phase history
237
+ phase_name, phase_order_correct = detect_phase(action, self._phase_history)
238
+ self._phase_history = self._phase_history + [phase_name]
 
 
 
 
 
 
239
 
240
+ # Step 4: Simulate trial result for reward computation
241
+ result = simulate_trial(self._latent, action)
242
+
243
+ # Step 5: Compute reward (all 8 components)
244
+ reward = compute_reward(
245
+ action=action,
246
+ latent=self._latent,
247
+ result=result,
248
+ phase_history=self._phase_history[:-1], # history before this step
 
249
  )
250
 
251
+ # Step 6: TrialJudge verification (hint + overconfidence penalty)
252
+ self._state = self._state_from_latent(self._latent, self._scenario)
253
+ judge_result = self._judge.verify(action, self._state, self._latent)
254
+ hint = judge_result.hint or ""
255
+
256
+ # Apply overconfidence penalty to r_penalty
257
+ if judge_result.overconfidence_penalty != 0.0:
258
+ reward = reward.model_copy(
259
+ update={
260
+ "r_penalty": (
261
+ reward.r_penalty + judge_result.overconfidence_penalty
262
+ )
263
+ }
264
+ )
265
+
266
+ # Step 7: Check terminal condition
267
  step_idx = len(self._latent.action_history)
268
  done = step_idx >= _MAX_STEPS or self._latent.trial_complete
 
269
 
270
+ # Step 8: Generate noisy observation via OutputGenerator
271
+ noise_model = self._noise_model or NoiseModel(seed=self._latent.seed)
272
+ output_gen = OutputGenerator(noise_model)
273
+ obs = output_gen.generate(
274
+ latent=self._latent,
275
+ trial_state=self._state,
276
+ steps_taken=step_idx,
277
+ max_steps=_MAX_STEPS,
278
+ rule_violations=[],
279
+ done=done,
280
+ reward=reward.total,
281
+ scenario_description=self._scenario.description,
282
+ hint=hint,
283
+ )
284
 
285
+ # Step 9: Accumulate total reward
286
+ self._total_reward += reward.total
287
+
288
+ # Step 10: Log full EpisodeTranscript record to JSONL (Req 7.1)
289
+ transcript = EpisodeTranscript(
290
+ episode_id=self._episode_id,
291
+ step=step_idx,
292
+ action=action,
293
+ observation=obs,
294
+ reward_breakdown=reward.model_dump(),
295
+ total_reward=reward.total,
296
+ phase_detected=phase_name,
297
+ phase_order_correct=phase_order_correct,
298
+ hidden_state_snapshot=self._latent,
299
+ timestamp=datetime.now(timezone.utc).isoformat(),
300
+ )
301
  if self._logger is not None:
302
  self._logger.log_step(step_idx, action, obs, reward, done)
303
+ # Also write the full EpisodeTranscript as a separate JSONL record
304
+ self._logger._append_jsonl(
305
+ {"type": "transcript", **transcript.model_dump(mode="json")}
306
+ )
307
 
308
+ # Log summary on episode end (Req 7.2)
309
  if done and self._logger is not None:
310
  self._logger.log_summary(
311
  scenario_id=self._scenario.scenario_id,
 
316
  ),
317
  )
318
 
319
+ info = {
320
+ "step_index": step_idx,
321
+ "action_valid": True,
322
+ "phase_detected": phase_name,
323
+ "phase_order_correct": phase_order_correct,
324
+ "judge_passed": judge_result.passed,
325
+ "judge_feedback": judge_result.feedback,
326
+ "judge_hint": hint,
327
+ "overconfidence_penalty": judge_result.overconfidence_penalty,
328
+ }
329
+
330
  return obs, reward, done, info
331
 
332
  except RuntimeError:
333
  raise
334
+ except Exception as exc: # Req 10.4: no unhandled exceptions
335
  reward = RewardBreakdown(
336
  r_validity=-1.0,
337
  r_ordering=0.0,
 
348
  "action_valid": False,
349
  "violations": [f"Internal error: {exc}"],
350
  }
351
+ noise_model = self._noise_model or NoiseModel(
352
+ seed=self._latent.seed if self._latent else 0
353
+ )
354
+ output_gen = OutputGenerator(noise_model)
355
+ obs = (
356
+ output_gen.generate(
357
+ latent=self._latent,
358
+ trial_state=self._state
359
+ or TrialState(
360
+ episode_id=self._episode_id,
361
+ step_count=step_idx,
362
+ difficulty=self._difficulty,
363
+ scenario_id=self._scenario.scenario_id
364
+ if self._scenario
365
+ else "",
366
+ curriculum_tier="0",
367
+ curriculum_stats={},
368
+ action_diversity=0.0,
369
+ phase_compliance_rate=0.0,
370
+ is_resolved=False,
371
+ ),
372
+ steps_taken=step_idx,
373
+ max_steps=_MAX_STEPS,
374
+ rule_violations=[f"Internal error: {exc}"],
375
+ done=False,
376
+ reward=reward.total,
377
+ scenario_description=(
378
+ self._scenario.description if self._scenario else ""
379
+ ),
380
+ hint="",
381
+ )
382
+ if self._latent is not None
383
+ else TrialObservation(
384
+ scenario_description="",
385
+ phase_data={},
386
+ resource_status={},
387
+ rule_violations=[f"Internal error: {exc}"],
388
+ available_actions=[],
389
+ steps_taken=step_idx,
390
+ max_steps=_MAX_STEPS,
391
+ hint="",
392
+ done=False,
393
+ reward=0.0,
394
+ )
395
  )
396
  return obs, reward, False, info
397
 
 
420
  """Build the lightweight TrialState from latent state."""
421
  step_count = len(latent.action_history)
422
  unique_actions = len(set(latent.action_history))
423
+ action_diversity = unique_actions / step_count if step_count > 0 else 0.0
424
+
425
+ # Compute phase compliance rate from phase history
426
+ phase_steps = len(self._phase_history)
427
+ if phase_steps > 0:
428
+ correct_count = sum(
429
+ 1
430
+ for i, ph in enumerate(self._phase_history)
431
+ if _phase_order_correct_at(ph, self._phase_history[:i])
432
+ )
433
+ phase_compliance_rate = correct_count / phase_steps
434
+ else:
435
+ phase_compliance_rate = 0.0
436
+
437
  return TrialState(
438
  episode_id=self._episode_id,
439
  step_count=step_count,
 
442
  curriculum_tier=str(scenario.curriculum_tier),
443
  curriculum_stats={},
444
  action_diversity=action_diversity,
445
+ phase_compliance_rate=phase_compliance_rate,
446
  is_resolved=latent.trial_complete,
447
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server/judge.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trial Judge — multi-layer verification for clinical trial design decisions.
3
+
4
+ Layer 1 (programmatic, authoritative, never overridden):
5
+ - power >= 0.80
6
+ - p_value < 0.05
7
+ - FDA compliance passes
8
+ - budget_remaining > 0
9
+
10
+ Layer 2 (persona-scaled LLM stub):
11
+ - junior (difficulty < 0.4): gives hints, lenient feedback
12
+ - senior (0.4–0.7): balanced feedback
13
+ - principal (> 0.7): strict, no hints
14
+
15
+ Overconfidence penalty: -0.5 per high-confidence wrong claim
16
+ (action.confidence >= 0.8 and the claim is incorrect per Layer 1).
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from pydantic import BaseModel
22
+
23
+ from models import TrialAction, TrialLatentState, TrialState
24
+ from server.rules.fda_rules import check_fda_compliance
25
+ from server.simulator.power_calculator import calculate_power
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Result model
29
+ # ---------------------------------------------------------------------------
30
+
31
+
32
+ class JudgeResult(BaseModel):
33
+ """Output of TrialJudge.verify()."""
34
+
35
+ passed: bool
36
+ violations: list[str]
37
+ feedback: str
38
+ hint: str | None
39
+ overconfidence_penalty: float
40
+ persona: str
41
+
42
+
43
+ # ---------------------------------------------------------------------------
44
+ # Persona thresholds
45
+ # ---------------------------------------------------------------------------
46
+
47
+ _JUNIOR_MAX = 0.4
48
+ _SENIOR_MAX = 0.7
49
+ _HIGH_CONFIDENCE_THRESHOLD = 0.8
50
+ _OVERCONFIDENCE_PENALTY = -0.5
51
+
52
+
53
+ def _select_persona(difficulty: float) -> str:
54
+ if difficulty < _JUNIOR_MAX:
55
+ return "junior"
56
+ if difficulty <= _SENIOR_MAX:
57
+ return "senior"
58
+ return "principal"
59
+
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # Layer 2: rule-based LLM stub
63
+ # ---------------------------------------------------------------------------
64
+
65
+
66
+ def _generate_feedback(
67
+ persona: str,
68
+ violations: list[str],
69
+ passed: bool,
70
+ action: TrialAction,
71
+ latent: TrialLatentState,
72
+ ) -> tuple[str, str | None]:
73
+ """Return (feedback, hint) for the given persona.
74
+
75
+ This is a rule-based stub that can be replaced with a real LLM call later.
76
+ The stub generates contextually appropriate strings without an LLM.
77
+ """
78
+ action_name = action.action_type.value.replace("_", " ")
79
+
80
+ if passed:
81
+ if persona == "junior":
82
+ feedback = (
83
+ f"Good work on '{action_name}'! Your trial design looks solid. "
84
+ f"Power and significance thresholds are met. Keep it up!"
85
+ )
86
+ hint = (
87
+ "Tip: continue building on this foundation — "
88
+ "consider biomarker stratification next to improve precision."
89
+ )
90
+ elif persona == "senior":
91
+ feedback = (
92
+ f"'{action_name}' passes all programmatic checks. "
93
+ f"Statistical power and p-value criteria are satisfied. "
94
+ f"Proceed to the next design step."
95
+ )
96
+ hint = None
97
+ else: # principal
98
+ feedback = (
99
+ f"'{action_name}' meets minimum criteria. "
100
+ f"Ensure alpha-spending and interim analysis boundaries "
101
+ f"are pre-specified before submission."
102
+ )
103
+ hint = None
104
+ else:
105
+ violation_summary = "; ".join(violations) if violations else "unknown issue"
106
+ if persona == "junior":
107
+ feedback = (
108
+ f"'{action_name}' did not pass verification. "
109
+ f"Issues found: {violation_summary}. "
110
+ f"Review the requirements and try again."
111
+ )
112
+ hint = _build_hint_for_violations(violations, latent)
113
+ elif persona == "senior":
114
+ feedback = (
115
+ f"'{action_name}' failed verification. "
116
+ f"Violations: {violation_summary}. "
117
+ f"Address these before proceeding."
118
+ )
119
+ hint = None
120
+ else: # principal
121
+ feedback = (
122
+ f"'{action_name}' is non-compliant. "
123
+ f"Violations: {violation_summary}. "
124
+ f"No further guidance will be provided — resolve independently."
125
+ )
126
+ hint = None
127
+
128
+ return feedback, hint
129
+
130
+
131
+ def _build_hint_for_violations(
132
+ violations: list[str], latent: TrialLatentState
133
+ ) -> str | None:
134
+ """Build a contextual hint for junior persona based on violation content."""
135
+ if not violations:
136
+ return None
137
+
138
+ first = violations[0].lower()
139
+
140
+ if "power" in first:
141
+ return (
142
+ "Hint: current power is below 0.80. "
143
+ "Try increasing the sample size — "
144
+ "more patients enrolled improves statistical power."
145
+ )
146
+ if "p-value" in first or "p_value" in first or "significance" in first:
147
+ return (
148
+ "Hint: the p-value threshold of 0.05 is not met. "
149
+ "Consider a larger effect size or more patients."
150
+ )
151
+ if "budget" in first:
152
+ return (
153
+ f"Hint: budget is exhausted (remaining: {latent.budget_remaining:.2f}). "
154
+ f"Look for cost-saving measures or request a protocol amendment."
155
+ )
156
+ if "fda" in first or "compliance" in first or "permitted" in first:
157
+ return (
158
+ f"Hint: this action is not allowed in the current phase "
159
+ f"('{latent.episode_phase}'). "
160
+ f"Check the transition table for permitted actions."
161
+ )
162
+ if "sample size" in first:
163
+ return "Hint: the minimum regulatory sample size is 30 participants."
164
+ if "protocol" in first:
165
+ return "Hint: submit the protocol before attempting FDA review."
166
+ if "phase i" in first:
167
+ return "Hint: complete Phase I before submitting to FDA review."
168
+ if "interim" in first:
169
+ return "Hint: run an interim analysis before the primary analysis."
170
+ if "patients" in first or "enrolled" in first:
171
+ return "Hint: enroll patients before running analyses."
172
+
173
+ # Generic fallback
174
+ return f"Hint: {violations[0]}"
175
+
176
+
177
+ # ---------------------------------------------------------------------------
178
+ # Main judge class
179
+ # ---------------------------------------------------------------------------
180
+
181
+
182
+ class TrialJudge:
183
+ """Multi-layer trial design verifier.
184
+
185
+ Layer 1 is programmatic and authoritative — its result is never overridden.
186
+ Layer 2 is persona-scaled and provides human-readable feedback and hints.
187
+ """
188
+
189
+ def verify(
190
+ self,
191
+ action: TrialAction,
192
+ state: TrialState,
193
+ latent: TrialLatentState,
194
+ ) -> JudgeResult:
195
+ """Verify the action against both programmatic and persona layers.
196
+
197
+ Args:
198
+ action: The agent's action to evaluate.
199
+ state: Lightweight training-loop metadata (carries difficulty).
200
+ latent: Hidden ground-truth + episode tracking state.
201
+
202
+ Returns:
203
+ JudgeResult with pass/fail, violations, feedback, hint, and penalty.
204
+ """
205
+ violations: list[str] = []
206
+
207
+ # ------------------------------------------------------------------
208
+ # Layer 1: Programmatic checks (authoritative, never overridden)
209
+ # ------------------------------------------------------------------
210
+
211
+ # 1a. Budget check
212
+ if latent.budget_remaining <= 0:
213
+ violations.append(
214
+ f"Budget exhausted: budget_remaining={latent.budget_remaining:.2f} "
215
+ f"(must be > 0)."
216
+ )
217
+
218
+ # 1b. Statistical power check
219
+ n = max(latent.patients_enrolled, 1)
220
+ power = calculate_power(latent.true_effect_size, n)
221
+ if power < 0.80:
222
+ violations.append(
223
+ f"Insufficient statistical power: {power:.3f} < 0.80 "
224
+ f"(effect_size={latent.true_effect_size:.3f}, n={n})."
225
+ )
226
+
227
+ # 1c. p-value check — derive from power/effect/n
228
+ # We use the same normal approximation as the simulator.
229
+ import math
230
+
231
+ from scipy.stats import norm
232
+
233
+ if n > 0 and latent.true_effect_size != 0.0:
234
+ n_per_arm = n / 2.0
235
+ se = 1.0 / math.sqrt(n_per_arm) if n_per_arm > 0 else 1.0
236
+ z_stat = latent.true_effect_size / se
237
+ p_value = float(2.0 * norm.sf(abs(z_stat)))
238
+ else:
239
+ p_value = 1.0
240
+
241
+ if p_value >= 0.05:
242
+ violations.append(
243
+ f"p-value not significant: {p_value:.4f} >= 0.05 "
244
+ f"(n={n}, effect_size={latent.true_effect_size:.3f})."
245
+ )
246
+
247
+ # 1d. FDA compliance check
248
+ compliance = check_fda_compliance(action, latent)
249
+ if not compliance.valid:
250
+ violations.extend(compliance.violations)
251
+
252
+ passed = len(violations) == 0
253
+
254
+ # ------------------------------------------------------------------
255
+ # Overconfidence penalty
256
+ # ------------------------------------------------------------------
257
+ # A "high-confidence wrong claim" is when the agent's confidence is
258
+ # >= 0.8 but Layer 1 found violations (the claim is incorrect).
259
+ overconfidence_penalty = 0.0
260
+ if not passed and action.confidence >= _HIGH_CONFIDENCE_THRESHOLD:
261
+ # One penalty per violation that was caused by a wrong claim
262
+ overconfidence_penalty = _OVERCONFIDENCE_PENALTY * len(violations)
263
+
264
+ # ------------------------------------------------------------------
265
+ # Layer 2: Persona-scaled feedback (never overrides Layer 1 result)
266
+ # ------------------------------------------------------------------
267
+ persona = _select_persona(state.difficulty)
268
+ feedback, hint = _generate_feedback(persona, violations, passed, action, latent)
269
+
270
+ return JudgeResult(
271
+ passed=passed,
272
+ violations=violations,
273
+ feedback=feedback,
274
+ hint=hint,
275
+ overconfidence_penalty=overconfidence_penalty,
276
+ persona=persona,
277
+ )
server/logger.py CHANGED
@@ -31,9 +31,7 @@ class EpisodeLogger:
31
  episode_id: str | None = None,
32
  curriculum_tier: int = 0,
33
  ) -> None:
34
- self._log_path: Path = (
35
- log_path if log_path is not None else settings.log_path
36
- )
37
  self._episode_id: str = (
38
  episode_id if episode_id is not None else str(uuid.uuid4())
39
  )
 
31
  episode_id: str | None = None,
32
  curriculum_tier: int = 0,
33
  ) -> None:
34
+ self._log_path: Path = log_path if log_path is not None else settings.log_path
 
 
35
  self._episode_id: str = (
36
  episode_id if episode_id is not None else str(uuid.uuid4())
37
  )
server/noise_model.py CHANGED
@@ -36,6 +36,11 @@ class NoiseModel:
36
  self._seed = seed
37
  self._rng: np.random.Generator = np.random.default_rng(seed)
38
 
 
 
 
 
 
39
  def randomize(self, config: ScenarioConfig) -> ScenarioConfig:
40
  """Return a new ScenarioConfig with domain-randomized parameters.
41
 
 
36
  self._seed = seed
37
  self._rng: np.random.Generator = np.random.default_rng(seed)
38
 
39
+ @property
40
+ def rng(self) -> np.random.Generator:
41
+ """Public access to the seeded Generator."""
42
+ return self._rng
43
+
44
  def randomize(self, config: ScenarioConfig) -> ScenarioConfig:
45
  """Return a new ScenarioConfig with domain-randomized parameters.
46
 
server/phase_detector.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase Detector — classifies TrialActions into clinical workflow phases.
3
+
4
+ Clinical workflow phase order:
5
+ literature_review → hypothesis → design → enrollment →
6
+ monitoring → analysis → submission
7
+
8
+ Phase-order bonus: +0.2 for correct order (no regression, no skips)
9
+ Skip penalty: -0.3 per skipped phase
10
+
11
+ Requirements: 8.5, 9.4
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from models import ActionType, TrialAction
17
+
18
+ # Ordered list of clinical workflow phases
19
+ PHASE_ORDER: list[str] = [
20
+ "literature_review",
21
+ "hypothesis",
22
+ "design",
23
+ "enrollment",
24
+ "monitoring",
25
+ "analysis",
26
+ "submission",
27
+ ]
28
+
29
+ # Reward constants
30
+ PHASE_BONUS: float = 0.2
31
+ PHASE_SKIP_PENALTY: float = -0.3
32
+
33
+ # Mapping from ActionType to phase name.
34
+ # literature_review has no direct action — used as default for unknown.
35
+ _ACTION_TO_PHASE: dict[ActionType, str] = {
36
+ # hypothesis
37
+ ActionType.ESTIMATE_EFFECT_SIZE: "hypothesis",
38
+ ActionType.ADD_BIOMARKER_STRATIFICATION: "hypothesis",
39
+ # design
40
+ ActionType.SET_PRIMARY_ENDPOINT: "design",
41
+ ActionType.SET_SAMPLE_SIZE: "design",
42
+ ActionType.SET_INCLUSION_CRITERIA: "design",
43
+ ActionType.SET_EXCLUSION_CRITERIA: "design",
44
+ ActionType.SET_DOSING_SCHEDULE: "design",
45
+ ActionType.SET_CONTROL_ARM: "design",
46
+ ActionType.SET_RANDOMIZATION_RATIO: "design",
47
+ ActionType.SET_BLINDING: "design",
48
+ ActionType.REQUEST_PROTOCOL_AMENDMENT: "design",
49
+ # enrollment
50
+ ActionType.ENROLL_PATIENTS: "enrollment",
51
+ # monitoring
52
+ ActionType.RUN_DOSE_ESCALATION: "monitoring",
53
+ ActionType.OBSERVE_SAFETY_SIGNAL: "monitoring",
54
+ ActionType.RUN_INTERIM_ANALYSIS: "monitoring",
55
+ ActionType.MODIFY_SAMPLE_SIZE: "monitoring",
56
+ # analysis
57
+ ActionType.RUN_PRIMARY_ANALYSIS: "analysis",
58
+ ActionType.SYNTHESIZE_CONCLUSION: "analysis",
59
+ # submission
60
+ ActionType.SUBMIT_TO_FDA_REVIEW: "submission",
61
+ }
62
+
63
+
64
+ def detect_phase(action: TrialAction, history: list[str]) -> tuple[str, bool]:
65
+ """Classify a TrialAction into a clinical workflow phase.
66
+
67
+ Args:
68
+ action: The agent's action for this step.
69
+ history: List of phase names (strings) from previous steps in the episode.
70
+
71
+ Returns:
72
+ A tuple of (phase_name, phase_order_correct) where:
73
+ - phase_name is the detected phase string
74
+ - phase_order_correct is True iff the phase transition is valid
75
+ (no regression, no skipped phases)
76
+ """
77
+ phase_name = _ACTION_TO_PHASE.get(action.action_type, "literature_review")
78
+
79
+ if not history:
80
+ # First action — any phase is valid
81
+ return phase_name, True
82
+
83
+ last_phase = history[-1]
84
+ last_idx = PHASE_ORDER.index(last_phase) if last_phase in PHASE_ORDER else 0
85
+ current_idx = PHASE_ORDER.index(phase_name) if phase_name in PHASE_ORDER else 0
86
+
87
+ # Regression: going backwards is not correct
88
+ if current_idx < last_idx:
89
+ return phase_name, False
90
+
91
+ # Skipped phases: any phase between last+1 and current-1 (exclusive) is a skip
92
+ skipped = current_idx - last_idx - 1
93
+ if skipped > 0:
94
+ return phase_name, False
95
+
96
+ # Staying in same phase or advancing by exactly one — correct
97
+ return phase_name, True
98
+
99
+
100
+ def compute_phase_ordering_reward(action: TrialAction, history: list[str]) -> float:
101
+ """Compute the r_ordering reward component using phase detection.
102
+
103
+ Returns:
104
+ +PHASE_BONUS if phase order is correct.
105
+ PHASE_SKIP_PENALTY * num_skipped_phases if phases were skipped.
106
+ 0.0 if there is a regression (going backwards).
107
+ """
108
+ phase_name = _ACTION_TO_PHASE.get(action.action_type, "literature_review")
109
+
110
+ if not history:
111
+ return PHASE_BONUS
112
+
113
+ last_phase = history[-1]
114
+ last_idx = PHASE_ORDER.index(last_phase) if last_phase in PHASE_ORDER else 0
115
+ current_idx = PHASE_ORDER.index(phase_name) if phase_name in PHASE_ORDER else 0
116
+
117
+ if current_idx < last_idx:
118
+ # Regression — no bonus, no skip penalty
119
+ return 0.0
120
+
121
+ skipped = current_idx - last_idx - 1
122
+ if skipped > 0:
123
+ return PHASE_SKIP_PENALTY * skipped
124
+
125
+ return PHASE_BONUS
server/reward/reward_computer.py CHANGED
@@ -18,6 +18,7 @@ from models import (
18
  TrialLatentState,
19
  TrialResult,
20
  )
 
21
  from server.rules.fda_rules import check_fda_compliance
22
 
23
  # Reward magnitude constants
@@ -29,13 +30,13 @@ _TERMINAL_CALIBRATION = 5.0
29
  _INFO_GAIN_BASE = 0.5
30
  _EFFICIENCY_SCALE = 2.0
31
  _NOVELTY_BASE = 0.2
32
- _ORDERING_BONUS = 0.2
33
 
34
 
35
  def compute_reward(
36
  action: TrialAction,
37
  latent: TrialLatentState,
38
  result: TrialResult,
 
39
  ) -> RewardBreakdown:
40
  """Compute all eight reward components for a single step.
41
 
@@ -46,6 +47,7 @@ def compute_reward(
46
  action: The agent's action.
47
  latent: Hidden ground-truth + episode tracking state.
48
  result: The simulated trial result.
 
49
 
50
  Returns:
51
  A RewardBreakdown with all eight keys populated.
@@ -54,11 +56,9 @@ def compute_reward(
54
 
55
  r_validity = _VALIDITY_VALID if compliance.valid else _VALIDITY_INVALID
56
  r_penalty = (
57
- _PENALTY_INVALID * len(compliance.violations)
58
- if not compliance.valid
59
- else 0.0
60
  )
61
- r_ordering = _ordering_reward(action, latent)
62
  r_info_gain = _info_gain_reward(action, result)
63
  r_efficiency = _efficiency_reward(latent)
64
  r_novelty = _novelty_reward(action, latent)
@@ -81,18 +81,11 @@ def compute_reward(
81
  # Component helpers
82
  # ---------------------------------------------------------------------------
83
 
84
- def _ordering_reward(action: TrialAction, latent: TrialLatentState) -> float:
85
- """Bonus for actions that match the expected clinical workflow phase."""
86
- from server.rules.fda_rules import TRANSITION_TABLE
87
- permitted = TRANSITION_TABLE.get(latent.episode_phase, set())
88
- if action.action_type in permitted:
89
- return _ORDERING_BONUS
90
- return 0.0
91
-
92
 
93
  def _info_gain_reward(action: TrialAction, result: TrialResult) -> float:
94
  """Reward for information-gathering actions that produce useful results."""
95
  from models import ActionType
 
96
  info_actions = {
97
  ActionType.ESTIMATE_EFFECT_SIZE,
98
  ActionType.OBSERVE_SAFETY_SIGNAL,
@@ -110,9 +103,7 @@ def _efficiency_reward(latent: TrialLatentState) -> float:
110
  initial_budget = 1_000_000.0
111
  if initial_budget <= 0:
112
  return 0.0
113
- budget_fraction = min(
114
- max(latent.budget_remaining / initial_budget, 0.0), 1.0
115
- )
116
  return _EFFICIENCY_SCALE * budget_fraction
117
 
118
 
@@ -123,9 +114,7 @@ def _novelty_reward(action: TrialAction, latent: TrialLatentState) -> float:
123
  return 0.0
124
 
125
 
126
- def _terminal_success_reward(
127
- latent: TrialLatentState, result: TrialResult
128
- ) -> float:
129
  """Positive reward when the episode ends with a successful trial (req 6.4)."""
130
  if latent.trial_complete and result.success and result.failure_reason is None:
131
  return _TERMINAL_SUCCESS
@@ -150,6 +139,6 @@ def _terminal_calibration_reward(
150
  centre_error = abs(ci_centre - true_effect)
151
  calibration_score = max(0.0, 1.0 - centre_error)
152
  width_penalty = min(ci_width, 1.0)
153
- calibration_score *= (1.0 - width_penalty * 0.5)
154
 
155
  return _TERMINAL_CALIBRATION * calibration_score
 
18
  TrialLatentState,
19
  TrialResult,
20
  )
21
+ from server.phase_detector import compute_phase_ordering_reward
22
  from server.rules.fda_rules import check_fda_compliance
23
 
24
  # Reward magnitude constants
 
30
  _INFO_GAIN_BASE = 0.5
31
  _EFFICIENCY_SCALE = 2.0
32
  _NOVELTY_BASE = 0.2
 
33
 
34
 
35
  def compute_reward(
36
  action: TrialAction,
37
  latent: TrialLatentState,
38
  result: TrialResult,
39
+ phase_history: list[str] | None = None,
40
  ) -> RewardBreakdown:
41
  """Compute all eight reward components for a single step.
42
 
 
47
  action: The agent's action.
48
  latent: Hidden ground-truth + episode tracking state.
49
  result: The simulated trial result.
50
+ phase_history: List of phase names from previous steps (for r_ordering).
51
 
52
  Returns:
53
  A RewardBreakdown with all eight keys populated.
 
56
 
57
  r_validity = _VALIDITY_VALID if compliance.valid else _VALIDITY_INVALID
58
  r_penalty = (
59
+ _PENALTY_INVALID * len(compliance.violations) if not compliance.valid else 0.0
 
 
60
  )
61
+ r_ordering = compute_phase_ordering_reward(action, phase_history or [])
62
  r_info_gain = _info_gain_reward(action, result)
63
  r_efficiency = _efficiency_reward(latent)
64
  r_novelty = _novelty_reward(action, latent)
 
81
  # Component helpers
82
  # ---------------------------------------------------------------------------
83
 
 
 
 
 
 
 
 
 
84
 
85
  def _info_gain_reward(action: TrialAction, result: TrialResult) -> float:
86
  """Reward for information-gathering actions that produce useful results."""
87
  from models import ActionType
88
+
89
  info_actions = {
90
  ActionType.ESTIMATE_EFFECT_SIZE,
91
  ActionType.OBSERVE_SAFETY_SIGNAL,
 
103
  initial_budget = 1_000_000.0
104
  if initial_budget <= 0:
105
  return 0.0
106
+ budget_fraction = min(max(latent.budget_remaining / initial_budget, 0.0), 1.0)
 
 
107
  return _EFFICIENCY_SCALE * budget_fraction
108
 
109
 
 
114
  return 0.0
115
 
116
 
117
+ def _terminal_success_reward(latent: TrialLatentState, result: TrialResult) -> float:
 
 
118
  """Positive reward when the episode ends with a successful trial (req 6.4)."""
119
  if latent.trial_complete and result.success and result.failure_reason is None:
120
  return _TERMINAL_SUCCESS
 
139
  centre_error = abs(ci_centre - true_effect)
140
  calibration_score = max(0.0, 1.0 - centre_error)
141
  width_penalty = min(ci_width, 1.0)
142
+ calibration_score *= 1.0 - width_penalty * 0.5
143
 
144
  return _TERMINAL_CALIBRATION * calibration_score
server/reward/shaping.py CHANGED
@@ -32,9 +32,7 @@ def _budget_efficiency(
32
  return min(max(latent.budget_remaining / initial_budget, 0.0), 1.0)
33
 
34
 
35
- def potential(
36
- latent: TrialLatentState, initial_budget: float = 1_000_000.0
37
- ) -> float:
38
  """φ(s) = milestone_completion × budget_efficiency."""
39
  return _milestone_completion(latent) * _budget_efficiency(latent, initial_budget)
40
 
 
32
  return min(max(latent.budget_remaining / initial_budget, 0.0), 1.0)
33
 
34
 
35
+ def potential(latent: TrialLatentState, initial_budget: float = 1_000_000.0) -> float:
 
 
36
  """φ(s) = milestone_completion × budget_efficiency."""
37
  return _milestone_completion(latent) * _budget_efficiency(latent, initial_budget)
38
 
server/rules/prerequisite_rules.py CHANGED
@@ -20,9 +20,7 @@ _HISTORY_PREREQUISITES: dict[ActionType, list[ActionType]] = {
20
  }
21
 
22
 
23
- def check_prerequisites(
24
- action: TrialAction, latent: TrialLatentState
25
- ) -> list[str]:
26
  """Return a list of prerequisite violation strings for *action* given *latent*.
27
 
28
  Returns an empty list when all prerequisites are satisfied.
 
20
  }
21
 
22
 
23
+ def check_prerequisites(action: TrialAction, latent: TrialLatentState) -> list[str]:
 
 
24
  """Return a list of prerequisite violation strings for *action* given *latent*.
25
 
26
  Returns an empty list when all prerequisites are satisfied.
server/simulator/__init__.py CHANGED
@@ -2,5 +2,5 @@
2
  simulator — Trial outcome simulation and power calculation.
3
 
4
  Provides simulate_trial, calculate_power (with episode-scoped cache),
5
- compute_reward, and the seeded hidden-state generator.
6
  """
 
2
  simulator — Trial outcome simulation and power calculation.
3
 
4
  Provides simulate_trial, calculate_power (with episode-scoped cache),
5
+ compute_reward, TransitionEngine, and the seeded hidden-state generator.
6
  """
server/simulator/output_generator.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OutputGenerator — produces a noisy TrialObservation from a TrialLatentState.
3
+
4
+ Follows the Bio Experiment pattern: TransitionEngine updates hidden state,
5
+ OutputGenerator produces noisy observations from it. Agent never sees clean
6
+ hidden values.
7
+
8
+ Key responsibilities:
9
+ - Inject measurement noise and site variability via NoiseModel's seeded RNG
10
+ - Populate phase_data with noisy (not raw) experimental results
11
+ - Populate resource_status from latent state resource fields
12
+ - Populate available_actions based on current milestone flags and phase
13
+ - Never expose true_effect_size, true_side_effect_rate, or other hidden values
14
+ directly — always add noise before returning to the agent
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import numpy as np
20
+
21
+ from models import ActionType, TrialLatentState, TrialObservation, TrialState
22
+ from server.noise_model import NoiseModel
23
+ from server.rules.fda_rules import TRANSITION_TABLE
24
+ from server.rules.prerequisite_rules import _HISTORY_PREREQUISITES
25
+
26
+
27
+ class OutputGenerator:
28
+ """Produces a noisy TrialObservation from a TrialLatentState.
29
+
30
+ The agent never sees clean hidden values — all experimental results are
31
+ perturbed by measurement noise and site variability before being returned.
32
+
33
+ Args:
34
+ noise_model: Seeded NoiseModel used to draw observation noise.
35
+ """
36
+
37
+ def __init__(self, noise_model: NoiseModel) -> None:
38
+ self._noise_model = noise_model
39
+
40
+ def generate(
41
+ self,
42
+ latent: TrialLatentState,
43
+ trial_state: TrialState,
44
+ *,
45
+ steps_taken: int,
46
+ max_steps: int,
47
+ rule_violations: list[str],
48
+ done: bool,
49
+ reward: float,
50
+ scenario_description: str,
51
+ hint: str = "",
52
+ ) -> TrialObservation:
53
+ """Generate a noisy TrialObservation from the current latent state.
54
+
55
+ Args:
56
+ latent: Updated hidden state from TransitionEngine.
57
+ trial_state: Episode metadata (difficulty, curriculum tier, etc.).
58
+ steps_taken: Number of steps taken so far in the episode.
59
+ max_steps: Maximum steps allowed in the episode.
60
+ rule_violations: List of rule violation strings from this step.
61
+ done: Whether the episode is finished.
62
+ reward: Reward signal for this step.
63
+ scenario_description: Human-readable scenario description.
64
+ hint: Optional hint string (only populated at junior difficulty).
65
+
66
+ Returns:
67
+ A TrialObservation with noisy phase_data, resource_status, and
68
+ available_actions. Raw hidden values are never included.
69
+ """
70
+ rng = self._noise_model.rng
71
+
72
+ phase_data = self._build_phase_data(latent, rng)
73
+ resource_status = self._build_resource_status(latent)
74
+ available_actions = self._build_available_actions(latent)
75
+
76
+ return TrialObservation(
77
+ scenario_description=scenario_description,
78
+ phase_data=phase_data,
79
+ resource_status=resource_status,
80
+ rule_violations=rule_violations,
81
+ available_actions=available_actions,
82
+ steps_taken=steps_taken,
83
+ max_steps=max_steps,
84
+ hint=hint,
85
+ done=done,
86
+ reward=reward,
87
+ )
88
+
89
+ # ------------------------------------------------------------------
90
+ # Private helpers
91
+ # ------------------------------------------------------------------
92
+
93
+ def _build_phase_data(
94
+ self,
95
+ latent: TrialLatentState,
96
+ rng: "np.random.Generator",
97
+ ) -> dict:
98
+ """Build noisy phase_data dict — never exposes raw hidden values.
99
+
100
+ Measurement noise (latent.measurement_noise) is applied to effect-size
101
+ estimates. Site variability (latent.site_variability) is applied to
102
+ adverse-event-rate estimates.
103
+ """
104
+ import numpy as np # local import to keep module-level deps minimal
105
+
106
+ noise_std = max(latent.measurement_noise, 1e-6)
107
+ site_std = max(latent.site_variability, 1e-6)
108
+
109
+ phase_data: dict = {
110
+ "current_phase": latent.episode_phase,
111
+ "patients_enrolled": latent.patients_enrolled,
112
+ # Milestones — these are observable flags, not hidden values
113
+ "phase_i_complete": latent.phase_i_complete,
114
+ "mtd_identified": latent.mtd_identified,
115
+ "effect_estimated": latent.effect_estimated,
116
+ "protocol_submitted": latent.protocol_submitted,
117
+ "interim_complete": latent.interim_complete,
118
+ "trial_complete": latent.trial_complete,
119
+ }
120
+
121
+ # Noisy effect-size estimate — only available after ESTIMATE_EFFECT_SIZE
122
+ if latent.effect_estimated:
123
+ noisy_effect = float(latent.true_effect_size + rng.normal(0.0, noise_std))
124
+ phase_data["observed_effect_size"] = round(noisy_effect, 4)
125
+
126
+ # Noisy confidence interval width (derived from noise level)
127
+ ci_half_width = float(rng.normal(noise_std * 2, noise_std * 0.5))
128
+ ci_half_width = max(ci_half_width, 0.01)
129
+ phase_data["effect_size_ci"] = (
130
+ round(noisy_effect - ci_half_width, 4),
131
+ round(noisy_effect + ci_half_width, 4),
132
+ )
133
+
134
+ # Noisy adverse-event rate — only available after OBSERVE_SAFETY_SIGNAL
135
+ # or RUN_DOSE_ESCALATION
136
+ if (
137
+ latent.phase_i_complete
138
+ or ActionType.OBSERVE_SAFETY_SIGNAL.value in latent.action_history
139
+ ):
140
+ noisy_ae_rate = float(
141
+ latent.true_side_effect_rate + rng.normal(0.0, site_std)
142
+ )
143
+ noisy_ae_rate = float(np.clip(noisy_ae_rate, 0.0, 1.0))
144
+ phase_data["observed_adverse_event_rate"] = round(noisy_ae_rate, 4)
145
+
146
+ # Noisy placebo response — only available after interim or primary analysis
147
+ if latent.interim_complete or latent.trial_complete:
148
+ noisy_placebo = float(
149
+ latent.placebo_response_rate + rng.normal(0.0, noise_std)
150
+ )
151
+ noisy_placebo = float(np.clip(noisy_placebo, 0.0, 1.0))
152
+ phase_data["observed_placebo_response"] = round(noisy_placebo, 4)
153
+
154
+ # Noisy dose-response curve — only available after Phase I
155
+ if latent.phase_i_complete and latent.true_dose_response:
156
+ noisy_dose_response: dict[str, float] = {}
157
+ for dose, response in latent.true_dose_response.items():
158
+ noisy_resp = float(response + rng.normal(0.0, noise_std))
159
+ noisy_resp = float(np.clip(noisy_resp, 0.0, 1.0))
160
+ noisy_dose_response[str(dose)] = round(noisy_resp, 4)
161
+ phase_data["observed_dose_response"] = noisy_dose_response
162
+
163
+ # Dropout rate estimate — noisy, only after enrollment begins
164
+ if latent.patients_enrolled > 0:
165
+ noisy_dropout = float(
166
+ latent.dropout_rate + rng.normal(0.0, noise_std * 0.5)
167
+ )
168
+ noisy_dropout = float(np.clip(noisy_dropout, 0.0, 1.0))
169
+ phase_data["observed_dropout_rate"] = round(noisy_dropout, 4)
170
+
171
+ # Responder population hint — only after biomarker stratification
172
+ if ActionType.ADD_BIOMARKER_STRATIFICATION.value in latent.action_history:
173
+ # Reveal population label but NOT the true criteria (hidden)
174
+ phase_data["responder_population_hint"] = latent.true_responder_population
175
+
176
+ return phase_data
177
+
178
+ def _build_resource_status(self, latent: TrialLatentState) -> dict:
179
+ """Build resource_status from latent state resource fields."""
180
+ return {
181
+ "budget_remaining": latent.budget_remaining,
182
+ "time_remaining_days": latent.time_remaining_days,
183
+ "patients_enrolled": latent.patients_enrolled,
184
+ }
185
+
186
+ def _build_available_actions(self, latent: TrialLatentState) -> list[str]:
187
+ """Return the list of valid action strings given current milestone flags.
188
+
189
+ Filters the phase-permitted actions through prerequisite checks so the
190
+ agent only sees actions it can actually take right now.
191
+ """
192
+ phase_permitted: set[ActionType] = TRANSITION_TABLE.get(
193
+ latent.episode_phase, set()
194
+ )
195
+
196
+ available: list[str] = []
197
+ for action_type in sorted(phase_permitted, key=lambda a: a.value):
198
+ if self._prerequisites_met(action_type, latent):
199
+ available.append(action_type.value)
200
+
201
+ return available
202
+
203
+ def _prerequisites_met(
204
+ self, action_type: ActionType, latent: TrialLatentState
205
+ ) -> bool:
206
+ """Return True if all prerequisites for *action_type* are satisfied."""
207
+ # History-based prerequisites
208
+ required_actions = _HISTORY_PREREQUISITES.get(action_type, [])
209
+ for required in required_actions:
210
+ if required.value not in latent.action_history:
211
+ return False
212
+
213
+ # State-flag prerequisites (mirrors prerequisite_rules.py logic)
214
+ if action_type == ActionType.REQUEST_PROTOCOL_AMENDMENT:
215
+ if not latent.protocol_submitted:
216
+ return False
217
+
218
+ if action_type == ActionType.SUBMIT_TO_FDA_REVIEW:
219
+ if not latent.protocol_submitted or not latent.phase_i_complete:
220
+ return False
221
+
222
+ if action_type == ActionType.RUN_PRIMARY_ANALYSIS:
223
+ if not latent.interim_complete:
224
+ return False
225
+
226
+ if action_type == ActionType.RUN_INTERIM_ANALYSIS:
227
+ if latent.patients_enrolled <= 0:
228
+ return False
229
+
230
+ if action_type == ActionType.MODIFY_SAMPLE_SIZE:
231
+ if ActionType.SET_SAMPLE_SIZE.value not in latent.action_history:
232
+ return False
233
+
234
+ if action_type == ActionType.SYNTHESIZE_CONCLUSION:
235
+ if not latent.trial_complete:
236
+ return False
237
+
238
+ return True
server/simulator/transition_engine.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TransitionEngine — mutates TrialLatentState per action.
3
+
4
+ Follows the Bio Experiment pattern: TransitionEngine updates hidden state,
5
+ OutputGenerator produces noisy observations from it. Agent never sees clean
6
+ hidden values.
7
+
8
+ Key responsibilities:
9
+ - Enroll patients (ENROLL_PATIENTS)
10
+ - Spend budget and advance time
11
+ - Record adverse events
12
+ - Set milestone flags (phase_i_complete, mtd_identified, effect_estimated,
13
+ protocol_submitted, interim_complete, trial_complete)
14
+ - Degrade data quality on soft violations
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import random
20
+
21
+ from models import ActionType, TrialAction, TrialLatentState
22
+
23
+
24
+ class TransitionEngine:
25
+ """Mutates TrialLatentState in response to agent actions.
26
+
27
+ All state transitions are deterministic given the same seed and action
28
+ sequence (reproducibility requirement 9.2).
29
+ """
30
+
31
+ # Cost and time constants (per action type)
32
+ _ACTION_COSTS: dict[ActionType, float] = {
33
+ ActionType.SET_PRIMARY_ENDPOINT: 5_000.0,
34
+ ActionType.SET_SAMPLE_SIZE: 2_000.0,
35
+ ActionType.SET_INCLUSION_CRITERIA: 3_000.0,
36
+ ActionType.SET_EXCLUSION_CRITERIA: 3_000.0,
37
+ ActionType.SET_DOSING_SCHEDULE: 10_000.0,
38
+ ActionType.SET_CONTROL_ARM: 5_000.0,
39
+ ActionType.SET_RANDOMIZATION_RATIO: 2_000.0,
40
+ ActionType.SET_BLINDING: 4_000.0,
41
+ ActionType.RUN_DOSE_ESCALATION: 50_000.0,
42
+ ActionType.OBSERVE_SAFETY_SIGNAL: 15_000.0,
43
+ ActionType.ESTIMATE_EFFECT_SIZE: 20_000.0,
44
+ ActionType.RUN_INTERIM_ANALYSIS: 30_000.0,
45
+ ActionType.MODIFY_SAMPLE_SIZE: 5_000.0,
46
+ ActionType.ADD_BIOMARKER_STRATIFICATION: 25_000.0,
47
+ ActionType.SUBMIT_TO_FDA_REVIEW: 100_000.0,
48
+ ActionType.REQUEST_PROTOCOL_AMENDMENT: 15_000.0,
49
+ ActionType.RUN_PRIMARY_ANALYSIS: 50_000.0,
50
+ ActionType.SYNTHESIZE_CONCLUSION: 10_000.0,
51
+ ActionType.ENROLL_PATIENTS: 0.0, # cost computed per patient
52
+ }
53
+
54
+ _ACTION_TIME_DAYS: dict[ActionType, int] = {
55
+ ActionType.SET_PRIMARY_ENDPOINT: 7,
56
+ ActionType.SET_SAMPLE_SIZE: 3,
57
+ ActionType.SET_INCLUSION_CRITERIA: 5,
58
+ ActionType.SET_EXCLUSION_CRITERIA: 5,
59
+ ActionType.SET_DOSING_SCHEDULE: 14,
60
+ ActionType.SET_CONTROL_ARM: 7,
61
+ ActionType.SET_RANDOMIZATION_RATIO: 3,
62
+ ActionType.SET_BLINDING: 5,
63
+ ActionType.RUN_DOSE_ESCALATION: 90,
64
+ ActionType.OBSERVE_SAFETY_SIGNAL: 30,
65
+ ActionType.ESTIMATE_EFFECT_SIZE: 45,
66
+ ActionType.RUN_INTERIM_ANALYSIS: 60,
67
+ ActionType.MODIFY_SAMPLE_SIZE: 7,
68
+ ActionType.ADD_BIOMARKER_STRATIFICATION: 30,
69
+ ActionType.SUBMIT_TO_FDA_REVIEW: 180,
70
+ ActionType.REQUEST_PROTOCOL_AMENDMENT: 30,
71
+ ActionType.RUN_PRIMARY_ANALYSIS: 90,
72
+ ActionType.SYNTHESIZE_CONCLUSION: 14,
73
+ ActionType.ENROLL_PATIENTS: 0, # time computed per patient
74
+ }
75
+
76
+ # Cost per patient enrolled (varies by disease area complexity)
77
+ _COST_PER_PATIENT: float = 10_000.0
78
+ _DAYS_PER_PATIENT: float = 2.0
79
+
80
+ def __init__(self) -> None:
81
+ """Initialize the TransitionEngine."""
82
+ pass
83
+
84
+ def apply_transition(
85
+ self, latent: TrialLatentState, action: TrialAction
86
+ ) -> TrialLatentState:
87
+ """Apply *action* to *latent* and return the updated state.
88
+
89
+ Does NOT mutate the input latent state — returns a new copy with
90
+ updated fields.
91
+
92
+ Args:
93
+ latent: Current hidden state.
94
+ action: Agent action to apply.
95
+
96
+ Returns:
97
+ Updated TrialLatentState with mutated fields.
98
+ """
99
+ # Create a mutable copy
100
+ updated = latent.model_copy(deep=True)
101
+
102
+ # Update action history
103
+ updated.action_history.append(action.action_type.value)
104
+
105
+ # Compute step-specific RNG
106
+ step_index = len(updated.action_history)
107
+ rng = random.Random(latent.seed ^ step_index)
108
+
109
+ # --- Budget and time consumption ---
110
+ base_cost = self._ACTION_COSTS.get(action.action_type, 0.0)
111
+ base_time = self._ACTION_TIME_DAYS.get(action.action_type, 0)
112
+
113
+ if action.action_type == ActionType.ENROLL_PATIENTS:
114
+ n_patients = action.parameters.get("n_patients", 0)
115
+ base_cost = n_patients * self._COST_PER_PATIENT
116
+ base_time = int(n_patients * self._DAYS_PER_PATIENT)
117
+ updated.patients_enrolled += n_patients
118
+
119
+ updated.budget_remaining -= base_cost
120
+ updated.time_remaining_days -= base_time
121
+
122
+ # --- Milestone flag updates ---
123
+ if action.action_type == ActionType.RUN_DOSE_ESCALATION:
124
+ updated.phase_i_complete = True
125
+ updated.mtd_identified = True
126
+
127
+ if action.action_type == ActionType.ESTIMATE_EFFECT_SIZE:
128
+ updated.effect_estimated = True
129
+
130
+ if action.action_type == ActionType.SUBMIT_TO_FDA_REVIEW:
131
+ updated.protocol_submitted = True
132
+
133
+ if action.action_type == ActionType.RUN_INTERIM_ANALYSIS:
134
+ updated.interim_complete = True
135
+
136
+ if action.action_type == ActionType.RUN_PRIMARY_ANALYSIS:
137
+ updated.trial_complete = True
138
+
139
+ # --- Soft violation: degrade data quality ---
140
+ # If action confidence is low (< 0.5), increase measurement noise
141
+ if action.confidence < 0.5:
142
+ degradation_factor = 1.0 + (0.5 - action.confidence)
143
+ updated.measurement_noise = min(
144
+ updated.measurement_noise * degradation_factor, 0.5
145
+ )
146
+
147
+ # If budget is negative (soft violation), degrade site variability
148
+ if updated.budget_remaining < 0:
149
+ updated.site_variability = min(updated.site_variability * 1.2, 0.5)
150
+
151
+ # If time is negative (soft violation), increase dropout rate
152
+ if updated.time_remaining_days < 0:
153
+ updated.dropout_rate = min(updated.dropout_rate * 1.15, 0.8)
154
+
155
+ # --- Adverse event recording (stochastic) ---
156
+ # On certain actions, record adverse events based on true_side_effect_rate
157
+ if action.action_type in {
158
+ ActionType.ENROLL_PATIENTS,
159
+ ActionType.OBSERVE_SAFETY_SIGNAL,
160
+ ActionType.RUN_DOSE_ESCALATION,
161
+ }:
162
+ # Adverse events increase site variability slightly
163
+ if rng.random() < updated.true_side_effect_rate:
164
+ updated.adverse_events += 1
165
+ updated.site_variability = min(updated.site_variability + 0.02, 0.5)
166
+
167
+ return updated
server/simulator/trial_simulator.py CHANGED
@@ -84,6 +84,7 @@ def simulate_trial(
84
  se = 1.0 / math.sqrt(n_per_arm)
85
  z_stat = observed_effect / se if se > 0 else 0.0
86
  from scipy.stats import norm
 
87
  p_value = float(2.0 * norm.sf(abs(z_stat)))
88
  else:
89
  p_value = 1.0
@@ -93,6 +94,7 @@ def simulate_trial(
93
 
94
  if n_per_arm > 0:
95
  from scipy.stats import norm as _norm
 
96
  z_95 = _norm.ppf(0.975)
97
  se = 1.0 / math.sqrt(n_per_arm)
98
  ci_low = observed_effect - z_95 * se
@@ -104,9 +106,7 @@ def simulate_trial(
104
  0.0,
105
  latent.site_variability if latent.site_variability > 0 else 0.01,
106
  )
107
- adverse_event_rate = min(
108
- max(latent.true_side_effect_rate + ae_noise, 0.0), 1.0
109
- )
110
 
111
  return TrialResult(
112
  p_value=p_value,
 
84
  se = 1.0 / math.sqrt(n_per_arm)
85
  z_stat = observed_effect / se if se > 0 else 0.0
86
  from scipy.stats import norm
87
+
88
  p_value = float(2.0 * norm.sf(abs(z_stat)))
89
  else:
90
  p_value = 1.0
 
94
 
95
  if n_per_arm > 0:
96
  from scipy.stats import norm as _norm
97
+
98
  z_95 = _norm.ppf(0.975)
99
  se = 1.0 / math.sqrt(n_per_arm)
100
  ci_low = observed_effect - z_95 * se
 
106
  0.0,
107
  latent.site_variability if latent.site_variability > 0 else 0.01,
108
  )
109
+ adverse_event_rate = min(max(latent.true_side_effect_rate + ae_noise, 0.0), 1.0)
 
 
110
 
111
  return TrialResult(
112
  p_value=p_value,
tests/test_curriculum_controller.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for server/curriculum/controller.py
3
+
4
+ Verifies:
5
+ - advance_curriculum mastery logic (70% → graduate, 90% → fast-track)
6
+ - select_scenario tier mapping
7
+ - Edge cases (empty history, max tier, clamping)
8
+ """
9
+
10
+ import numpy as np
11
+
12
+ from server.curriculum.controller import (
13
+ MAX_TIER,
14
+ EpisodeMetrics,
15
+ advance_curriculum,
16
+ select_scenario,
17
+ )
18
+ from server.curriculum.scenarios import (
19
+ AUTOIMMUNE_BIOLOGIC,
20
+ CNS_DEPRESSION,
21
+ RARE_DISEASE_ORPHAN,
22
+ SOLID_TUMOR_CHEMO,
23
+ WARMUP,
24
+ )
25
+
26
+ # ── advance_curriculum tests ──────────────────────────────────────────────────
27
+
28
+
29
+ def test_advance_curriculum_empty_history():
30
+ """Empty history → stay at current tier."""
31
+ metrics = EpisodeMetrics(success=True, episode_history=[])
32
+ assert advance_curriculum(0, metrics) == 0
33
+ assert advance_curriculum(2, metrics) == 2
34
+
35
+
36
+ def test_advance_curriculum_no_mastery():
37
+ """Below 70% success → stay at current tier."""
38
+ # 6/10 = 60% → no graduation
39
+ history = [True, False, True, False, True, False, True, False, True, False]
40
+ metrics = EpisodeMetrics(success=False, episode_history=history)
41
+ assert advance_curriculum(1, metrics) == 1
42
+
43
+
44
+ def test_advance_curriculum_normal_graduation():
45
+ """70%+ rolling success → advance one tier."""
46
+ # 7/10 = 70% → graduate
47
+ history = [True, True, True, True, True, True, True, False, False, False]
48
+ metrics = EpisodeMetrics(success=False, episode_history=history)
49
+ assert advance_curriculum(0, metrics) == 1
50
+ assert advance_curriculum(2, metrics) == 3
51
+
52
+
53
+ def test_advance_curriculum_fast_track():
54
+ """90%+ success after ≥3 episodes → skip one tier (advance by 2)."""
55
+ # 9/10 = 90% → fast-track
56
+ history = [True, True, True, True, True, True, True, True, True, False]
57
+ metrics = EpisodeMetrics(success=False, episode_history=history)
58
+ assert advance_curriculum(0, metrics) == 2 # skip tier 1
59
+ assert advance_curriculum(1, metrics) == 3 # skip tier 2
60
+
61
+
62
+ def test_advance_curriculum_fast_track_requires_min_episodes():
63
+ """Fast-track requires at least 3 episodes."""
64
+ # 2 episodes, 100% success → not enough for fast-track
65
+ history = [True, True]
66
+ metrics = EpisodeMetrics(success=True, episode_history=history)
67
+ # Should not fast-track (only 2 episodes), but 100% ≥ 70% → normal graduate
68
+ assert advance_curriculum(0, metrics) == 1
69
+
70
+ # 3 episodes, 100% success → fast-track
71
+ history = [True, True, True]
72
+ metrics = EpisodeMetrics(success=True, episode_history=history)
73
+ assert advance_curriculum(0, metrics) == 2
74
+
75
+
76
+ def test_advance_curriculum_max_tier_clamp():
77
+ """Cannot advance beyond MAX_TIER (4)."""
78
+ history = [True] * 10 # 100% success
79
+ metrics = EpisodeMetrics(success=True, episode_history=history)
80
+ assert advance_curriculum(MAX_TIER, metrics) == MAX_TIER
81
+ assert advance_curriculum(MAX_TIER - 1, metrics) == MAX_TIER # fast-track clamped
82
+
83
+
84
+ def test_advance_curriculum_rolling_window():
85
+ """Only the most recent 10 episodes count for rolling rate."""
86
+ # 20 episodes: first 10 are all False, last 10 are 9 True + 1 False
87
+ # Rolling window (last 10) = 9/10 = 90% → fast-track
88
+ history = [False] * 10 + [True] * 9 + [False]
89
+ metrics = EpisodeMetrics(success=False, episode_history=history)
90
+ assert advance_curriculum(0, metrics) == 2
91
+
92
+
93
+ def test_advance_curriculum_exactly_70_percent():
94
+ """Exactly 70% success → should graduate."""
95
+ history = [True] * 7 + [False] * 3
96
+ metrics = EpisodeMetrics(success=False, episode_history=history)
97
+ assert advance_curriculum(1, metrics) == 2
98
+
99
+
100
+ def test_advance_curriculum_exactly_90_percent():
101
+ """Exactly 90% success after ≥3 episodes → fast-track."""
102
+ history = [True] * 9 + [False]
103
+ metrics = EpisodeMetrics(success=False, episode_history=history)
104
+ assert advance_curriculum(0, metrics) == 2
105
+
106
+
107
+ # ── select_scenario tests ─────────────────────────────────────────────────────
108
+
109
+
110
+ def test_select_scenario_tier_mapping():
111
+ """Each tier maps to the correct ScenarioConfig."""
112
+ rng = np.random.default_rng(42)
113
+ assert select_scenario(0, rng) == WARMUP
114
+ assert select_scenario(1, rng) == SOLID_TUMOR_CHEMO
115
+ assert select_scenario(2, rng) == AUTOIMMUNE_BIOLOGIC
116
+ assert select_scenario(3, rng) == CNS_DEPRESSION
117
+ assert select_scenario(4, rng) == RARE_DISEASE_ORPHAN
118
+
119
+
120
+ def test_select_scenario_clamping():
121
+ """Out-of-range tiers are clamped to [MIN_TIER, MAX_TIER]."""
122
+ rng = np.random.default_rng(42)
123
+ # Below MIN_TIER → clamp to 0
124
+ assert select_scenario(-1, rng) == WARMUP
125
+ assert select_scenario(-100, rng) == WARMUP
126
+ # Above MAX_TIER → clamp to 4
127
+ assert select_scenario(5, rng) == RARE_DISEASE_ORPHAN
128
+ assert select_scenario(100, rng) == RARE_DISEASE_ORPHAN
129
+
130
+
131
+ def test_select_scenario_deterministic():
132
+ """Same tier + rng seed → same scenario (currently deterministic anyway)."""
133
+ rng1 = np.random.default_rng(42)
134
+ rng2 = np.random.default_rng(42)
135
+ assert select_scenario(2, rng1) == select_scenario(2, rng2)
136
+
137
+
138
+ # ── Integration test: full curriculum progression ─────────────────────────────
139
+
140
+
141
+ def test_full_curriculum_progression():
142
+ """Simulate a full curriculum progression from tier 0 → 4."""
143
+ tier = 0
144
+ history: list[bool] = []
145
+
146
+ # Tier 0 → 1 (normal graduation at 70%)
147
+ for _ in range(7):
148
+ history.append(True)
149
+ for _ in range(3):
150
+ history.append(False)
151
+ metrics = EpisodeMetrics(success=False, episode_history=history)
152
+ tier = advance_curriculum(tier, metrics)
153
+ assert tier == 1
154
+
155
+ # Tier 1 → 3 (fast-track at 90%)
156
+ history = [True] * 9 + [False]
157
+ metrics = EpisodeMetrics(success=False, episode_history=history)
158
+ tier = advance_curriculum(tier, metrics)
159
+ assert tier == 3
160
+
161
+ # Tier 3 → 4 (normal graduation)
162
+ history = [True] * 7 + [False] * 3
163
+ metrics = EpisodeMetrics(success=False, episode_history=history)
164
+ tier = advance_curriculum(tier, metrics)
165
+ assert tier == 4
166
+
167
+ # Tier 4 → 4 (max tier, cannot advance)
168
+ history = [True] * 10
169
+ metrics = EpisodeMetrics(success=True, episode_history=history)
170
+ tier = advance_curriculum(tier, metrics)
171
+ assert tier == 4
tests/test_episode_logger_wiring.py CHANGED
@@ -38,9 +38,7 @@ class TestLoggerCreatedOnReset:
38
  def test_logger_exists_after_reset(self, manager: EpisodeManager) -> None:
39
  assert manager._logger is not None
40
 
41
- def test_logger_replaced_on_second_reset(
42
- self, manager: EpisodeManager
43
- ) -> None:
44
  first_id = manager._logger.episode_id
45
  manager.reset()
46
  second_id = manager._logger.episode_id
@@ -53,9 +51,7 @@ class TestLoggerCreatedOnReset:
53
  class TestLogStepCalledOnStep:
54
  """Requirement 7.1: log_step() is called for every step."""
55
 
56
- def test_log_step_called_for_invalid_action(
57
- self, manager: EpisodeManager
58
- ) -> None:
59
  mock_logger = MagicMock()
60
  manager._logger = mock_logger
61
 
 
38
  def test_logger_exists_after_reset(self, manager: EpisodeManager) -> None:
39
  assert manager._logger is not None
40
 
41
+ def test_logger_replaced_on_second_reset(self, manager: EpisodeManager) -> None:
 
 
42
  first_id = manager._logger.episode_id
43
  manager.reset()
44
  second_id = manager._logger.episode_id
 
51
  class TestLogStepCalledOnStep:
52
  """Requirement 7.1: log_step() is called for every step."""
53
 
54
+ def test_log_step_called_for_invalid_action(self, manager: EpisodeManager) -> None:
 
 
55
  mock_logger = MagicMock()
56
  manager._logger = mock_logger
57
 
tests/test_episode_manager_compliance.py CHANGED
@@ -33,17 +33,13 @@ def manager() -> EpisodeManager:
33
  class TestInvalidActionReturnsNegativeRValidity:
34
  """Requirement 10.1: invalid actions → negative r_validity, latent unchanged."""
35
 
36
- def test_invalid_action_r_validity_negative(
37
- self, manager: EpisodeManager
38
- ) -> None:
39
  # SUBMIT_TO_FDA_REVIEW not permitted in literature_review phase
40
  action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
41
  _, reward, _, _ = manager.step(action)
42
  assert reward.r_validity < 0, "r_validity must be negative for invalid action"
43
 
44
- def test_invalid_action_state_unchanged(
45
- self, manager: EpisodeManager
46
- ) -> None:
47
  action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
48
  history_before = list(manager._latent.action_history)
49
  step_before = len(history_before)
@@ -62,9 +58,7 @@ class TestInvalidActionReturnsNegativeRValidity:
62
  assert len(obs.rule_violations) > 0
63
  assert len(info["violations"]) > 0
64
 
65
- def test_invalid_action_done_is_false(
66
- self, manager: EpisodeManager
67
- ) -> None:
68
  action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
69
  _, _, done, _ = manager.step(action)
70
  assert done is False
@@ -108,9 +102,7 @@ class TestNoUnhandledExceptions:
108
  with pytest.raises(RuntimeError, match="No active episode"):
109
  em.step(action)
110
 
111
- def test_multiple_invalid_steps_do_not_raise(
112
- self, manager: EpisodeManager
113
- ) -> None:
114
  action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
115
  for _ in range(5):
116
  _, reward, _, _ = manager.step(action)
 
33
  class TestInvalidActionReturnsNegativeRValidity:
34
  """Requirement 10.1: invalid actions → negative r_validity, latent unchanged."""
35
 
36
+ def test_invalid_action_r_validity_negative(self, manager: EpisodeManager) -> None:
 
 
37
  # SUBMIT_TO_FDA_REVIEW not permitted in literature_review phase
38
  action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
39
  _, reward, _, _ = manager.step(action)
40
  assert reward.r_validity < 0, "r_validity must be negative for invalid action"
41
 
42
+ def test_invalid_action_state_unchanged(self, manager: EpisodeManager) -> None:
 
 
43
  action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
44
  history_before = list(manager._latent.action_history)
45
  step_before = len(history_before)
 
58
  assert len(obs.rule_violations) > 0
59
  assert len(info["violations"]) > 0
60
 
61
+ def test_invalid_action_done_is_false(self, manager: EpisodeManager) -> None:
 
 
62
  action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
63
  _, _, done, _ = manager.step(action)
64
  assert done is False
 
102
  with pytest.raises(RuntimeError, match="No active episode"):
103
  em.step(action)
104
 
105
+ def test_multiple_invalid_steps_do_not_raise(self, manager: EpisodeManager) -> None:
 
 
106
  action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
107
  for _ in range(5):
108
  _, reward, _, _ = manager.step(action)
tests/test_judge.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for server/judge.py — TrialJudge multi-layer verification.
3
+
4
+ Covers:
5
+ - Layer 1 programmatic checks (power, p-value, FDA compliance, budget)
6
+ - Layer 2 persona selection (junior/senior/principal)
7
+ - Overconfidence penalty
8
+ - Hint generation for junior persona
9
+ - No unhandled exceptions on any valid input (req 10.4)
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import pytest
15
+
16
+ from models import ActionType, TrialAction, TrialLatentState, TrialState
17
+ from server.judge import JudgeResult, TrialJudge, _select_persona
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Fixtures
21
+ # ---------------------------------------------------------------------------
22
+
23
+
24
+ def _make_latent(**overrides) -> TrialLatentState:
25
+ defaults = dict(
26
+ true_effect_size=0.8,
27
+ true_side_effect_rate=0.05,
28
+ true_responder_population="all",
29
+ true_responder_criteria=[],
30
+ true_dose_response={},
31
+ true_mechanism="unknown",
32
+ placebo_response_rate=0.1,
33
+ dropout_rate=0.05,
34
+ site_variability=0.0,
35
+ measurement_noise=0.0,
36
+ budget_remaining=500_000.0,
37
+ time_remaining_days=300,
38
+ patients_enrolled=200,
39
+ phase_i_complete=True,
40
+ mtd_identified=True,
41
+ effect_estimated=True,
42
+ protocol_submitted=True,
43
+ interim_complete=True,
44
+ trial_complete=True,
45
+ adverse_events=0,
46
+ episode_phase="analysis",
47
+ action_history=["run_primary_analysis"],
48
+ seed=42,
49
+ )
50
+ defaults.update(overrides)
51
+ return TrialLatentState(**defaults)
52
+
53
+
54
+ def _make_state(difficulty: float = 0.3) -> TrialState:
55
+ return TrialState(
56
+ episode_id="test-ep",
57
+ step_count=5,
58
+ difficulty=difficulty,
59
+ scenario_id="solid_tumor_chemo",
60
+ curriculum_tier="0",
61
+ curriculum_stats={},
62
+ action_diversity=0.8,
63
+ phase_compliance_rate=1.0,
64
+ is_resolved=False,
65
+ )
66
+
67
+
68
+ def _make_action(
69
+ action_type: ActionType = ActionType.RUN_PRIMARY_ANALYSIS,
70
+ confidence: float = 0.5,
71
+ **params,
72
+ ) -> TrialAction:
73
+ return TrialAction(
74
+ action_type=action_type,
75
+ parameters=params,
76
+ justification="test",
77
+ confidence=confidence,
78
+ )
79
+
80
+
81
+ # ---------------------------------------------------------------------------
82
+ # Persona selection
83
+ # ---------------------------------------------------------------------------
84
+
85
+
86
+ def test_persona_junior():
87
+ assert _select_persona(0.0) == "junior"
88
+ assert _select_persona(0.39) == "junior"
89
+
90
+
91
+ def test_persona_senior():
92
+ assert _select_persona(0.4) == "senior"
93
+ assert _select_persona(0.7) == "senior"
94
+
95
+
96
+ def test_persona_principal():
97
+ assert _select_persona(0.71) == "principal"
98
+ assert _select_persona(1.0) == "principal"
99
+
100
+
101
+ # ---------------------------------------------------------------------------
102
+ # Layer 1: budget check
103
+ # ---------------------------------------------------------------------------
104
+
105
+
106
+ def test_budget_exhausted_fails():
107
+ judge = TrialJudge()
108
+ latent = _make_latent(budget_remaining=0.0)
109
+ result = judge.verify(_make_action(), _make_state(), latent)
110
+ assert not result.passed
111
+ assert any("budget" in v.lower() for v in result.violations)
112
+
113
+
114
+ def test_budget_negative_fails():
115
+ judge = TrialJudge()
116
+ latent = _make_latent(budget_remaining=-100.0)
117
+ result = judge.verify(_make_action(), _make_state(), latent)
118
+ assert not result.passed
119
+ assert any("budget" in v.lower() for v in result.violations)
120
+
121
+
122
+ def test_budget_positive_passes_budget_check():
123
+ judge = TrialJudge()
124
+ latent = _make_latent(budget_remaining=1.0)
125
+ # Other checks may still fail, but budget violation should not be present
126
+ result = judge.verify(_make_action(), _make_state(), latent)
127
+ assert not any("budget" in v.lower() for v in result.violations)
128
+
129
+
130
+ # ---------------------------------------------------------------------------
131
+ # Layer 1: power check
132
+ # ---------------------------------------------------------------------------
133
+
134
+
135
+ def test_low_power_fails():
136
+ judge = TrialJudge()
137
+ # Very small effect + few patients → low power
138
+ latent = _make_latent(true_effect_size=0.01, patients_enrolled=10)
139
+ result = judge.verify(_make_action(), _make_state(), latent)
140
+ assert not result.passed
141
+ assert any("power" in v.lower() for v in result.violations)
142
+
143
+
144
+ def test_sufficient_power_no_power_violation():
145
+ judge = TrialJudge()
146
+ # Large effect + many patients → high power
147
+ latent = _make_latent(true_effect_size=1.5, patients_enrolled=500)
148
+ result = judge.verify(_make_action(), _make_state(), latent)
149
+ assert not any("power" in v.lower() for v in result.violations)
150
+
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # Layer 1: p-value check
154
+ # ---------------------------------------------------------------------------
155
+
156
+
157
+ def test_nonsignificant_pvalue_fails():
158
+ judge = TrialJudge()
159
+ # Zero effect → p-value = 1.0
160
+ latent = _make_latent(true_effect_size=0.0, patients_enrolled=100)
161
+ result = judge.verify(_make_action(), _make_state(), latent)
162
+ assert not result.passed
163
+ assert any("p-value" in v.lower() for v in result.violations)
164
+
165
+
166
+ def test_significant_pvalue_no_pvalue_violation():
167
+ judge = TrialJudge()
168
+ # Large effect + many patients → very small p-value
169
+ latent = _make_latent(true_effect_size=2.0, patients_enrolled=1000)
170
+ result = judge.verify(_make_action(), _make_state(), latent)
171
+ assert not any("p-value" in v.lower() for v in result.violations)
172
+
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # Layer 1: FDA compliance
176
+ # ---------------------------------------------------------------------------
177
+
178
+
179
+ def test_fda_violation_propagated():
180
+ judge = TrialJudge()
181
+ # Action not permitted in current phase
182
+ latent = _make_latent(episode_phase="literature_review")
183
+ action = _make_action(action_type=ActionType.SUBMIT_TO_FDA_REVIEW)
184
+ result = judge.verify(action, _make_state(), latent)
185
+ assert not result.passed
186
+ assert len(result.violations) > 0
187
+
188
+
189
+ # ---------------------------------------------------------------------------
190
+ # Overconfidence penalty
191
+ # ---------------------------------------------------------------------------
192
+
193
+
194
+ def test_overconfidence_penalty_applied_when_high_confidence_and_wrong():
195
+ judge = TrialJudge()
196
+ latent = _make_latent(budget_remaining=0.0) # guaranteed violation
197
+ action = _make_action(confidence=0.9)
198
+ result = judge.verify(action, _make_state(), latent)
199
+ assert not result.passed
200
+ assert result.overconfidence_penalty < 0.0
201
+
202
+
203
+ def test_no_overconfidence_penalty_when_low_confidence():
204
+ judge = TrialJudge()
205
+ latent = _make_latent(budget_remaining=0.0) # violation present
206
+ action = _make_action(confidence=0.5)
207
+ result = judge.verify(action, _make_state(), latent)
208
+ assert result.overconfidence_penalty == 0.0
209
+
210
+
211
+ def test_no_overconfidence_penalty_when_passed():
212
+ judge = TrialJudge()
213
+ # Use large effect + many patients to pass power/p-value, valid phase/action
214
+ latent = _make_latent(
215
+ true_effect_size=2.0,
216
+ patients_enrolled=1000,
217
+ budget_remaining=500_000.0,
218
+ episode_phase="analysis",
219
+ interim_complete=True,
220
+ trial_complete=True,
221
+ )
222
+ action = _make_action(action_type=ActionType.RUN_PRIMARY_ANALYSIS, confidence=0.95)
223
+ result = judge.verify(action, _make_state(), latent)
224
+ if result.passed:
225
+ assert result.overconfidence_penalty == 0.0
226
+
227
+
228
+ def test_overconfidence_penalty_scales_with_violation_count():
229
+ judge = TrialJudge()
230
+ # Multiple violations: budget + low power + non-significant p-value
231
+ latent = _make_latent(
232
+ budget_remaining=0.0,
233
+ true_effect_size=0.0,
234
+ patients_enrolled=1,
235
+ )
236
+ action = _make_action(confidence=0.9)
237
+ result = judge.verify(action, _make_state(), latent)
238
+ assert result.overconfidence_penalty <= -1.0 # at least 2 violations × -0.5
239
+
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # Layer 2: persona in result
243
+ # ---------------------------------------------------------------------------
244
+
245
+
246
+ def test_junior_persona_in_result():
247
+ judge = TrialJudge()
248
+ latent = _make_latent(budget_remaining=0.0)
249
+ result = judge.verify(_make_action(), _make_state(difficulty=0.2), latent)
250
+ assert result.persona == "junior"
251
+
252
+
253
+ def test_senior_persona_in_result():
254
+ judge = TrialJudge()
255
+ latent = _make_latent(budget_remaining=0.0)
256
+ result = judge.verify(_make_action(), _make_state(difficulty=0.5), latent)
257
+ assert result.persona == "senior"
258
+
259
+
260
+ def test_principal_persona_in_result():
261
+ judge = TrialJudge()
262
+ latent = _make_latent(budget_remaining=0.0)
263
+ result = judge.verify(_make_action(), _make_state(difficulty=0.9), latent)
264
+ assert result.persona == "principal"
265
+
266
+
267
+ # ---------------------------------------------------------------------------
268
+ # Layer 2: hints
269
+ # ---------------------------------------------------------------------------
270
+
271
+
272
+ def test_junior_gets_hint_on_failure():
273
+ judge = TrialJudge()
274
+ latent = _make_latent(budget_remaining=0.0)
275
+ result = judge.verify(_make_action(), _make_state(difficulty=0.2), latent)
276
+ assert not result.passed
277
+ assert result.hint is not None and len(result.hint) > 0
278
+
279
+
280
+ def test_senior_no_hint_on_failure():
281
+ judge = TrialJudge()
282
+ latent = _make_latent(budget_remaining=0.0)
283
+ result = judge.verify(_make_action(), _make_state(difficulty=0.5), latent)
284
+ assert result.hint is None
285
+
286
+
287
+ def test_principal_no_hint_on_failure():
288
+ judge = TrialJudge()
289
+ latent = _make_latent(budget_remaining=0.0)
290
+ result = judge.verify(_make_action(), _make_state(difficulty=0.9), latent)
291
+ assert result.hint is None
292
+
293
+
294
+ def test_junior_gets_hint_on_pass():
295
+ judge = TrialJudge()
296
+ latent = _make_latent(
297
+ true_effect_size=2.0,
298
+ patients_enrolled=1000,
299
+ budget_remaining=500_000.0,
300
+ episode_phase="analysis",
301
+ interim_complete=True,
302
+ trial_complete=True,
303
+ )
304
+ action = _make_action(action_type=ActionType.RUN_PRIMARY_ANALYSIS)
305
+ result = judge.verify(action, _make_state(difficulty=0.2), latent)
306
+ if result.passed:
307
+ assert result.hint is not None
308
+
309
+
310
+ # ---------------------------------------------------------------------------
311
+ # JudgeResult model
312
+ # ---------------------------------------------------------------------------
313
+
314
+
315
+ def test_judge_result_is_pydantic_model():
316
+ result = JudgeResult(
317
+ passed=True,
318
+ violations=[],
319
+ feedback="ok",
320
+ hint=None,
321
+ overconfidence_penalty=0.0,
322
+ persona="senior",
323
+ )
324
+ assert result.passed is True
325
+ assert result.persona == "senior"
326
+
327
+
328
+ # ---------------------------------------------------------------------------
329
+ # Req 10.4: no unhandled exceptions
330
+ # ---------------------------------------------------------------------------
331
+
332
+
333
+ @pytest.mark.parametrize(
334
+ "action_type",
335
+ list(ActionType),
336
+ )
337
+ def test_no_exception_for_any_action_type(action_type):
338
+ """TrialJudge.verify must never raise for any valid action type (req 10.4)."""
339
+ judge = TrialJudge()
340
+ latent = _make_latent()
341
+ state = _make_state()
342
+ action = TrialAction(
343
+ action_type=action_type,
344
+ parameters={},
345
+ justification="test",
346
+ confidence=0.5,
347
+ )
348
+ # Must not raise
349
+ result = judge.verify(action, state, latent)
350
+ assert isinstance(result, JudgeResult)
tests/test_noise_model.py CHANGED
@@ -44,16 +44,12 @@ class TestNoiseModelIdempotence:
44
  r2 = NoiseModel(seed=42).randomize(base_scenario)
45
  assert r1.time_budget_days == r2.time_budget_days
46
 
47
- def test_same_seed_same_dropout_range(
48
- self, base_scenario: ScenarioConfig
49
- ) -> None:
50
  r1 = NoiseModel(seed=42).randomize(base_scenario)
51
  r2 = NoiseModel(seed=42).randomize(base_scenario)
52
  assert r1.dropout_rate_range == r2.dropout_rate_range
53
 
54
- def test_same_seed_same_placebo_range(
55
- self, base_scenario: ScenarioConfig
56
- ) -> None:
57
  r1 = NoiseModel(seed=42).randomize(base_scenario)
58
  r2 = NoiseModel(seed=42).randomize(base_scenario)
59
  assert r1.placebo_response_range == r2.placebo_response_range
@@ -114,9 +110,7 @@ class TestNoiseModelRanges:
114
  assert result.side_effect_rate_range == base_scenario.side_effect_rate_range
115
  assert result.min_sample_size == base_scenario.min_sample_size
116
 
117
- def test_time_budget_at_least_one_day(
118
- self, base_scenario: ScenarioConfig
119
- ) -> None:
120
  for seed in range(50):
121
  result = NoiseModel(seed=seed).randomize(base_scenario)
122
  assert result.time_budget_days >= 1
 
44
  r2 = NoiseModel(seed=42).randomize(base_scenario)
45
  assert r1.time_budget_days == r2.time_budget_days
46
 
47
+ def test_same_seed_same_dropout_range(self, base_scenario: ScenarioConfig) -> None:
 
 
48
  r1 = NoiseModel(seed=42).randomize(base_scenario)
49
  r2 = NoiseModel(seed=42).randomize(base_scenario)
50
  assert r1.dropout_rate_range == r2.dropout_rate_range
51
 
52
+ def test_same_seed_same_placebo_range(self, base_scenario: ScenarioConfig) -> None:
 
 
53
  r1 = NoiseModel(seed=42).randomize(base_scenario)
54
  r2 = NoiseModel(seed=42).randomize(base_scenario)
55
  assert r1.placebo_response_range == r2.placebo_response_range
 
110
  assert result.side_effect_rate_range == base_scenario.side_effect_rate_range
111
  assert result.min_sample_size == base_scenario.min_sample_size
112
 
113
+ def test_time_budget_at_least_one_day(self, base_scenario: ScenarioConfig) -> None:
 
 
114
  for seed in range(50):
115
  result = NoiseModel(seed=seed).randomize(base_scenario)
116
  assert result.time_budget_days >= 1
tests/test_output_generator.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for OutputGenerator — noisy TrialObservation generation (Task 15).
3
+
4
+ Requirements 9.1, 9.2, 9.3, 9.4:
5
+ - OutputGenerator produces a TrialObservation from a TrialLatentState
6
+ - Agent never sees raw hidden values (noise is always injected)
7
+ - phase_data, resource_status, available_actions are correctly populated
8
+ - Measurement noise and site variability are applied via NoiseModel
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import pytest
14
+
15
+ from models import ActionType, TrialLatentState, TrialState
16
+ from server.noise_model import NoiseModel
17
+ from server.simulator.output_generator import OutputGenerator
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Fixtures
21
+ # ---------------------------------------------------------------------------
22
+
23
+
24
+ @pytest.fixture()
25
+ def base_latent() -> TrialLatentState:
26
+ """A minimal TrialLatentState for testing."""
27
+ return TrialLatentState(
28
+ true_effect_size=0.5,
29
+ true_side_effect_rate=0.10,
30
+ true_responder_population="BRCA1+",
31
+ true_responder_criteria=["BRCA1+", "age < 65"],
32
+ true_dose_response={10.0: 0.2, 20.0: 0.4, 40.0: 0.7},
33
+ true_mechanism="PARP inhibition",
34
+ placebo_response_rate=0.15,
35
+ dropout_rate=0.08,
36
+ site_variability=0.05,
37
+ measurement_noise=0.05,
38
+ budget_remaining=500_000.0,
39
+ time_remaining_days=200,
40
+ patients_enrolled=0,
41
+ phase_i_complete=False,
42
+ mtd_identified=False,
43
+ effect_estimated=False,
44
+ protocol_submitted=False,
45
+ interim_complete=False,
46
+ trial_complete=False,
47
+ adverse_events=0,
48
+ episode_phase="design",
49
+ action_history=[],
50
+ seed=42,
51
+ )
52
+
53
+
54
+ @pytest.fixture()
55
+ def trial_state() -> TrialState:
56
+ return TrialState(
57
+ episode_id="ep-001",
58
+ step_count=1,
59
+ difficulty=0.5,
60
+ scenario_id="solid_tumor_chemo",
61
+ curriculum_tier="tier_0",
62
+ curriculum_stats={},
63
+ action_diversity=0.0,
64
+ phase_compliance_rate=1.0,
65
+ is_resolved=False,
66
+ )
67
+
68
+
69
+ @pytest.fixture()
70
+ def generator() -> OutputGenerator:
71
+ return OutputGenerator(noise_model=NoiseModel(seed=42))
72
+
73
+
74
+ def _make_obs(generator, latent, trial_state, **kwargs):
75
+ defaults = dict(
76
+ steps_taken=1,
77
+ max_steps=20,
78
+ rule_violations=[],
79
+ done=False,
80
+ reward=0.0,
81
+ scenario_description="Test scenario",
82
+ hint="",
83
+ )
84
+ defaults.update(kwargs)
85
+ return generator.generate(latent, trial_state, **defaults)
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Basic structure tests
90
+ # ---------------------------------------------------------------------------
91
+
92
+
93
+ class TestObservationStructure:
94
+ """TrialObservation has all required fields populated."""
95
+
96
+ def test_returns_trial_observation(self, generator, base_latent, trial_state):
97
+ from models import TrialObservation
98
+
99
+ obs = _make_obs(generator, base_latent, trial_state)
100
+ assert isinstance(obs, TrialObservation)
101
+
102
+ def test_scenario_description_passed_through(
103
+ self, generator, base_latent, trial_state
104
+ ):
105
+ obs = _make_obs(
106
+ generator, base_latent, trial_state, scenario_description="My scenario"
107
+ )
108
+ assert obs.scenario_description == "My scenario"
109
+
110
+ def test_steps_taken_and_max_steps(self, generator, base_latent, trial_state):
111
+ obs = _make_obs(
112
+ generator, base_latent, trial_state, steps_taken=5, max_steps=30
113
+ )
114
+ assert obs.steps_taken == 5
115
+ assert obs.max_steps == 30
116
+
117
+ def test_done_and_reward_passed_through(self, generator, base_latent, trial_state):
118
+ obs = _make_obs(generator, base_latent, trial_state, done=True, reward=1.5)
119
+ assert obs.done is True
120
+ assert obs.reward == 1.5
121
+
122
+ def test_rule_violations_passed_through(self, generator, base_latent, trial_state):
123
+ violations = ["violation A", "violation B"]
124
+ obs = _make_obs(generator, base_latent, trial_state, rule_violations=violations)
125
+ assert obs.rule_violations == violations
126
+
127
+ def test_hint_passed_through(self, generator, base_latent, trial_state):
128
+ obs = _make_obs(generator, base_latent, trial_state, hint="Try Phase I first")
129
+ assert obs.hint == "Try Phase I first"
130
+
131
+
132
+ # ---------------------------------------------------------------------------
133
+ # resource_status tests
134
+ # ---------------------------------------------------------------------------
135
+
136
+
137
+ class TestResourceStatus:
138
+ """resource_status reflects latent state resource fields."""
139
+
140
+ def test_budget_remaining(self, generator, base_latent, trial_state):
141
+ obs = _make_obs(generator, base_latent, trial_state)
142
+ assert obs.resource_status["budget_remaining"] == base_latent.budget_remaining
143
+
144
+ def test_time_remaining_days(self, generator, base_latent, trial_state):
145
+ obs = _make_obs(generator, base_latent, trial_state)
146
+ assert (
147
+ obs.resource_status["time_remaining_days"]
148
+ == base_latent.time_remaining_days
149
+ )
150
+
151
+ def test_patients_enrolled(self, generator, base_latent, trial_state):
152
+ latent = base_latent.model_copy(update={"patients_enrolled": 50})
153
+ obs = _make_obs(generator, latent, trial_state)
154
+ assert obs.resource_status["patients_enrolled"] == 50
155
+
156
+ def test_resource_status_has_three_keys(self, generator, base_latent, trial_state):
157
+ obs = _make_obs(generator, base_latent, trial_state)
158
+ assert set(obs.resource_status.keys()) == {
159
+ "budget_remaining",
160
+ "time_remaining_days",
161
+ "patients_enrolled",
162
+ }
163
+
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # phase_data tests — noise injection
167
+ # ---------------------------------------------------------------------------
168
+
169
+
170
+ class TestPhaseDataNoiseInjection:
171
+ """Agent never sees raw hidden values — noise is always injected."""
172
+
173
+ def test_true_effect_size_not_in_phase_data(
174
+ self, generator, base_latent, trial_state
175
+ ):
176
+ """Raw true_effect_size must never appear directly in phase_data."""
177
+ latent = base_latent.model_copy(update={"effect_estimated": True})
178
+ obs = _make_obs(generator, latent, trial_state)
179
+ # observed_effect_size should differ from true value (noise injected)
180
+ # We can't guarantee they differ by chance, but the key should be present
181
+ assert "observed_effect_size" in obs.phase_data
182
+
183
+ def test_effect_size_not_exposed_before_estimation(
184
+ self, generator, base_latent, trial_state
185
+ ):
186
+ """observed_effect_size should not appear before ESTIMATE_EFFECT_SIZE."""
187
+ obs = _make_obs(generator, base_latent, trial_state)
188
+ assert "observed_effect_size" not in obs.phase_data
189
+
190
+ def test_effect_size_exposed_after_estimation(
191
+ self, generator, base_latent, trial_state
192
+ ):
193
+ latent = base_latent.model_copy(update={"effect_estimated": True})
194
+ obs = _make_obs(generator, latent, trial_state)
195
+ assert "observed_effect_size" in obs.phase_data
196
+ assert "effect_size_ci" in obs.phase_data
197
+
198
+ def test_ae_rate_not_exposed_before_phase_i(
199
+ self, generator, base_latent, trial_state
200
+ ):
201
+ """Adverse event rate should not appear before Phase I or safety signal."""
202
+ obs = _make_obs(generator, base_latent, trial_state)
203
+ assert "observed_adverse_event_rate" not in obs.phase_data
204
+
205
+ def test_ae_rate_exposed_after_phase_i(self, generator, base_latent, trial_state):
206
+ latent = base_latent.model_copy(update={"phase_i_complete": True})
207
+ obs = _make_obs(generator, latent, trial_state)
208
+ assert "observed_adverse_event_rate" in obs.phase_data
209
+
210
+ def test_ae_rate_exposed_after_safety_signal(
211
+ self, generator, base_latent, trial_state
212
+ ):
213
+ latent = base_latent.model_copy(
214
+ update={"action_history": [ActionType.OBSERVE_SAFETY_SIGNAL.value]}
215
+ )
216
+ obs = _make_obs(generator, latent, trial_state)
217
+ assert "observed_adverse_event_rate" in obs.phase_data
218
+
219
+ def test_ae_rate_is_clipped_to_0_1(self, generator, base_latent, trial_state):
220
+ latent = base_latent.model_copy(
221
+ update={"phase_i_complete": True, "true_side_effect_rate": 0.99}
222
+ )
223
+ obs = _make_obs(generator, latent, trial_state)
224
+ rate = obs.phase_data["observed_adverse_event_rate"]
225
+ assert 0.0 <= rate <= 1.0
226
+
227
+ def test_placebo_response_not_exposed_before_interim(
228
+ self, generator, base_latent, trial_state
229
+ ):
230
+ obs = _make_obs(generator, base_latent, trial_state)
231
+ assert "observed_placebo_response" not in obs.phase_data
232
+
233
+ def test_placebo_response_exposed_after_interim(
234
+ self, generator, base_latent, trial_state
235
+ ):
236
+ latent = base_latent.model_copy(update={"interim_complete": True})
237
+ obs = _make_obs(generator, latent, trial_state)
238
+ assert "observed_placebo_response" in obs.phase_data
239
+
240
+ def test_dose_response_not_exposed_before_phase_i(
241
+ self, generator, base_latent, trial_state
242
+ ):
243
+ obs = _make_obs(generator, base_latent, trial_state)
244
+ assert "observed_dose_response" not in obs.phase_data
245
+
246
+ def test_dose_response_exposed_after_phase_i(
247
+ self, generator, base_latent, trial_state
248
+ ):
249
+ latent = base_latent.model_copy(update={"phase_i_complete": True})
250
+ obs = _make_obs(generator, latent, trial_state)
251
+ assert "observed_dose_response" in obs.phase_data
252
+ # All dose-response values should be clipped to [0, 1]
253
+ for v in obs.phase_data["observed_dose_response"].values():
254
+ assert 0.0 <= v <= 1.0
255
+
256
+ def test_dropout_rate_not_exposed_before_enrollment(
257
+ self, generator, base_latent, trial_state
258
+ ):
259
+ obs = _make_obs(generator, base_latent, trial_state)
260
+ assert "observed_dropout_rate" not in obs.phase_data
261
+
262
+ def test_dropout_rate_exposed_after_enrollment(
263
+ self, generator, base_latent, trial_state
264
+ ):
265
+ latent = base_latent.model_copy(update={"patients_enrolled": 10})
266
+ obs = _make_obs(generator, latent, trial_state)
267
+ assert "observed_dropout_rate" in obs.phase_data
268
+
269
+ def test_responder_population_hint_not_exposed_without_biomarker(
270
+ self, generator, base_latent, trial_state
271
+ ):
272
+ obs = _make_obs(generator, base_latent, trial_state)
273
+ assert "responder_population_hint" not in obs.phase_data
274
+
275
+ def test_responder_population_hint_exposed_after_biomarker(
276
+ self, generator, base_latent, trial_state
277
+ ):
278
+ latent = base_latent.model_copy(
279
+ update={"action_history": [ActionType.ADD_BIOMARKER_STRATIFICATION.value]}
280
+ )
281
+ obs = _make_obs(generator, latent, trial_state)
282
+ assert "responder_population_hint" in obs.phase_data
283
+ # Population label is revealed but NOT the true criteria
284
+ assert obs.phase_data["responder_population_hint"] == "BRCA1+"
285
+ assert "true_responder_criteria" not in obs.phase_data
286
+
287
+ def test_milestone_flags_in_phase_data(self, generator, base_latent, trial_state):
288
+ """Milestone flags are observable (not hidden values)."""
289
+ obs = _make_obs(generator, base_latent, trial_state)
290
+ assert "phase_i_complete" in obs.phase_data
291
+ assert "mtd_identified" in obs.phase_data
292
+ assert "effect_estimated" in obs.phase_data
293
+ assert "protocol_submitted" in obs.phase_data
294
+ assert "interim_complete" in obs.phase_data
295
+ assert "trial_complete" in obs.phase_data
296
+
297
+ def test_true_mechanism_not_in_phase_data(
298
+ self, generator, base_latent, trial_state
299
+ ):
300
+ """true_mechanism is a hidden value and must never appear in phase_data."""
301
+ obs = _make_obs(generator, base_latent, trial_state)
302
+ assert "true_mechanism" not in obs.phase_data
303
+
304
+ def test_true_responder_criteria_not_in_phase_data(
305
+ self, generator, base_latent, trial_state
306
+ ):
307
+ """true_responder_criteria is hidden and must never appear in phase_data."""
308
+ obs = _make_obs(generator, base_latent, trial_state)
309
+ assert "true_responder_criteria" not in obs.phase_data
310
+
311
+
312
+ # ---------------------------------------------------------------------------
313
+ # available_actions tests
314
+ # ---------------------------------------------------------------------------
315
+
316
+
317
+ class TestAvailableActions:
318
+ """available_actions reflects phase-permitted actions filtered by prerequisites."""
319
+
320
+ def test_available_actions_is_list_of_strings(
321
+ self, generator, base_latent, trial_state
322
+ ):
323
+ obs = _make_obs(generator, base_latent, trial_state)
324
+ assert isinstance(obs.available_actions, list)
325
+ assert all(isinstance(a, str) for a in obs.available_actions)
326
+
327
+ def test_design_phase_actions(self, generator, base_latent, trial_state):
328
+ """In design phase with empty history, basic design actions are available."""
329
+ obs = _make_obs(generator, base_latent, trial_state)
330
+ # SET_SAMPLE_SIZE, SET_INCLUSION_CRITERIA, SET_EXCLUSION_CRITERIA should be available
331
+ assert ActionType.SET_SAMPLE_SIZE.value in obs.available_actions
332
+ assert ActionType.SET_INCLUSION_CRITERIA.value in obs.available_actions
333
+ assert ActionType.SET_EXCLUSION_CRITERIA.value in obs.available_actions
334
+
335
+ def test_dosing_schedule_requires_primary_endpoint(
336
+ self, generator, base_latent, trial_state
337
+ ):
338
+ """SET_DOSING_SCHEDULE requires SET_PRIMARY_ENDPOINT in history."""
339
+ obs = _make_obs(generator, base_latent, trial_state)
340
+ # Without SET_PRIMARY_ENDPOINT in history, SET_DOSING_SCHEDULE should not be available
341
+ assert ActionType.SET_DOSING_SCHEDULE.value not in obs.available_actions
342
+
343
+ def test_dosing_schedule_available_after_primary_endpoint(
344
+ self, generator, base_latent, trial_state
345
+ ):
346
+ latent = base_latent.model_copy(
347
+ update={"action_history": [ActionType.SET_PRIMARY_ENDPOINT.value]}
348
+ )
349
+ obs = _make_obs(generator, latent, trial_state)
350
+ assert ActionType.SET_DOSING_SCHEDULE.value in obs.available_actions
351
+
352
+ def test_synthesize_conclusion_requires_trial_complete(
353
+ self, generator, base_latent, trial_state
354
+ ):
355
+ latent = base_latent.model_copy(
356
+ update={"episode_phase": "submission", "trial_complete": False}
357
+ )
358
+ obs = _make_obs(generator, latent, trial_state)
359
+ assert ActionType.SYNTHESIZE_CONCLUSION.value not in obs.available_actions
360
+
361
+ def test_synthesize_conclusion_available_when_trial_complete(
362
+ self, generator, base_latent, trial_state
363
+ ):
364
+ latent = base_latent.model_copy(
365
+ update={"episode_phase": "submission", "trial_complete": True}
366
+ )
367
+ obs = _make_obs(generator, latent, trial_state)
368
+ assert ActionType.SYNTHESIZE_CONCLUSION.value in obs.available_actions
369
+
370
+ def test_run_interim_analysis_requires_patients(
371
+ self, generator, base_latent, trial_state
372
+ ):
373
+ latent = base_latent.model_copy(
374
+ update={"episode_phase": "monitoring", "patients_enrolled": 0}
375
+ )
376
+ obs = _make_obs(generator, latent, trial_state)
377
+ assert ActionType.RUN_INTERIM_ANALYSIS.value not in obs.available_actions
378
+
379
+ def test_run_interim_analysis_available_with_patients(
380
+ self, generator, base_latent, trial_state
381
+ ):
382
+ latent = base_latent.model_copy(
383
+ update={"episode_phase": "monitoring", "patients_enrolled": 50}
384
+ )
385
+ obs = _make_obs(generator, latent, trial_state)
386
+ assert ActionType.RUN_INTERIM_ANALYSIS.value in obs.available_actions
387
+
388
+ def test_run_primary_analysis_requires_interim_complete(
389
+ self, generator, base_latent, trial_state
390
+ ):
391
+ latent = base_latent.model_copy(
392
+ update={"episode_phase": "analysis", "interim_complete": False}
393
+ )
394
+ obs = _make_obs(generator, latent, trial_state)
395
+ assert ActionType.RUN_PRIMARY_ANALYSIS.value not in obs.available_actions
396
+
397
+ def test_run_primary_analysis_available_after_interim(
398
+ self, generator, base_latent, trial_state
399
+ ):
400
+ latent = base_latent.model_copy(
401
+ update={"episode_phase": "analysis", "interim_complete": True}
402
+ )
403
+ obs = _make_obs(generator, latent, trial_state)
404
+ assert ActionType.RUN_PRIMARY_ANALYSIS.value in obs.available_actions
405
+
406
+ def test_unknown_phase_returns_empty_actions(
407
+ self, generator, base_latent, trial_state
408
+ ):
409
+ latent = base_latent.model_copy(update={"episode_phase": "unknown_phase"})
410
+ obs = _make_obs(generator, latent, trial_state)
411
+ assert obs.available_actions == []
412
+
413
+
414
+ # ---------------------------------------------------------------------------
415
+ # Determinism tests
416
+ # ---------------------------------------------------------------------------
417
+
418
+
419
+ class TestDeterminism:
420
+ """Same seed + same latent state → same observation (requirement 9.2)."""
421
+
422
+ def test_same_seed_same_observed_effect(self, base_latent, trial_state):
423
+ latent = base_latent.model_copy(update={"effect_estimated": True})
424
+ obs1 = OutputGenerator(NoiseModel(seed=99)).generate(
425
+ latent,
426
+ trial_state,
427
+ steps_taken=1,
428
+ max_steps=20,
429
+ rule_violations=[],
430
+ done=False,
431
+ reward=0.0,
432
+ scenario_description="S",
433
+ hint="",
434
+ )
435
+ obs2 = OutputGenerator(NoiseModel(seed=99)).generate(
436
+ latent,
437
+ trial_state,
438
+ steps_taken=1,
439
+ max_steps=20,
440
+ rule_violations=[],
441
+ done=False,
442
+ reward=0.0,
443
+ scenario_description="S",
444
+ hint="",
445
+ )
446
+ assert (
447
+ obs1.phase_data["observed_effect_size"]
448
+ == obs2.phase_data["observed_effect_size"]
449
+ )
450
+
451
+ def test_different_seeds_different_observed_effect(self, base_latent, trial_state):
452
+ latent = base_latent.model_copy(update={"effect_estimated": True})
453
+ obs1 = OutputGenerator(NoiseModel(seed=1)).generate(
454
+ latent,
455
+ trial_state,
456
+ steps_taken=1,
457
+ max_steps=20,
458
+ rule_violations=[],
459
+ done=False,
460
+ reward=0.0,
461
+ scenario_description="S",
462
+ hint="",
463
+ )
464
+ obs2 = OutputGenerator(NoiseModel(seed=2)).generate(
465
+ latent,
466
+ trial_state,
467
+ steps_taken=1,
468
+ max_steps=20,
469
+ rule_violations=[],
470
+ done=False,
471
+ reward=0.0,
472
+ scenario_description="S",
473
+ hint="",
474
+ )
475
+ # Different seeds should (almost certainly) produce different noisy values
476
+ assert (
477
+ obs1.phase_data["observed_effect_size"]
478
+ != obs2.phase_data["observed_effect_size"]
479
+ )
tests/test_phase_detector.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for server/phase_detector.py
3
+
4
+ Validates Requirements 8.5 and 9.4:
5
+ - detect_phase classifies actions into correct clinical workflow phases
6
+ - phase_order_correct is True for valid transitions, False for regressions/skips
7
+ - compute_phase_ordering_reward returns correct bonus/penalty values
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import pytest
13
+
14
+ from models import ActionType, TrialAction
15
+ from server.phase_detector import (
16
+ PHASE_BONUS,
17
+ PHASE_ORDER,
18
+ PHASE_SKIP_PENALTY,
19
+ compute_phase_ordering_reward,
20
+ detect_phase,
21
+ )
22
+
23
+
24
+ def _action(action_type: ActionType) -> TrialAction:
25
+ return TrialAction(
26
+ action_type=action_type,
27
+ parameters={},
28
+ justification="test",
29
+ confidence=0.5,
30
+ )
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Phase mapping tests
35
+ # ---------------------------------------------------------------------------
36
+
37
+
38
+ class TestPhaseMapping:
39
+ def test_hypothesis_actions(self):
40
+ for at in [
41
+ ActionType.ESTIMATE_EFFECT_SIZE,
42
+ ActionType.ADD_BIOMARKER_STRATIFICATION,
43
+ ]:
44
+ phase, _ = detect_phase(_action(at), [])
45
+ assert phase == "hypothesis", f"{at} should map to hypothesis"
46
+
47
+ def test_design_actions(self):
48
+ design_actions = [
49
+ ActionType.SET_PRIMARY_ENDPOINT,
50
+ ActionType.SET_SAMPLE_SIZE,
51
+ ActionType.SET_INCLUSION_CRITERIA,
52
+ ActionType.SET_EXCLUSION_CRITERIA,
53
+ ActionType.SET_DOSING_SCHEDULE,
54
+ ActionType.SET_CONTROL_ARM,
55
+ ActionType.SET_RANDOMIZATION_RATIO,
56
+ ActionType.SET_BLINDING,
57
+ ActionType.REQUEST_PROTOCOL_AMENDMENT,
58
+ ]
59
+ for at in design_actions:
60
+ phase, _ = detect_phase(_action(at), [])
61
+ assert phase == "design", f"{at} should map to design"
62
+
63
+ def test_enrollment_action(self):
64
+ phase, _ = detect_phase(_action(ActionType.ENROLL_PATIENTS), [])
65
+ assert phase == "enrollment"
66
+
67
+ def test_monitoring_actions(self):
68
+ monitoring_actions = [
69
+ ActionType.RUN_DOSE_ESCALATION,
70
+ ActionType.OBSERVE_SAFETY_SIGNAL,
71
+ ActionType.RUN_INTERIM_ANALYSIS,
72
+ ActionType.MODIFY_SAMPLE_SIZE,
73
+ ]
74
+ for at in monitoring_actions:
75
+ phase, _ = detect_phase(_action(at), [])
76
+ assert phase == "monitoring", f"{at} should map to monitoring"
77
+
78
+ def test_analysis_actions(self):
79
+ for at in [ActionType.RUN_PRIMARY_ANALYSIS, ActionType.SYNTHESIZE_CONCLUSION]:
80
+ phase, _ = detect_phase(_action(at), [])
81
+ assert phase == "analysis", f"{at} should map to analysis"
82
+
83
+ def test_submission_action(self):
84
+ phase, _ = detect_phase(_action(ActionType.SUBMIT_TO_FDA_REVIEW), [])
85
+ assert phase == "submission"
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Phase order correctness tests
90
+ # ---------------------------------------------------------------------------
91
+
92
+
93
+ class TestPhaseOrderCorrectness:
94
+ def test_empty_history_always_correct(self):
95
+ for at in ActionType:
96
+ _, correct = detect_phase(_action(at), [])
97
+ assert correct is True, f"Empty history should always be correct for {at}"
98
+
99
+ def test_same_phase_is_correct(self):
100
+ _, correct = detect_phase(_action(ActionType.SET_SAMPLE_SIZE), ["design"])
101
+ assert correct is True
102
+
103
+ def test_advance_one_phase_is_correct(self):
104
+ _, correct = detect_phase(_action(ActionType.ENROLL_PATIENTS), ["design"])
105
+ assert correct is True
106
+
107
+ def test_regression_is_incorrect(self):
108
+ # Going from enrollment back to design
109
+ _, correct = detect_phase(_action(ActionType.SET_SAMPLE_SIZE), ["enrollment"])
110
+ assert correct is False
111
+
112
+ def test_skip_one_phase_is_incorrect(self):
113
+ # Jumping from hypothesis to enrollment (skipping design)
114
+ _, correct = detect_phase(_action(ActionType.ENROLL_PATIENTS), ["hypothesis"])
115
+ assert correct is False
116
+
117
+ def test_skip_multiple_phases_is_incorrect(self):
118
+ # Jumping from design to analysis (skipping enrollment + monitoring)
119
+ _, correct = detect_phase(_action(ActionType.RUN_PRIMARY_ANALYSIS), ["design"])
120
+ assert correct is False
121
+
122
+ def test_valid_full_sequence(self):
123
+ """Walk through the full phase sequence and verify all transitions are correct."""
124
+ history: list[str] = []
125
+ sequence = [
126
+ ActionType.ESTIMATE_EFFECT_SIZE, # hypothesis
127
+ ActionType.SET_PRIMARY_ENDPOINT, # design
128
+ ActionType.ENROLL_PATIENTS, # enrollment
129
+ ActionType.RUN_DOSE_ESCALATION, # monitoring
130
+ ActionType.RUN_PRIMARY_ANALYSIS, # analysis
131
+ ActionType.SUBMIT_TO_FDA_REVIEW, # submission
132
+ ]
133
+ for at in sequence:
134
+ phase, correct = detect_phase(_action(at), history)
135
+ assert correct is True, (
136
+ f"Expected correct order for {at} with history {history}"
137
+ )
138
+ history.append(phase)
139
+
140
+
141
+ # ---------------------------------------------------------------------------
142
+ # PHASE_ORDER constant
143
+ # ---------------------------------------------------------------------------
144
+
145
+
146
+ class TestPhaseOrderConstant:
147
+ def test_phase_order_has_seven_phases(self):
148
+ assert len(PHASE_ORDER) == 7
149
+
150
+ def test_phase_order_sequence(self):
151
+ assert PHASE_ORDER == [
152
+ "literature_review",
153
+ "hypothesis",
154
+ "design",
155
+ "enrollment",
156
+ "monitoring",
157
+ "analysis",
158
+ "submission",
159
+ ]
160
+
161
+
162
+ # ---------------------------------------------------------------------------
163
+ # compute_phase_ordering_reward tests
164
+ # ---------------------------------------------------------------------------
165
+
166
+
167
+ class TestComputePhaseOrderingReward:
168
+ def test_empty_history_returns_bonus(self):
169
+ reward = compute_phase_ordering_reward(_action(ActionType.SET_SAMPLE_SIZE), [])
170
+ assert reward == PHASE_BONUS
171
+
172
+ def test_correct_advance_returns_bonus(self):
173
+ reward = compute_phase_ordering_reward(
174
+ _action(ActionType.ENROLL_PATIENTS), ["design"]
175
+ )
176
+ assert reward == PHASE_BONUS
177
+
178
+ def test_same_phase_returns_bonus(self):
179
+ reward = compute_phase_ordering_reward(
180
+ _action(ActionType.SET_SAMPLE_SIZE), ["design"]
181
+ )
182
+ assert reward == PHASE_BONUS
183
+
184
+ def test_regression_returns_zero(self):
185
+ reward = compute_phase_ordering_reward(
186
+ _action(ActionType.SET_SAMPLE_SIZE), ["enrollment"]
187
+ )
188
+ assert reward == 0.0
189
+
190
+ def test_skip_one_phase_returns_single_penalty(self):
191
+ # hypothesis → enrollment skips design (1 skip)
192
+ reward = compute_phase_ordering_reward(
193
+ _action(ActionType.ENROLL_PATIENTS), ["hypothesis"]
194
+ )
195
+ assert reward == pytest.approx(PHASE_SKIP_PENALTY * 1)
196
+
197
+ def test_skip_two_phases_returns_double_penalty(self):
198
+ # design → monitoring skips enrollment (1 skip)
199
+ # design → analysis skips enrollment + monitoring (2 skips)
200
+ reward = compute_phase_ordering_reward(
201
+ _action(ActionType.RUN_PRIMARY_ANALYSIS), ["design"]
202
+ )
203
+ assert reward == pytest.approx(PHASE_SKIP_PENALTY * 2)
204
+
205
+ def test_constants_values(self):
206
+ assert PHASE_BONUS == 0.2
207
+ assert PHASE_SKIP_PENALTY == -0.3