Spaces:
Sleeping
Sleeping
Coding Ninja commited on
Commit ·
1f2ca34
1
Parent(s): 22170b0
Push 3: Curriculum controller, hidden-state pipeline, phase detector, trial judge, and full EpisodeManager wiring
Browse files- models.py +1 -0
- pyproject.toml +3 -0
- server/curriculum/__init__.py +30 -2
- server/curriculum/controller.py +147 -0
- server/curriculum/scenarios.py +116 -0
- server/episode_manager.py +233 -109
- server/judge.py +277 -0
- server/logger.py +1 -3
- server/noise_model.py +5 -0
- server/phase_detector.py +125 -0
- server/reward/reward_computer.py +9 -20
- server/reward/shaping.py +1 -3
- server/rules/prerequisite_rules.py +1 -3
- server/simulator/__init__.py +1 -1
- server/simulator/output_generator.py +238 -0
- server/simulator/transition_engine.py +167 -0
- server/simulator/trial_simulator.py +3 -3
- tests/test_curriculum_controller.py +171 -0
- tests/test_episode_logger_wiring.py +2 -6
- tests/test_episode_manager_compliance.py +4 -12
- tests/test_judge.py +350 -0
- tests/test_noise_model.py +3 -9
- tests/test_output_generator.py +479 -0
- tests/test_phase_detector.py +207 -0
models.py
CHANGED
|
@@ -100,6 +100,7 @@ class TrialLatentState(BaseModel):
|
|
| 100 |
protocol_submitted: bool
|
| 101 |
interim_complete: bool
|
| 102 |
trial_complete: bool
|
|
|
|
| 103 |
# Episode tracking (used by rule engine and phase detector)
|
| 104 |
episode_phase: str
|
| 105 |
action_history: list[str]
|
|
|
|
| 100 |
protocol_submitted: bool
|
| 101 |
interim_complete: bool
|
| 102 |
trial_complete: bool
|
| 103 |
+
adverse_events: int # cumulative count of recorded adverse events
|
| 104 |
# Episode tracking (used by rule engine and phase detector)
|
| 105 |
episode_phase: str
|
| 106 |
action_history: list[str]
|
pyproject.toml
CHANGED
|
@@ -32,6 +32,9 @@ target-version = "py311"
|
|
| 32 |
select = ["E", "F", "W", "I"]
|
| 33 |
ignore = []
|
| 34 |
|
|
|
|
|
|
|
|
|
|
| 35 |
[tool.pytest.ini_options]
|
| 36 |
testpaths = ["tests"]
|
| 37 |
addopts = "-v"
|
|
|
|
| 32 |
select = ["E", "F", "W", "I"]
|
| 33 |
ignore = []
|
| 34 |
|
| 35 |
+
[tool.ruff.lint.per-file-ignores]
|
| 36 |
+
"tests/**" = ["E501"]
|
| 37 |
+
|
| 38 |
[tool.pytest.ini_options]
|
| 39 |
testpaths = ["tests"]
|
| 40 |
addopts = "-v"
|
server/curriculum/__init__.py
CHANGED
|
@@ -1,7 +1,35 @@
|
|
| 1 |
"""
|
| 2 |
curriculum — Curriculum controller and scenario registry.
|
| 3 |
|
| 4 |
-
Provides advance_curriculum, select_scenario, and the four
|
| 5 |
-
ScenarioConfig instances (solid_tumor_chemo, autoimmune_biologic,
|
| 6 |
cns_depression, rare_disease_orphan).
|
| 7 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
curriculum — Curriculum controller and scenario registry.
|
| 3 |
|
| 4 |
+
Provides advance_curriculum, select_scenario, EpisodeMetrics, and the four
|
| 5 |
+
initial ScenarioConfig instances (solid_tumor_chemo, autoimmune_biologic,
|
| 6 |
cns_depression, rare_disease_orphan).
|
| 7 |
"""
|
| 8 |
+
|
| 9 |
+
from server.curriculum.controller import (
|
| 10 |
+
EpisodeMetrics,
|
| 11 |
+
advance_curriculum,
|
| 12 |
+
select_scenario,
|
| 13 |
+
)
|
| 14 |
+
from server.curriculum.scenarios import (
|
| 15 |
+
AUTOIMMUNE_BIOLOGIC,
|
| 16 |
+
CNS_DEPRESSION,
|
| 17 |
+
RARE_DISEASE_ORPHAN,
|
| 18 |
+
SCENARIO_LIST,
|
| 19 |
+
SCENARIOS,
|
| 20 |
+
SOLID_TUMOR_CHEMO,
|
| 21 |
+
WARMUP,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"EpisodeMetrics",
|
| 26 |
+
"advance_curriculum",
|
| 27 |
+
"select_scenario",
|
| 28 |
+
"WARMUP",
|
| 29 |
+
"SOLID_TUMOR_CHEMO",
|
| 30 |
+
"AUTOIMMUNE_BIOLOGIC",
|
| 31 |
+
"CNS_DEPRESSION",
|
| 32 |
+
"RARE_DISEASE_ORPHAN",
|
| 33 |
+
"SCENARIOS",
|
| 34 |
+
"SCENARIO_LIST",
|
| 35 |
+
]
|
server/curriculum/controller.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Curriculum controller for the Clinical Trial Designer environment.
|
| 3 |
+
|
| 4 |
+
Exposes:
|
| 5 |
+
- advance_curriculum(tier, metrics) -> int
|
| 6 |
+
- select_scenario(tier, rng) -> ScenarioConfig
|
| 7 |
+
|
| 8 |
+
5-tier mastery logic:
|
| 9 |
+
Tier 0: warmup
|
| 10 |
+
Tier 1: beginner
|
| 11 |
+
Tier 2: intermediate
|
| 12 |
+
Tier 3: advanced
|
| 13 |
+
Tier 4: expert
|
| 14 |
+
|
| 15 |
+
Graduation rules:
|
| 16 |
+
- 70% rolling success rate over recent episodes → advance one tier
|
| 17 |
+
- 90% success rate after at least 3 episodes → fast-track (skip one tier)
|
| 18 |
+
- Max tier is 4 (expert); cannot advance beyond.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
from dataclasses import dataclass, field
|
| 24 |
+
from typing import Sequence
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
from models import ScenarioConfig
|
| 29 |
+
from server.curriculum.scenarios import (
|
| 30 |
+
AUTOIMMUNE_BIOLOGIC,
|
| 31 |
+
CNS_DEPRESSION,
|
| 32 |
+
RARE_DISEASE_ORPHAN,
|
| 33 |
+
SOLID_TUMOR_CHEMO,
|
| 34 |
+
WARMUP,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# ── Constants ────────────────────────────────────────────────────────────────
|
| 38 |
+
|
| 39 |
+
MIN_TIER: int = 0
|
| 40 |
+
MAX_TIER: int = 4
|
| 41 |
+
|
| 42 |
+
MASTERY_THRESHOLD: float = 0.70 # 70% rolling success → graduate
|
| 43 |
+
FAST_TRACK_THRESHOLD: float = 0.90 # 90% success after ≥3 episodes → skip tier
|
| 44 |
+
FAST_TRACK_MIN_EPISODES: int = 3
|
| 45 |
+
|
| 46 |
+
# Rolling window size for success-rate calculation
|
| 47 |
+
ROLLING_WINDOW: int = 10
|
| 48 |
+
|
| 49 |
+
# Tier → ScenarioConfig mapping (one canonical scenario per tier)
|
| 50 |
+
_TIER_SCENARIO: dict[int, ScenarioConfig] = {
|
| 51 |
+
0: WARMUP,
|
| 52 |
+
1: SOLID_TUMOR_CHEMO,
|
| 53 |
+
2: AUTOIMMUNE_BIOLOGIC,
|
| 54 |
+
3: CNS_DEPRESSION,
|
| 55 |
+
4: RARE_DISEASE_ORPHAN,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
TIER_NAMES: dict[int, str] = {
|
| 59 |
+
0: "warmup",
|
| 60 |
+
1: "beginner",
|
| 61 |
+
2: "intermediate",
|
| 62 |
+
3: "advanced",
|
| 63 |
+
4: "expert",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ── EpisodeMetrics ────────────────────────────────────────────────────────────
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class EpisodeMetrics:
|
| 72 |
+
"""Performance metrics for a completed episode.
|
| 73 |
+
|
| 74 |
+
Attributes:
|
| 75 |
+
success: Whether the episode ended in a successful trial outcome.
|
| 76 |
+
episode_history: Rolling list of recent success booleans (most recent
|
| 77 |
+
episode appended last). The controller uses the last
|
| 78 |
+
``ROLLING_WINDOW`` entries to compute the rolling success rate.
|
| 79 |
+
Callers should append the current episode's ``success`` value
|
| 80 |
+
*before* passing this object to ``advance_curriculum``.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
success: bool
|
| 84 |
+
episode_history: list[bool] = field(default_factory=list)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ── Public API ────────────────────────────────────────────────────────────────
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def advance_curriculum(tier: int, metrics: EpisodeMetrics) -> int:
|
| 91 |
+
"""Return the updated curriculum tier after evaluating episode metrics.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
tier: Current curriculum tier (0–4).
|
| 95 |
+
metrics: Performance metrics for the just-completed episode.
|
| 96 |
+
``metrics.episode_history`` must already include the current
|
| 97 |
+
episode's success value as its last element.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
The new curriculum tier. May be the same tier (not yet mastered),
|
| 101 |
+
``tier + 1`` (normal graduation), or ``tier + 2`` (fast-track skip).
|
| 102 |
+
Never exceeds ``MAX_TIER``.
|
| 103 |
+
"""
|
| 104 |
+
if tier >= MAX_TIER:
|
| 105 |
+
return MAX_TIER
|
| 106 |
+
|
| 107 |
+
history: Sequence[bool] = metrics.episode_history
|
| 108 |
+
n_episodes = len(history)
|
| 109 |
+
|
| 110 |
+
if n_episodes == 0:
|
| 111 |
+
return tier
|
| 112 |
+
|
| 113 |
+
# Use the most recent ROLLING_WINDOW episodes for the rolling rate
|
| 114 |
+
window = list(history[-ROLLING_WINDOW:])
|
| 115 |
+
rolling_rate = sum(window) / len(window)
|
| 116 |
+
|
| 117 |
+
# Fast-track: 90%+ success after at least 3 episodes → skip one tier
|
| 118 |
+
if n_episodes >= FAST_TRACK_MIN_EPISODES and rolling_rate >= FAST_TRACK_THRESHOLD:
|
| 119 |
+
new_tier = min(tier + 2, MAX_TIER)
|
| 120 |
+
return new_tier
|
| 121 |
+
|
| 122 |
+
# Normal graduation: 70%+ rolling success → advance one tier
|
| 123 |
+
if rolling_rate >= MASTERY_THRESHOLD:
|
| 124 |
+
return min(tier + 1, MAX_TIER)
|
| 125 |
+
|
| 126 |
+
return tier
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def select_scenario(tier: int, rng: np.random.Generator) -> ScenarioConfig:
|
| 130 |
+
"""Select a ScenarioConfig appropriate for the given curriculum tier.
|
| 131 |
+
|
| 132 |
+
At tier 0 (warmup) the solid_tumor_chemo scenario is returned with an
|
| 133 |
+
inflated effect size (already encoded in the WARMUP ScenarioConfig).
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
tier: Current curriculum tier (0–4). Values outside [0, 4] are
|
| 137 |
+
clamped to the valid range.
|
| 138 |
+
rng: A seeded ``numpy.random.Generator`` used for any stochastic
|
| 139 |
+
selection. Currently each tier maps to exactly one scenario, so
|
| 140 |
+
``rng`` is accepted for API consistency and future extensibility
|
| 141 |
+
(e.g. sampling from a pool of scenarios at the same tier).
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
The ``ScenarioConfig`` for the given tier.
|
| 145 |
+
"""
|
| 146 |
+
clamped_tier = max(MIN_TIER, min(tier, MAX_TIER))
|
| 147 |
+
return _TIER_SCENARIO[clamped_tier]
|
server/curriculum/scenarios.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Scenario registry for the curriculum controller.
|
| 3 |
+
|
| 4 |
+
Defines ScenarioConfig instances for all four scenario IDs plus a tier-0 warmup
|
| 5 |
+
variant of solid_tumor_chemo with an inflated effect size.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from models import ScenarioConfig
|
| 9 |
+
|
| 10 |
+
# Tier 0 — warmup (solid_tumor_chemo with inflated effect size, easier)
|
| 11 |
+
WARMUP = ScenarioConfig(
|
| 12 |
+
scenario_id="solid_tumor_chemo_warmup",
|
| 13 |
+
curriculum_tier=0,
|
| 14 |
+
disease_area="oncology",
|
| 15 |
+
effect_size_range=(0.55, 0.85), # inflated vs tier-1 (0.25–0.55)
|
| 16 |
+
side_effect_rate_range=(0.10, 0.25),
|
| 17 |
+
placebo_response_range=(0.05, 0.15),
|
| 18 |
+
dropout_rate_range=(0.05, 0.10),
|
| 19 |
+
budget_usd=8_000_000.0,
|
| 20 |
+
time_budget_days=365,
|
| 21 |
+
min_sample_size=60,
|
| 22 |
+
description=(
|
| 23 |
+
"Warmup scenario: EGFR+ solid-tumour chemotherapy with an inflated "
|
| 24 |
+
"effect size to help the agent learn basic trial-design mechanics."
|
| 25 |
+
),
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Tier 1 — EGFR+ subgroup enrichment
|
| 29 |
+
SOLID_TUMOR_CHEMO = ScenarioConfig(
|
| 30 |
+
scenario_id="solid_tumor_chemo",
|
| 31 |
+
curriculum_tier=1,
|
| 32 |
+
disease_area="oncology",
|
| 33 |
+
effect_size_range=(0.25, 0.55),
|
| 34 |
+
side_effect_rate_range=(0.15, 0.35),
|
| 35 |
+
placebo_response_range=(0.05, 0.15),
|
| 36 |
+
dropout_rate_range=(0.05, 0.15),
|
| 37 |
+
budget_usd=10_000_000.0,
|
| 38 |
+
time_budget_days=540,
|
| 39 |
+
min_sample_size=80,
|
| 40 |
+
description=(
|
| 41 |
+
"EGFR+ solid-tumour chemotherapy. Agent must identify the EGFR+ "
|
| 42 |
+
"biomarker subgroup to unlock the true effect size."
|
| 43 |
+
),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Tier 2 — U-shaped dose-response
|
| 47 |
+
AUTOIMMUNE_BIOLOGIC = ScenarioConfig(
|
| 48 |
+
scenario_id="autoimmune_biologic",
|
| 49 |
+
curriculum_tier=2,
|
| 50 |
+
disease_area="immunology",
|
| 51 |
+
effect_size_range=(0.20, 0.45),
|
| 52 |
+
side_effect_rate_range=(0.10, 0.30),
|
| 53 |
+
placebo_response_range=(0.15, 0.30),
|
| 54 |
+
dropout_rate_range=(0.08, 0.18),
|
| 55 |
+
budget_usd=15_000_000.0,
|
| 56 |
+
time_budget_days=720,
|
| 57 |
+
min_sample_size=120,
|
| 58 |
+
description=(
|
| 59 |
+
"Autoimmune biologic with a U-shaped dose-response curve. "
|
| 60 |
+
"Agent must run dose-escalation to find the optimal dose window."
|
| 61 |
+
),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Tier 3 — high placebo response
|
| 65 |
+
CNS_DEPRESSION = ScenarioConfig(
|
| 66 |
+
scenario_id="cns_depression",
|
| 67 |
+
curriculum_tier=3,
|
| 68 |
+
disease_area="psychiatry",
|
| 69 |
+
effect_size_range=(0.15, 0.35),
|
| 70 |
+
side_effect_rate_range=(0.10, 0.25),
|
| 71 |
+
placebo_response_range=(0.35, 0.55), # high placebo response
|
| 72 |
+
dropout_rate_range=(0.10, 0.25),
|
| 73 |
+
budget_usd=20_000_000.0,
|
| 74 |
+
time_budget_days=900,
|
| 75 |
+
min_sample_size=200,
|
| 76 |
+
description=(
|
| 77 |
+
"CNS depression trial with a high placebo-response rate. "
|
| 78 |
+
"Agent must power the study to detect a small drug-placebo delta."
|
| 79 |
+
),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Tier 4 — rare disease / tiny n
|
| 83 |
+
RARE_DISEASE_ORPHAN = ScenarioConfig(
|
| 84 |
+
scenario_id="rare_disease_orphan",
|
| 85 |
+
curriculum_tier=4,
|
| 86 |
+
disease_area="rare_disease",
|
| 87 |
+
effect_size_range=(0.40, 0.80), # larger effect needed to compensate tiny n
|
| 88 |
+
side_effect_rate_range=(0.05, 0.20),
|
| 89 |
+
placebo_response_range=(0.05, 0.15),
|
| 90 |
+
dropout_rate_range=(0.05, 0.15),
|
| 91 |
+
budget_usd=5_000_000.0,
|
| 92 |
+
time_budget_days=1080,
|
| 93 |
+
min_sample_size=10, # tiny n — orphan disease
|
| 94 |
+
description=(
|
| 95 |
+
"Rare-disease orphan drug trial with a very small patient population. "
|
| 96 |
+
"Agent must justify statistical validity under FDA orphan-drug rules."
|
| 97 |
+
),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Registry — keyed by scenario_id for O(1) lookup
|
| 101 |
+
SCENARIOS: dict[str, ScenarioConfig] = {
|
| 102 |
+
WARMUP.scenario_id: WARMUP,
|
| 103 |
+
SOLID_TUMOR_CHEMO.scenario_id: SOLID_TUMOR_CHEMO,
|
| 104 |
+
AUTOIMMUNE_BIOLOGIC.scenario_id: AUTOIMMUNE_BIOLOGIC,
|
| 105 |
+
CNS_DEPRESSION.scenario_id: CNS_DEPRESSION,
|
| 106 |
+
RARE_DISEASE_ORPHAN.scenario_id: RARE_DISEASE_ORPHAN,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Convenience list ordered by tier
|
| 110 |
+
SCENARIO_LIST: list[ScenarioConfig] = [
|
| 111 |
+
WARMUP,
|
| 112 |
+
SOLID_TUMOR_CHEMO,
|
| 113 |
+
AUTOIMMUNE_BIOLOGIC,
|
| 114 |
+
CNS_DEPRESSION,
|
| 115 |
+
RARE_DISEASE_ORPHAN,
|
| 116 |
+
]
|
server/episode_manager.py
CHANGED
|
@@ -10,38 +10,45 @@ from __future__ import annotations
|
|
| 10 |
|
| 11 |
import random
|
| 12 |
import uuid
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
from models import (
|
|
|
|
| 15 |
RewardBreakdown,
|
| 16 |
ScenarioConfig,
|
| 17 |
TrialAction,
|
| 18 |
TrialLatentState,
|
| 19 |
TrialObservation,
|
| 20 |
-
TrialResult,
|
| 21 |
TrialState,
|
| 22 |
)
|
|
|
|
|
|
|
| 23 |
from server.logger import EpisodeLogger
|
| 24 |
from server.noise_model import NoiseModel
|
|
|
|
|
|
|
| 25 |
from server.rules.fda_rules import check_fda_compliance
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
scenario_id="solid_tumor_chemo",
|
| 30 |
-
curriculum_tier=0,
|
| 31 |
-
disease_area="NSCLC",
|
| 32 |
-
effect_size_range=(0.3, 0.7),
|
| 33 |
-
side_effect_rate_range=(0.05, 0.20),
|
| 34 |
-
placebo_response_range=(0.10, 0.25),
|
| 35 |
-
dropout_rate_range=(0.05, 0.15),
|
| 36 |
-
budget_usd=1_000_000.0,
|
| 37 |
-
time_budget_days=365,
|
| 38 |
-
min_sample_size=100,
|
| 39 |
-
description="Solid tumor chemotherapy — find EGFR+ subgroup",
|
| 40 |
-
)
|
| 41 |
|
| 42 |
_MAX_STEPS = 100
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
class EpisodeManager:
|
| 46 |
"""Orchestrates the reset/step lifecycle for a single clinical trial episode.
|
| 47 |
|
|
@@ -58,22 +65,35 @@ class EpisodeManager:
|
|
| 58 |
self._episode_id: str = ""
|
| 59 |
self._difficulty: float = 0.0
|
| 60 |
self._scenario: ScenarioConfig | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
# ------------------------------------------------------------------
|
| 63 |
# Public API
|
| 64 |
# ------------------------------------------------------------------
|
| 65 |
|
| 66 |
def reset(self, seed: int | None = None) -> TrialObservation:
|
| 67 |
-
"""Initialize a new episode and return the initial TrialObservation.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
resolved_seed = seed if seed is not None else random.randint(0, 2**31 - 1)
|
| 69 |
self._episode_id = str(uuid.uuid4())
|
| 70 |
|
| 71 |
-
# Step 1: Select scenario
|
| 72 |
-
scenario
|
|
|
|
|
|
|
| 73 |
self._scenario = scenario
|
| 74 |
|
| 75 |
-
# Step 2: Apply domain randomization via NoiseModel (
|
|
|
|
| 76 |
noise_model = NoiseModel(seed=resolved_seed)
|
|
|
|
| 77 |
randomized = noise_model.randomize(scenario)
|
| 78 |
|
| 79 |
# Sample concrete hidden values from randomized ranges
|
|
@@ -109,6 +129,7 @@ class EpisodeManager:
|
|
| 109 |
protocol_submitted=False,
|
| 110 |
interim_complete=False,
|
| 111 |
trial_complete=False,
|
|
|
|
| 112 |
episode_phase="literature_review",
|
| 113 |
action_history=[],
|
| 114 |
seed=resolved_seed,
|
|
@@ -117,26 +138,55 @@ class EpisodeManager:
|
|
| 117 |
# Step 4: Build lightweight TrialState for training loop
|
| 118 |
self._state = self._state_from_latent(self._latent, randomized)
|
| 119 |
|
|
|
|
| 120 |
self._clear_cache()
|
|
|
|
| 121 |
|
| 122 |
-
# Step
|
| 123 |
self._logger = EpisodeLogger(
|
| 124 |
-
|
|
|
|
| 125 |
)
|
| 126 |
self._total_reward = 0.0
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
def step(
|
| 132 |
self, action: TrialAction
|
| 133 |
) -> tuple[TrialObservation, RewardBreakdown, bool, dict]:
|
| 134 |
-
"""Advance the episode by one step.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
if self._latent is None or self._scenario is None:
|
| 136 |
raise RuntimeError("No active episode. Call reset() before step().")
|
| 137 |
|
| 138 |
try:
|
| 139 |
-
# Check FDA compliance
|
| 140 |
compliance = check_fda_compliance(action, self._latent)
|
| 141 |
|
| 142 |
if not compliance.valid:
|
|
@@ -146,73 +196,116 @@ class EpisodeManager:
|
|
| 146 |
r_info_gain=0.0,
|
| 147 |
r_efficiency=0.0,
|
| 148 |
r_novelty=0.0,
|
| 149 |
-
r_penalty=0.
|
| 150 |
r_terminal_success=0.0,
|
| 151 |
r_terminal_calibration=0.0,
|
| 152 |
)
|
| 153 |
done = False
|
|
|
|
| 154 |
info: dict = {
|
| 155 |
-
"step_index":
|
| 156 |
"action_valid": False,
|
| 157 |
"violations": compliance.violations,
|
| 158 |
}
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
rule_violations=compliance.violations,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
)
|
|
|
|
| 164 |
if self._logger is not None:
|
| 165 |
-
self._logger.log_step(
|
| 166 |
-
len(self._latent.action_history), action, obs, reward, done
|
| 167 |
-
)
|
| 168 |
return obs, reward, done, info
|
| 169 |
|
| 170 |
-
#
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
"action_history": (
|
| 174 |
-
self._latent.action_history + [action.action_type.value]
|
| 175 |
-
),
|
| 176 |
-
}
|
| 177 |
)
|
|
|
|
| 178 |
|
| 179 |
-
#
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
success=False,
|
| 183 |
-
power=0.8,
|
| 184 |
-
adverse_event_rate=0.1,
|
| 185 |
-
confidence_interval=(0.0, 1.0),
|
| 186 |
-
failure_reason=None,
|
| 187 |
-
)
|
| 188 |
|
| 189 |
-
#
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
r_terminal_calibration=0.0,
|
| 199 |
)
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
step_idx = len(self._latent.action_history)
|
| 202 |
done = step_idx >= _MAX_STEPS or self._latent.trial_complete
|
| 203 |
-
info = {"step_index": step_idx, "action_valid": True}
|
| 204 |
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
-
#
|
| 211 |
-
self._total_reward +=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
if self._logger is not None:
|
| 213 |
self._logger.log_step(step_idx, action, obs, reward, done)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
-
# Log summary on episode end (
|
| 216 |
if done and self._logger is not None:
|
| 217 |
self._logger.log_summary(
|
| 218 |
scenario_id=self._scenario.scenario_id,
|
|
@@ -223,11 +316,22 @@ class EpisodeManager:
|
|
| 223 |
),
|
| 224 |
)
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
return obs, reward, done, info
|
| 227 |
|
| 228 |
except RuntimeError:
|
| 229 |
raise
|
| 230 |
-
except Exception as exc: #
|
| 231 |
reward = RewardBreakdown(
|
| 232 |
r_validity=-1.0,
|
| 233 |
r_ordering=0.0,
|
|
@@ -244,10 +348,50 @@ class EpisodeManager:
|
|
| 244 |
"action_valid": False,
|
| 245 |
"violations": [f"Internal error: {exc}"],
|
| 246 |
}
|
| 247 |
-
|
| 248 |
-
self._latent
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
)
|
| 252 |
return obs, reward, False, info
|
| 253 |
|
|
@@ -276,9 +420,20 @@ class EpisodeManager:
|
|
| 276 |
"""Build the lightweight TrialState from latent state."""
|
| 277 |
step_count = len(latent.action_history)
|
| 278 |
unique_actions = len(set(latent.action_history))
|
| 279 |
-
action_diversity =
|
| 280 |
-
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
return TrialState(
|
| 283 |
episode_id=self._episode_id,
|
| 284 |
step_count=step_count,
|
|
@@ -287,37 +442,6 @@ class EpisodeManager:
|
|
| 287 |
curriculum_tier=str(scenario.curriculum_tier),
|
| 288 |
curriculum_stats={},
|
| 289 |
action_diversity=action_diversity,
|
| 290 |
-
phase_compliance_rate=
|
| 291 |
is_resolved=latent.trial_complete,
|
| 292 |
)
|
| 293 |
-
|
| 294 |
-
def _observation_from_latent(
|
| 295 |
-
self,
|
| 296 |
-
latent: TrialLatentState,
|
| 297 |
-
scenario: ScenarioConfig,
|
| 298 |
-
rule_violations: list[str] | None = None,
|
| 299 |
-
) -> TrialObservation:
|
| 300 |
-
"""Build a TrialObservation from latent state — noisy, agent-facing."""
|
| 301 |
-
return TrialObservation(
|
| 302 |
-
scenario_description=scenario.description,
|
| 303 |
-
phase_data={
|
| 304 |
-
"episode_phase": latent.episode_phase,
|
| 305 |
-
"observed_effect_estimate": None,
|
| 306 |
-
"observed_side_effect_rate": None,
|
| 307 |
-
"phase_i_complete": latent.phase_i_complete,
|
| 308 |
-
"interim_complete": latent.interim_complete,
|
| 309 |
-
"protocol_submitted": latent.protocol_submitted,
|
| 310 |
-
},
|
| 311 |
-
resource_status={
|
| 312 |
-
"budget_remaining": latent.budget_remaining,
|
| 313 |
-
"time_remaining_days": latent.time_remaining_days,
|
| 314 |
-
"patients_enrolled": latent.patients_enrolled,
|
| 315 |
-
},
|
| 316 |
-
rule_violations=rule_violations or [],
|
| 317 |
-
available_actions=[], # wired in Push 3 with TransitionEngine
|
| 318 |
-
steps_taken=len(latent.action_history),
|
| 319 |
-
max_steps=_MAX_STEPS,
|
| 320 |
-
hint="", # populated by TrialJudge at junior difficulty (Push 3)
|
| 321 |
-
done=latent.trial_complete,
|
| 322 |
-
reward=0.0, # filled in by step() after reward computation
|
| 323 |
-
)
|
|
|
|
| 10 |
|
| 11 |
import random
|
| 12 |
import uuid
|
| 13 |
+
from datetime import datetime, timezone
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
|
| 17 |
from models import (
|
| 18 |
+
EpisodeTranscript,
|
| 19 |
RewardBreakdown,
|
| 20 |
ScenarioConfig,
|
| 21 |
TrialAction,
|
| 22 |
TrialLatentState,
|
| 23 |
TrialObservation,
|
|
|
|
| 24 |
TrialState,
|
| 25 |
)
|
| 26 |
+
from server.curriculum.controller import select_scenario
|
| 27 |
+
from server.judge import TrialJudge
|
| 28 |
from server.logger import EpisodeLogger
|
| 29 |
from server.noise_model import NoiseModel
|
| 30 |
+
from server.phase_detector import detect_phase
|
| 31 |
+
from server.reward.reward_computer import compute_reward
|
| 32 |
from server.rules.fda_rules import check_fda_compliance
|
| 33 |
+
from server.simulator.output_generator import OutputGenerator
|
| 34 |
+
from server.simulator.transition_engine import TransitionEngine
|
| 35 |
+
from server.simulator.trial_simulator import simulate_trial
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
_MAX_STEPS = 100
|
| 38 |
|
| 39 |
|
| 40 |
+
def _phase_order_correct_at(phase: str, prior_history: list[str]) -> bool:
|
| 41 |
+
"""Return True if `phase` is a valid next phase given `prior_history`."""
|
| 42 |
+
from server.phase_detector import PHASE_ORDER
|
| 43 |
+
|
| 44 |
+
if not prior_history:
|
| 45 |
+
return True
|
| 46 |
+
last = prior_history[-1]
|
| 47 |
+
last_idx = PHASE_ORDER.index(last) if last in PHASE_ORDER else 0
|
| 48 |
+
current_idx = PHASE_ORDER.index(phase) if phase in PHASE_ORDER else 0
|
| 49 |
+
return current_idx >= last_idx and (current_idx - last_idx) <= 1
|
| 50 |
+
|
| 51 |
+
|
| 52 |
class EpisodeManager:
|
| 53 |
"""Orchestrates the reset/step lifecycle for a single clinical trial episode.
|
| 54 |
|
|
|
|
| 65 |
self._episode_id: str = ""
|
| 66 |
self._difficulty: float = 0.0
|
| 67 |
self._scenario: ScenarioConfig | None = None
|
| 68 |
+
self._phase_history: list[str] = []
|
| 69 |
+
self._noise_model: NoiseModel | None = None
|
| 70 |
+
self._curriculum_tier: int = 0
|
| 71 |
+
self._transition_engine: TransitionEngine = TransitionEngine()
|
| 72 |
+
self._judge: TrialJudge = TrialJudge()
|
| 73 |
|
| 74 |
# ------------------------------------------------------------------
|
| 75 |
# Public API
|
| 76 |
# ------------------------------------------------------------------
|
| 77 |
|
| 78 |
def reset(self, seed: int | None = None) -> TrialObservation:
|
| 79 |
+
"""Initialize a new episode and return the initial TrialObservation.
|
| 80 |
+
|
| 81 |
+
Seeded resets are reproducible: same seed → same scenario selection
|
| 82 |
+
and initial TrialLatentState (Req 8.5, 9.4).
|
| 83 |
+
"""
|
| 84 |
resolved_seed = seed if seed is not None else random.randint(0, 2**31 - 1)
|
| 85 |
self._episode_id = str(uuid.uuid4())
|
| 86 |
|
| 87 |
+
# Step 1: Select scenario via CurriculumController (Req 8.3, 8.5)
|
| 88 |
+
# Use a seeded RNG so scenario selection is reproducible for same seed.
|
| 89 |
+
scenario_rng = np.random.default_rng(resolved_seed)
|
| 90 |
+
scenario = select_scenario(self._curriculum_tier, scenario_rng)
|
| 91 |
self._scenario = scenario
|
| 92 |
|
| 93 |
+
# Step 2: Apply domain randomization via NoiseModel (Req 9.1, 9.2)
|
| 94 |
+
# NoiseModel is seeded so same seed → same randomized config.
|
| 95 |
noise_model = NoiseModel(seed=resolved_seed)
|
| 96 |
+
self._noise_model = noise_model
|
| 97 |
randomized = noise_model.randomize(scenario)
|
| 98 |
|
| 99 |
# Sample concrete hidden values from randomized ranges
|
|
|
|
| 129 |
protocol_submitted=False,
|
| 130 |
interim_complete=False,
|
| 131 |
trial_complete=False,
|
| 132 |
+
adverse_events=0,
|
| 133 |
episode_phase="literature_review",
|
| 134 |
action_history=[],
|
| 135 |
seed=resolved_seed,
|
|
|
|
| 138 |
# Step 4: Build lightweight TrialState for training loop
|
| 139 |
self._state = self._state_from_latent(self._latent, randomized)
|
| 140 |
|
| 141 |
+
# Step 5: Clear power cache (Req 14.3)
|
| 142 |
self._clear_cache()
|
| 143 |
+
self._phase_history = []
|
| 144 |
|
| 145 |
+
# Step 6: Fresh logger (episode_id matches this episode), reward accumulator
|
| 146 |
self._logger = EpisodeLogger(
|
| 147 |
+
episode_id=self._episode_id,
|
| 148 |
+
curriculum_tier=randomized.curriculum_tier,
|
| 149 |
)
|
| 150 |
self._total_reward = 0.0
|
| 151 |
+
# Difficulty scales linearly with curriculum tier: tier 0 → 0.0, tier 4 → 1.0
|
| 152 |
+
self._difficulty = scenario.curriculum_tier / 4.0
|
| 153 |
+
|
| 154 |
+
# Step 7: Return initial TrialObservation via OutputGenerator
|
| 155 |
+
output_gen = OutputGenerator(noise_model)
|
| 156 |
+
return output_gen.generate(
|
| 157 |
+
latent=self._latent,
|
| 158 |
+
trial_state=self._state,
|
| 159 |
+
steps_taken=0,
|
| 160 |
+
max_steps=_MAX_STEPS,
|
| 161 |
+
rule_violations=[],
|
| 162 |
+
done=False,
|
| 163 |
+
reward=0.0,
|
| 164 |
+
scenario_description=scenario.description,
|
| 165 |
+
hint="",
|
| 166 |
+
)
|
| 167 |
|
| 168 |
def step(
|
| 169 |
self, action: TrialAction
|
| 170 |
) -> tuple[TrialObservation, RewardBreakdown, bool, dict]:
|
| 171 |
+
"""Advance the episode by one step.
|
| 172 |
+
|
| 173 |
+
Full pipeline (Req 8.5, 9.4, 7.1):
|
| 174 |
+
1. Validate active episode
|
| 175 |
+
2. check_fda_compliance → ComplianceResult
|
| 176 |
+
3. TransitionEngine.apply_transition() mutates TrialLatentState
|
| 177 |
+
4. OutputGenerator.generate() produces noisy TrialObservation
|
| 178 |
+
5. compute_reward() → RewardBreakdown
|
| 179 |
+
6. PhaseDetector.detect_phase() classifies action
|
| 180 |
+
7. TrialJudge.verify() for hint/feedback
|
| 181 |
+
8. Check terminal condition
|
| 182 |
+
9. Log full EpisodeTranscript to JSONL
|
| 183 |
+
10. Return (obs, reward_breakdown, done, info)
|
| 184 |
+
"""
|
| 185 |
if self._latent is None or self._scenario is None:
|
| 186 |
raise RuntimeError("No active episode. Call reset() before step().")
|
| 187 |
|
| 188 |
try:
|
| 189 |
+
# Step 1: Check FDA compliance (read-only, does not mutate state)
|
| 190 |
compliance = check_fda_compliance(action, self._latent)
|
| 191 |
|
| 192 |
if not compliance.valid:
|
|
|
|
| 196 |
r_info_gain=0.0,
|
| 197 |
r_efficiency=0.0,
|
| 198 |
r_novelty=0.0,
|
| 199 |
+
r_penalty=-0.5 * len(compliance.violations),
|
| 200 |
r_terminal_success=0.0,
|
| 201 |
r_terminal_calibration=0.0,
|
| 202 |
)
|
| 203 |
done = False
|
| 204 |
+
step_idx = len(self._latent.action_history)
|
| 205 |
info: dict = {
|
| 206 |
+
"step_index": step_idx,
|
| 207 |
"action_valid": False,
|
| 208 |
"violations": compliance.violations,
|
| 209 |
}
|
| 210 |
+
# Build observation without mutating latent
|
| 211 |
+
noise_model = self._noise_model or NoiseModel(seed=self._latent.seed)
|
| 212 |
+
output_gen = OutputGenerator(noise_model)
|
| 213 |
+
obs = output_gen.generate(
|
| 214 |
+
latent=self._latent,
|
| 215 |
+
trial_state=self._state
|
| 216 |
+
or self._state_from_latent(self._latent, self._scenario),
|
| 217 |
+
steps_taken=step_idx,
|
| 218 |
+
max_steps=_MAX_STEPS,
|
| 219 |
rule_violations=compliance.violations,
|
| 220 |
+
done=False,
|
| 221 |
+
reward=reward.total,
|
| 222 |
+
scenario_description=self._scenario.description,
|
| 223 |
+
hint="",
|
| 224 |
)
|
| 225 |
+
# Log invalid step
|
| 226 |
if self._logger is not None:
|
| 227 |
+
self._logger.log_step(step_idx, action, obs, reward, done)
|
|
|
|
|
|
|
| 228 |
return obs, reward, done, info
|
| 229 |
|
| 230 |
+
# Step 2: TransitionEngine mutates TrialLatentState
|
| 231 |
+
updated_latent = self._transition_engine.apply_transition(
|
| 232 |
+
self._latent, action
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
)
|
| 234 |
+
self._latent = updated_latent
|
| 235 |
|
| 236 |
+
# Step 3: Detect phase and update phase history
|
| 237 |
+
phase_name, phase_order_correct = detect_phase(action, self._phase_history)
|
| 238 |
+
self._phase_history = self._phase_history + [phase_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
+
# Step 4: Simulate trial result for reward computation
|
| 241 |
+
result = simulate_trial(self._latent, action)
|
| 242 |
+
|
| 243 |
+
# Step 5: Compute reward (all 8 components)
|
| 244 |
+
reward = compute_reward(
|
| 245 |
+
action=action,
|
| 246 |
+
latent=self._latent,
|
| 247 |
+
result=result,
|
| 248 |
+
phase_history=self._phase_history[:-1], # history before this step
|
|
|
|
| 249 |
)
|
| 250 |
|
| 251 |
+
# Step 6: TrialJudge verification (hint + overconfidence penalty)
|
| 252 |
+
self._state = self._state_from_latent(self._latent, self._scenario)
|
| 253 |
+
judge_result = self._judge.verify(action, self._state, self._latent)
|
| 254 |
+
hint = judge_result.hint or ""
|
| 255 |
+
|
| 256 |
+
# Apply overconfidence penalty to r_penalty
|
| 257 |
+
if judge_result.overconfidence_penalty != 0.0:
|
| 258 |
+
reward = reward.model_copy(
|
| 259 |
+
update={
|
| 260 |
+
"r_penalty": (
|
| 261 |
+
reward.r_penalty + judge_result.overconfidence_penalty
|
| 262 |
+
)
|
| 263 |
+
}
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Step 7: Check terminal condition
|
| 267 |
step_idx = len(self._latent.action_history)
|
| 268 |
done = step_idx >= _MAX_STEPS or self._latent.trial_complete
|
|
|
|
| 269 |
|
| 270 |
+
# Step 8: Generate noisy observation via OutputGenerator
|
| 271 |
+
noise_model = self._noise_model or NoiseModel(seed=self._latent.seed)
|
| 272 |
+
output_gen = OutputGenerator(noise_model)
|
| 273 |
+
obs = output_gen.generate(
|
| 274 |
+
latent=self._latent,
|
| 275 |
+
trial_state=self._state,
|
| 276 |
+
steps_taken=step_idx,
|
| 277 |
+
max_steps=_MAX_STEPS,
|
| 278 |
+
rule_violations=[],
|
| 279 |
+
done=done,
|
| 280 |
+
reward=reward.total,
|
| 281 |
+
scenario_description=self._scenario.description,
|
| 282 |
+
hint=hint,
|
| 283 |
+
)
|
| 284 |
|
| 285 |
+
# Step 9: Accumulate total reward
|
| 286 |
+
self._total_reward += reward.total
|
| 287 |
+
|
| 288 |
+
# Step 10: Log full EpisodeTranscript record to JSONL (Req 7.1)
|
| 289 |
+
transcript = EpisodeTranscript(
|
| 290 |
+
episode_id=self._episode_id,
|
| 291 |
+
step=step_idx,
|
| 292 |
+
action=action,
|
| 293 |
+
observation=obs,
|
| 294 |
+
reward_breakdown=reward.model_dump(),
|
| 295 |
+
total_reward=reward.total,
|
| 296 |
+
phase_detected=phase_name,
|
| 297 |
+
phase_order_correct=phase_order_correct,
|
| 298 |
+
hidden_state_snapshot=self._latent,
|
| 299 |
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
| 300 |
+
)
|
| 301 |
if self._logger is not None:
|
| 302 |
self._logger.log_step(step_idx, action, obs, reward, done)
|
| 303 |
+
# Also write the full EpisodeTranscript as a separate JSONL record
|
| 304 |
+
self._logger._append_jsonl(
|
| 305 |
+
{"type": "transcript", **transcript.model_dump(mode="json")}
|
| 306 |
+
)
|
| 307 |
|
| 308 |
+
# Log summary on episode end (Req 7.2)
|
| 309 |
if done and self._logger is not None:
|
| 310 |
self._logger.log_summary(
|
| 311 |
scenario_id=self._scenario.scenario_id,
|
|
|
|
| 316 |
),
|
| 317 |
)
|
| 318 |
|
| 319 |
+
info = {
|
| 320 |
+
"step_index": step_idx,
|
| 321 |
+
"action_valid": True,
|
| 322 |
+
"phase_detected": phase_name,
|
| 323 |
+
"phase_order_correct": phase_order_correct,
|
| 324 |
+
"judge_passed": judge_result.passed,
|
| 325 |
+
"judge_feedback": judge_result.feedback,
|
| 326 |
+
"judge_hint": hint,
|
| 327 |
+
"overconfidence_penalty": judge_result.overconfidence_penalty,
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
return obs, reward, done, info
|
| 331 |
|
| 332 |
except RuntimeError:
|
| 333 |
raise
|
| 334 |
+
except Exception as exc: # Req 10.4: no unhandled exceptions
|
| 335 |
reward = RewardBreakdown(
|
| 336 |
r_validity=-1.0,
|
| 337 |
r_ordering=0.0,
|
|
|
|
| 348 |
"action_valid": False,
|
| 349 |
"violations": [f"Internal error: {exc}"],
|
| 350 |
}
|
| 351 |
+
noise_model = self._noise_model or NoiseModel(
|
| 352 |
+
seed=self._latent.seed if self._latent else 0
|
| 353 |
+
)
|
| 354 |
+
output_gen = OutputGenerator(noise_model)
|
| 355 |
+
obs = (
|
| 356 |
+
output_gen.generate(
|
| 357 |
+
latent=self._latent,
|
| 358 |
+
trial_state=self._state
|
| 359 |
+
or TrialState(
|
| 360 |
+
episode_id=self._episode_id,
|
| 361 |
+
step_count=step_idx,
|
| 362 |
+
difficulty=self._difficulty,
|
| 363 |
+
scenario_id=self._scenario.scenario_id
|
| 364 |
+
if self._scenario
|
| 365 |
+
else "",
|
| 366 |
+
curriculum_tier="0",
|
| 367 |
+
curriculum_stats={},
|
| 368 |
+
action_diversity=0.0,
|
| 369 |
+
phase_compliance_rate=0.0,
|
| 370 |
+
is_resolved=False,
|
| 371 |
+
),
|
| 372 |
+
steps_taken=step_idx,
|
| 373 |
+
max_steps=_MAX_STEPS,
|
| 374 |
+
rule_violations=[f"Internal error: {exc}"],
|
| 375 |
+
done=False,
|
| 376 |
+
reward=reward.total,
|
| 377 |
+
scenario_description=(
|
| 378 |
+
self._scenario.description if self._scenario else ""
|
| 379 |
+
),
|
| 380 |
+
hint="",
|
| 381 |
+
)
|
| 382 |
+
if self._latent is not None
|
| 383 |
+
else TrialObservation(
|
| 384 |
+
scenario_description="",
|
| 385 |
+
phase_data={},
|
| 386 |
+
resource_status={},
|
| 387 |
+
rule_violations=[f"Internal error: {exc}"],
|
| 388 |
+
available_actions=[],
|
| 389 |
+
steps_taken=step_idx,
|
| 390 |
+
max_steps=_MAX_STEPS,
|
| 391 |
+
hint="",
|
| 392 |
+
done=False,
|
| 393 |
+
reward=0.0,
|
| 394 |
+
)
|
| 395 |
)
|
| 396 |
return obs, reward, False, info
|
| 397 |
|
|
|
|
| 420 |
"""Build the lightweight TrialState from latent state."""
|
| 421 |
step_count = len(latent.action_history)
|
| 422 |
unique_actions = len(set(latent.action_history))
|
| 423 |
+
action_diversity = unique_actions / step_count if step_count > 0 else 0.0
|
| 424 |
+
|
| 425 |
+
# Compute phase compliance rate from phase history
|
| 426 |
+
phase_steps = len(self._phase_history)
|
| 427 |
+
if phase_steps > 0:
|
| 428 |
+
correct_count = sum(
|
| 429 |
+
1
|
| 430 |
+
for i, ph in enumerate(self._phase_history)
|
| 431 |
+
if _phase_order_correct_at(ph, self._phase_history[:i])
|
| 432 |
+
)
|
| 433 |
+
phase_compliance_rate = correct_count / phase_steps
|
| 434 |
+
else:
|
| 435 |
+
phase_compliance_rate = 0.0
|
| 436 |
+
|
| 437 |
return TrialState(
|
| 438 |
episode_id=self._episode_id,
|
| 439 |
step_count=step_count,
|
|
|
|
| 442 |
curriculum_tier=str(scenario.curriculum_tier),
|
| 443 |
curriculum_stats={},
|
| 444 |
action_diversity=action_diversity,
|
| 445 |
+
phase_compliance_rate=phase_compliance_rate,
|
| 446 |
is_resolved=latent.trial_complete,
|
| 447 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server/judge.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trial Judge — multi-layer verification for clinical trial design decisions.
|
| 3 |
+
|
| 4 |
+
Layer 1 (programmatic, authoritative, never overridden):
|
| 5 |
+
- power >= 0.80
|
| 6 |
+
- p_value < 0.05
|
| 7 |
+
- FDA compliance passes
|
| 8 |
+
- budget_remaining > 0
|
| 9 |
+
|
| 10 |
+
Layer 2 (persona-scaled LLM stub):
|
| 11 |
+
- junior (difficulty < 0.4): gives hints, lenient feedback
|
| 12 |
+
- senior (0.4–0.7): balanced feedback
|
| 13 |
+
- principal (> 0.7): strict, no hints
|
| 14 |
+
|
| 15 |
+
Overconfidence penalty: -0.5 per high-confidence wrong claim
|
| 16 |
+
(action.confidence >= 0.8 and the claim is incorrect per Layer 1).
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
from pydantic import BaseModel
|
| 22 |
+
|
| 23 |
+
from models import TrialAction, TrialLatentState, TrialState
|
| 24 |
+
from server.rules.fda_rules import check_fda_compliance
|
| 25 |
+
from server.simulator.power_calculator import calculate_power
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Result model
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class JudgeResult(BaseModel):
|
| 33 |
+
"""Output of TrialJudge.verify()."""
|
| 34 |
+
|
| 35 |
+
passed: bool
|
| 36 |
+
violations: list[str]
|
| 37 |
+
feedback: str
|
| 38 |
+
hint: str | None
|
| 39 |
+
overconfidence_penalty: float
|
| 40 |
+
persona: str
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# Persona thresholds
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
_JUNIOR_MAX = 0.4
|
| 48 |
+
_SENIOR_MAX = 0.7
|
| 49 |
+
_HIGH_CONFIDENCE_THRESHOLD = 0.8
|
| 50 |
+
_OVERCONFIDENCE_PENALTY = -0.5
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _select_persona(difficulty: float) -> str:
|
| 54 |
+
if difficulty < _JUNIOR_MAX:
|
| 55 |
+
return "junior"
|
| 56 |
+
if difficulty <= _SENIOR_MAX:
|
| 57 |
+
return "senior"
|
| 58 |
+
return "principal"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# Layer 2: rule-based LLM stub
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _generate_feedback(
|
| 67 |
+
persona: str,
|
| 68 |
+
violations: list[str],
|
| 69 |
+
passed: bool,
|
| 70 |
+
action: TrialAction,
|
| 71 |
+
latent: TrialLatentState,
|
| 72 |
+
) -> tuple[str, str | None]:
|
| 73 |
+
"""Return (feedback, hint) for the given persona.
|
| 74 |
+
|
| 75 |
+
This is a rule-based stub that can be replaced with a real LLM call later.
|
| 76 |
+
The stub generates contextually appropriate strings without an LLM.
|
| 77 |
+
"""
|
| 78 |
+
action_name = action.action_type.value.replace("_", " ")
|
| 79 |
+
|
| 80 |
+
if passed:
|
| 81 |
+
if persona == "junior":
|
| 82 |
+
feedback = (
|
| 83 |
+
f"Good work on '{action_name}'! Your trial design looks solid. "
|
| 84 |
+
f"Power and significance thresholds are met. Keep it up!"
|
| 85 |
+
)
|
| 86 |
+
hint = (
|
| 87 |
+
"Tip: continue building on this foundation — "
|
| 88 |
+
"consider biomarker stratification next to improve precision."
|
| 89 |
+
)
|
| 90 |
+
elif persona == "senior":
|
| 91 |
+
feedback = (
|
| 92 |
+
f"'{action_name}' passes all programmatic checks. "
|
| 93 |
+
f"Statistical power and p-value criteria are satisfied. "
|
| 94 |
+
f"Proceed to the next design step."
|
| 95 |
+
)
|
| 96 |
+
hint = None
|
| 97 |
+
else: # principal
|
| 98 |
+
feedback = (
|
| 99 |
+
f"'{action_name}' meets minimum criteria. "
|
| 100 |
+
f"Ensure alpha-spending and interim analysis boundaries "
|
| 101 |
+
f"are pre-specified before submission."
|
| 102 |
+
)
|
| 103 |
+
hint = None
|
| 104 |
+
else:
|
| 105 |
+
violation_summary = "; ".join(violations) if violations else "unknown issue"
|
| 106 |
+
if persona == "junior":
|
| 107 |
+
feedback = (
|
| 108 |
+
f"'{action_name}' did not pass verification. "
|
| 109 |
+
f"Issues found: {violation_summary}. "
|
| 110 |
+
f"Review the requirements and try again."
|
| 111 |
+
)
|
| 112 |
+
hint = _build_hint_for_violations(violations, latent)
|
| 113 |
+
elif persona == "senior":
|
| 114 |
+
feedback = (
|
| 115 |
+
f"'{action_name}' failed verification. "
|
| 116 |
+
f"Violations: {violation_summary}. "
|
| 117 |
+
f"Address these before proceeding."
|
| 118 |
+
)
|
| 119 |
+
hint = None
|
| 120 |
+
else: # principal
|
| 121 |
+
feedback = (
|
| 122 |
+
f"'{action_name}' is non-compliant. "
|
| 123 |
+
f"Violations: {violation_summary}. "
|
| 124 |
+
f"No further guidance will be provided — resolve independently."
|
| 125 |
+
)
|
| 126 |
+
hint = None
|
| 127 |
+
|
| 128 |
+
return feedback, hint
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _build_hint_for_violations(
|
| 132 |
+
violations: list[str], latent: TrialLatentState
|
| 133 |
+
) -> str | None:
|
| 134 |
+
"""Build a contextual hint for junior persona based on violation content."""
|
| 135 |
+
if not violations:
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
first = violations[0].lower()
|
| 139 |
+
|
| 140 |
+
if "power" in first:
|
| 141 |
+
return (
|
| 142 |
+
"Hint: current power is below 0.80. "
|
| 143 |
+
"Try increasing the sample size — "
|
| 144 |
+
"more patients enrolled improves statistical power."
|
| 145 |
+
)
|
| 146 |
+
if "p-value" in first or "p_value" in first or "significance" in first:
|
| 147 |
+
return (
|
| 148 |
+
"Hint: the p-value threshold of 0.05 is not met. "
|
| 149 |
+
"Consider a larger effect size or more patients."
|
| 150 |
+
)
|
| 151 |
+
if "budget" in first:
|
| 152 |
+
return (
|
| 153 |
+
f"Hint: budget is exhausted (remaining: {latent.budget_remaining:.2f}). "
|
| 154 |
+
f"Look for cost-saving measures or request a protocol amendment."
|
| 155 |
+
)
|
| 156 |
+
if "fda" in first or "compliance" in first or "permitted" in first:
|
| 157 |
+
return (
|
| 158 |
+
f"Hint: this action is not allowed in the current phase "
|
| 159 |
+
f"('{latent.episode_phase}'). "
|
| 160 |
+
f"Check the transition table for permitted actions."
|
| 161 |
+
)
|
| 162 |
+
if "sample size" in first:
|
| 163 |
+
return "Hint: the minimum regulatory sample size is 30 participants."
|
| 164 |
+
if "protocol" in first:
|
| 165 |
+
return "Hint: submit the protocol before attempting FDA review."
|
| 166 |
+
if "phase i" in first:
|
| 167 |
+
return "Hint: complete Phase I before submitting to FDA review."
|
| 168 |
+
if "interim" in first:
|
| 169 |
+
return "Hint: run an interim analysis before the primary analysis."
|
| 170 |
+
if "patients" in first or "enrolled" in first:
|
| 171 |
+
return "Hint: enroll patients before running analyses."
|
| 172 |
+
|
| 173 |
+
# Generic fallback
|
| 174 |
+
return f"Hint: {violations[0]}"
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# ---------------------------------------------------------------------------
|
| 178 |
+
# Main judge class
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class TrialJudge:
|
| 183 |
+
"""Multi-layer trial design verifier.
|
| 184 |
+
|
| 185 |
+
Layer 1 is programmatic and authoritative — its result is never overridden.
|
| 186 |
+
Layer 2 is persona-scaled and provides human-readable feedback and hints.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def verify(
|
| 190 |
+
self,
|
| 191 |
+
action: TrialAction,
|
| 192 |
+
state: TrialState,
|
| 193 |
+
latent: TrialLatentState,
|
| 194 |
+
) -> JudgeResult:
|
| 195 |
+
"""Verify the action against both programmatic and persona layers.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
action: The agent's action to evaluate.
|
| 199 |
+
state: Lightweight training-loop metadata (carries difficulty).
|
| 200 |
+
latent: Hidden ground-truth + episode tracking state.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
JudgeResult with pass/fail, violations, feedback, hint, and penalty.
|
| 204 |
+
"""
|
| 205 |
+
violations: list[str] = []
|
| 206 |
+
|
| 207 |
+
# ------------------------------------------------------------------
|
| 208 |
+
# Layer 1: Programmatic checks (authoritative, never overridden)
|
| 209 |
+
# ------------------------------------------------------------------
|
| 210 |
+
|
| 211 |
+
# 1a. Budget check
|
| 212 |
+
if latent.budget_remaining <= 0:
|
| 213 |
+
violations.append(
|
| 214 |
+
f"Budget exhausted: budget_remaining={latent.budget_remaining:.2f} "
|
| 215 |
+
f"(must be > 0)."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# 1b. Statistical power check
|
| 219 |
+
n = max(latent.patients_enrolled, 1)
|
| 220 |
+
power = calculate_power(latent.true_effect_size, n)
|
| 221 |
+
if power < 0.80:
|
| 222 |
+
violations.append(
|
| 223 |
+
f"Insufficient statistical power: {power:.3f} < 0.80 "
|
| 224 |
+
f"(effect_size={latent.true_effect_size:.3f}, n={n})."
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# 1c. p-value check — derive from power/effect/n
|
| 228 |
+
# We use the same normal approximation as the simulator.
|
| 229 |
+
import math
|
| 230 |
+
|
| 231 |
+
from scipy.stats import norm
|
| 232 |
+
|
| 233 |
+
if n > 0 and latent.true_effect_size != 0.0:
|
| 234 |
+
n_per_arm = n / 2.0
|
| 235 |
+
se = 1.0 / math.sqrt(n_per_arm) if n_per_arm > 0 else 1.0
|
| 236 |
+
z_stat = latent.true_effect_size / se
|
| 237 |
+
p_value = float(2.0 * norm.sf(abs(z_stat)))
|
| 238 |
+
else:
|
| 239 |
+
p_value = 1.0
|
| 240 |
+
|
| 241 |
+
if p_value >= 0.05:
|
| 242 |
+
violations.append(
|
| 243 |
+
f"p-value not significant: {p_value:.4f} >= 0.05 "
|
| 244 |
+
f"(n={n}, effect_size={latent.true_effect_size:.3f})."
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# 1d. FDA compliance check
|
| 248 |
+
compliance = check_fda_compliance(action, latent)
|
| 249 |
+
if not compliance.valid:
|
| 250 |
+
violations.extend(compliance.violations)
|
| 251 |
+
|
| 252 |
+
passed = len(violations) == 0
|
| 253 |
+
|
| 254 |
+
# ------------------------------------------------------------------
|
| 255 |
+
# Overconfidence penalty
|
| 256 |
+
# ------------------------------------------------------------------
|
| 257 |
+
# A "high-confidence wrong claim" is when the agent's confidence is
|
| 258 |
+
# >= 0.8 but Layer 1 found violations (the claim is incorrect).
|
| 259 |
+
overconfidence_penalty = 0.0
|
| 260 |
+
if not passed and action.confidence >= _HIGH_CONFIDENCE_THRESHOLD:
|
| 261 |
+
# One penalty per violation that was caused by a wrong claim
|
| 262 |
+
overconfidence_penalty = _OVERCONFIDENCE_PENALTY * len(violations)
|
| 263 |
+
|
| 264 |
+
# ------------------------------------------------------------------
|
| 265 |
+
# Layer 2: Persona-scaled feedback (never overrides Layer 1 result)
|
| 266 |
+
# ------------------------------------------------------------------
|
| 267 |
+
persona = _select_persona(state.difficulty)
|
| 268 |
+
feedback, hint = _generate_feedback(persona, violations, passed, action, latent)
|
| 269 |
+
|
| 270 |
+
return JudgeResult(
|
| 271 |
+
passed=passed,
|
| 272 |
+
violations=violations,
|
| 273 |
+
feedback=feedback,
|
| 274 |
+
hint=hint,
|
| 275 |
+
overconfidence_penalty=overconfidence_penalty,
|
| 276 |
+
persona=persona,
|
| 277 |
+
)
|
server/logger.py
CHANGED
|
@@ -31,9 +31,7 @@ class EpisodeLogger:
|
|
| 31 |
episode_id: str | None = None,
|
| 32 |
curriculum_tier: int = 0,
|
| 33 |
) -> None:
|
| 34 |
-
self._log_path: Path =
|
| 35 |
-
log_path if log_path is not None else settings.log_path
|
| 36 |
-
)
|
| 37 |
self._episode_id: str = (
|
| 38 |
episode_id if episode_id is not None else str(uuid.uuid4())
|
| 39 |
)
|
|
|
|
| 31 |
episode_id: str | None = None,
|
| 32 |
curriculum_tier: int = 0,
|
| 33 |
) -> None:
|
| 34 |
+
self._log_path: Path = log_path if log_path is not None else settings.log_path
|
|
|
|
|
|
|
| 35 |
self._episode_id: str = (
|
| 36 |
episode_id if episode_id is not None else str(uuid.uuid4())
|
| 37 |
)
|
server/noise_model.py
CHANGED
|
@@ -36,6 +36,11 @@ class NoiseModel:
|
|
| 36 |
self._seed = seed
|
| 37 |
self._rng: np.random.Generator = np.random.default_rng(seed)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def randomize(self, config: ScenarioConfig) -> ScenarioConfig:
|
| 40 |
"""Return a new ScenarioConfig with domain-randomized parameters.
|
| 41 |
|
|
|
|
| 36 |
self._seed = seed
|
| 37 |
self._rng: np.random.Generator = np.random.default_rng(seed)
|
| 38 |
|
| 39 |
+
@property
|
| 40 |
+
def rng(self) -> np.random.Generator:
|
| 41 |
+
"""Public access to the seeded Generator."""
|
| 42 |
+
return self._rng
|
| 43 |
+
|
| 44 |
def randomize(self, config: ScenarioConfig) -> ScenarioConfig:
|
| 45 |
"""Return a new ScenarioConfig with domain-randomized parameters.
|
| 46 |
|
server/phase_detector.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phase Detector — classifies TrialActions into clinical workflow phases.
|
| 3 |
+
|
| 4 |
+
Clinical workflow phase order:
|
| 5 |
+
literature_review → hypothesis → design → enrollment →
|
| 6 |
+
monitoring → analysis → submission
|
| 7 |
+
|
| 8 |
+
Phase-order bonus: +0.2 for correct order (no regression, no skips)
|
| 9 |
+
Skip penalty: -0.3 per skipped phase
|
| 10 |
+
|
| 11 |
+
Requirements: 8.5, 9.4
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from models import ActionType, TrialAction
|
| 17 |
+
|
| 18 |
+
# Ordered list of clinical workflow phases
|
| 19 |
+
PHASE_ORDER: list[str] = [
|
| 20 |
+
"literature_review",
|
| 21 |
+
"hypothesis",
|
| 22 |
+
"design",
|
| 23 |
+
"enrollment",
|
| 24 |
+
"monitoring",
|
| 25 |
+
"analysis",
|
| 26 |
+
"submission",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# Reward constants
|
| 30 |
+
PHASE_BONUS: float = 0.2
|
| 31 |
+
PHASE_SKIP_PENALTY: float = -0.3
|
| 32 |
+
|
| 33 |
+
# Mapping from ActionType to phase name.
|
| 34 |
+
# literature_review has no direct action — used as default for unknown.
|
| 35 |
+
_ACTION_TO_PHASE: dict[ActionType, str] = {
|
| 36 |
+
# hypothesis
|
| 37 |
+
ActionType.ESTIMATE_EFFECT_SIZE: "hypothesis",
|
| 38 |
+
ActionType.ADD_BIOMARKER_STRATIFICATION: "hypothesis",
|
| 39 |
+
# design
|
| 40 |
+
ActionType.SET_PRIMARY_ENDPOINT: "design",
|
| 41 |
+
ActionType.SET_SAMPLE_SIZE: "design",
|
| 42 |
+
ActionType.SET_INCLUSION_CRITERIA: "design",
|
| 43 |
+
ActionType.SET_EXCLUSION_CRITERIA: "design",
|
| 44 |
+
ActionType.SET_DOSING_SCHEDULE: "design",
|
| 45 |
+
ActionType.SET_CONTROL_ARM: "design",
|
| 46 |
+
ActionType.SET_RANDOMIZATION_RATIO: "design",
|
| 47 |
+
ActionType.SET_BLINDING: "design",
|
| 48 |
+
ActionType.REQUEST_PROTOCOL_AMENDMENT: "design",
|
| 49 |
+
# enrollment
|
| 50 |
+
ActionType.ENROLL_PATIENTS: "enrollment",
|
| 51 |
+
# monitoring
|
| 52 |
+
ActionType.RUN_DOSE_ESCALATION: "monitoring",
|
| 53 |
+
ActionType.OBSERVE_SAFETY_SIGNAL: "monitoring",
|
| 54 |
+
ActionType.RUN_INTERIM_ANALYSIS: "monitoring",
|
| 55 |
+
ActionType.MODIFY_SAMPLE_SIZE: "monitoring",
|
| 56 |
+
# analysis
|
| 57 |
+
ActionType.RUN_PRIMARY_ANALYSIS: "analysis",
|
| 58 |
+
ActionType.SYNTHESIZE_CONCLUSION: "analysis",
|
| 59 |
+
# submission
|
| 60 |
+
ActionType.SUBMIT_TO_FDA_REVIEW: "submission",
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def detect_phase(action: TrialAction, history: list[str]) -> tuple[str, bool]:
|
| 65 |
+
"""Classify a TrialAction into a clinical workflow phase.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
action: The agent's action for this step.
|
| 69 |
+
history: List of phase names (strings) from previous steps in the episode.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
A tuple of (phase_name, phase_order_correct) where:
|
| 73 |
+
- phase_name is the detected phase string
|
| 74 |
+
- phase_order_correct is True iff the phase transition is valid
|
| 75 |
+
(no regression, no skipped phases)
|
| 76 |
+
"""
|
| 77 |
+
phase_name = _ACTION_TO_PHASE.get(action.action_type, "literature_review")
|
| 78 |
+
|
| 79 |
+
if not history:
|
| 80 |
+
# First action — any phase is valid
|
| 81 |
+
return phase_name, True
|
| 82 |
+
|
| 83 |
+
last_phase = history[-1]
|
| 84 |
+
last_idx = PHASE_ORDER.index(last_phase) if last_phase in PHASE_ORDER else 0
|
| 85 |
+
current_idx = PHASE_ORDER.index(phase_name) if phase_name in PHASE_ORDER else 0
|
| 86 |
+
|
| 87 |
+
# Regression: going backwards is not correct
|
| 88 |
+
if current_idx < last_idx:
|
| 89 |
+
return phase_name, False
|
| 90 |
+
|
| 91 |
+
# Skipped phases: any phase between last+1 and current-1 (exclusive) is a skip
|
| 92 |
+
skipped = current_idx - last_idx - 1
|
| 93 |
+
if skipped > 0:
|
| 94 |
+
return phase_name, False
|
| 95 |
+
|
| 96 |
+
# Staying in same phase or advancing by exactly one — correct
|
| 97 |
+
return phase_name, True
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def compute_phase_ordering_reward(action: TrialAction, history: list[str]) -> float:
|
| 101 |
+
"""Compute the r_ordering reward component using phase detection.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
+PHASE_BONUS if phase order is correct.
|
| 105 |
+
PHASE_SKIP_PENALTY * num_skipped_phases if phases were skipped.
|
| 106 |
+
0.0 if there is a regression (going backwards).
|
| 107 |
+
"""
|
| 108 |
+
phase_name = _ACTION_TO_PHASE.get(action.action_type, "literature_review")
|
| 109 |
+
|
| 110 |
+
if not history:
|
| 111 |
+
return PHASE_BONUS
|
| 112 |
+
|
| 113 |
+
last_phase = history[-1]
|
| 114 |
+
last_idx = PHASE_ORDER.index(last_phase) if last_phase in PHASE_ORDER else 0
|
| 115 |
+
current_idx = PHASE_ORDER.index(phase_name) if phase_name in PHASE_ORDER else 0
|
| 116 |
+
|
| 117 |
+
if current_idx < last_idx:
|
| 118 |
+
# Regression — no bonus, no skip penalty
|
| 119 |
+
return 0.0
|
| 120 |
+
|
| 121 |
+
skipped = current_idx - last_idx - 1
|
| 122 |
+
if skipped > 0:
|
| 123 |
+
return PHASE_SKIP_PENALTY * skipped
|
| 124 |
+
|
| 125 |
+
return PHASE_BONUS
|
server/reward/reward_computer.py
CHANGED
|
@@ -18,6 +18,7 @@ from models import (
|
|
| 18 |
TrialLatentState,
|
| 19 |
TrialResult,
|
| 20 |
)
|
|
|
|
| 21 |
from server.rules.fda_rules import check_fda_compliance
|
| 22 |
|
| 23 |
# Reward magnitude constants
|
|
@@ -29,13 +30,13 @@ _TERMINAL_CALIBRATION = 5.0
|
|
| 29 |
_INFO_GAIN_BASE = 0.5
|
| 30 |
_EFFICIENCY_SCALE = 2.0
|
| 31 |
_NOVELTY_BASE = 0.2
|
| 32 |
-
_ORDERING_BONUS = 0.2
|
| 33 |
|
| 34 |
|
| 35 |
def compute_reward(
|
| 36 |
action: TrialAction,
|
| 37 |
latent: TrialLatentState,
|
| 38 |
result: TrialResult,
|
|
|
|
| 39 |
) -> RewardBreakdown:
|
| 40 |
"""Compute all eight reward components for a single step.
|
| 41 |
|
|
@@ -46,6 +47,7 @@ def compute_reward(
|
|
| 46 |
action: The agent's action.
|
| 47 |
latent: Hidden ground-truth + episode tracking state.
|
| 48 |
result: The simulated trial result.
|
|
|
|
| 49 |
|
| 50 |
Returns:
|
| 51 |
A RewardBreakdown with all eight keys populated.
|
|
@@ -54,11 +56,9 @@ def compute_reward(
|
|
| 54 |
|
| 55 |
r_validity = _VALIDITY_VALID if compliance.valid else _VALIDITY_INVALID
|
| 56 |
r_penalty = (
|
| 57 |
-
_PENALTY_INVALID * len(compliance.violations)
|
| 58 |
-
if not compliance.valid
|
| 59 |
-
else 0.0
|
| 60 |
)
|
| 61 |
-
r_ordering =
|
| 62 |
r_info_gain = _info_gain_reward(action, result)
|
| 63 |
r_efficiency = _efficiency_reward(latent)
|
| 64 |
r_novelty = _novelty_reward(action, latent)
|
|
@@ -81,18 +81,11 @@ def compute_reward(
|
|
| 81 |
# Component helpers
|
| 82 |
# ---------------------------------------------------------------------------
|
| 83 |
|
| 84 |
-
def _ordering_reward(action: TrialAction, latent: TrialLatentState) -> float:
|
| 85 |
-
"""Bonus for actions that match the expected clinical workflow phase."""
|
| 86 |
-
from server.rules.fda_rules import TRANSITION_TABLE
|
| 87 |
-
permitted = TRANSITION_TABLE.get(latent.episode_phase, set())
|
| 88 |
-
if action.action_type in permitted:
|
| 89 |
-
return _ORDERING_BONUS
|
| 90 |
-
return 0.0
|
| 91 |
-
|
| 92 |
|
| 93 |
def _info_gain_reward(action: TrialAction, result: TrialResult) -> float:
|
| 94 |
"""Reward for information-gathering actions that produce useful results."""
|
| 95 |
from models import ActionType
|
|
|
|
| 96 |
info_actions = {
|
| 97 |
ActionType.ESTIMATE_EFFECT_SIZE,
|
| 98 |
ActionType.OBSERVE_SAFETY_SIGNAL,
|
|
@@ -110,9 +103,7 @@ def _efficiency_reward(latent: TrialLatentState) -> float:
|
|
| 110 |
initial_budget = 1_000_000.0
|
| 111 |
if initial_budget <= 0:
|
| 112 |
return 0.0
|
| 113 |
-
budget_fraction = min(
|
| 114 |
-
max(latent.budget_remaining / initial_budget, 0.0), 1.0
|
| 115 |
-
)
|
| 116 |
return _EFFICIENCY_SCALE * budget_fraction
|
| 117 |
|
| 118 |
|
|
@@ -123,9 +114,7 @@ def _novelty_reward(action: TrialAction, latent: TrialLatentState) -> float:
|
|
| 123 |
return 0.0
|
| 124 |
|
| 125 |
|
| 126 |
-
def _terminal_success_reward(
|
| 127 |
-
latent: TrialLatentState, result: TrialResult
|
| 128 |
-
) -> float:
|
| 129 |
"""Positive reward when the episode ends with a successful trial (req 6.4)."""
|
| 130 |
if latent.trial_complete and result.success and result.failure_reason is None:
|
| 131 |
return _TERMINAL_SUCCESS
|
|
@@ -150,6 +139,6 @@ def _terminal_calibration_reward(
|
|
| 150 |
centre_error = abs(ci_centre - true_effect)
|
| 151 |
calibration_score = max(0.0, 1.0 - centre_error)
|
| 152 |
width_penalty = min(ci_width, 1.0)
|
| 153 |
-
calibration_score *=
|
| 154 |
|
| 155 |
return _TERMINAL_CALIBRATION * calibration_score
|
|
|
|
| 18 |
TrialLatentState,
|
| 19 |
TrialResult,
|
| 20 |
)
|
| 21 |
+
from server.phase_detector import compute_phase_ordering_reward
|
| 22 |
from server.rules.fda_rules import check_fda_compliance
|
| 23 |
|
| 24 |
# Reward magnitude constants
|
|
|
|
| 30 |
_INFO_GAIN_BASE = 0.5
|
| 31 |
_EFFICIENCY_SCALE = 2.0
|
| 32 |
_NOVELTY_BASE = 0.2
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def compute_reward(
|
| 36 |
action: TrialAction,
|
| 37 |
latent: TrialLatentState,
|
| 38 |
result: TrialResult,
|
| 39 |
+
phase_history: list[str] | None = None,
|
| 40 |
) -> RewardBreakdown:
|
| 41 |
"""Compute all eight reward components for a single step.
|
| 42 |
|
|
|
|
| 47 |
action: The agent's action.
|
| 48 |
latent: Hidden ground-truth + episode tracking state.
|
| 49 |
result: The simulated trial result.
|
| 50 |
+
phase_history: List of phase names from previous steps (for r_ordering).
|
| 51 |
|
| 52 |
Returns:
|
| 53 |
A RewardBreakdown with all eight keys populated.
|
|
|
|
| 56 |
|
| 57 |
r_validity = _VALIDITY_VALID if compliance.valid else _VALIDITY_INVALID
|
| 58 |
r_penalty = (
|
| 59 |
+
_PENALTY_INVALID * len(compliance.violations) if not compliance.valid else 0.0
|
|
|
|
|
|
|
| 60 |
)
|
| 61 |
+
r_ordering = compute_phase_ordering_reward(action, phase_history or [])
|
| 62 |
r_info_gain = _info_gain_reward(action, result)
|
| 63 |
r_efficiency = _efficiency_reward(latent)
|
| 64 |
r_novelty = _novelty_reward(action, latent)
|
|
|
|
| 81 |
# Component helpers
|
| 82 |
# ---------------------------------------------------------------------------
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
def _info_gain_reward(action: TrialAction, result: TrialResult) -> float:
|
| 86 |
"""Reward for information-gathering actions that produce useful results."""
|
| 87 |
from models import ActionType
|
| 88 |
+
|
| 89 |
info_actions = {
|
| 90 |
ActionType.ESTIMATE_EFFECT_SIZE,
|
| 91 |
ActionType.OBSERVE_SAFETY_SIGNAL,
|
|
|
|
| 103 |
initial_budget = 1_000_000.0
|
| 104 |
if initial_budget <= 0:
|
| 105 |
return 0.0
|
| 106 |
+
budget_fraction = min(max(latent.budget_remaining / initial_budget, 0.0), 1.0)
|
|
|
|
|
|
|
| 107 |
return _EFFICIENCY_SCALE * budget_fraction
|
| 108 |
|
| 109 |
|
|
|
|
| 114 |
return 0.0
|
| 115 |
|
| 116 |
|
| 117 |
+
def _terminal_success_reward(latent: TrialLatentState, result: TrialResult) -> float:
|
|
|
|
|
|
|
| 118 |
"""Positive reward when the episode ends with a successful trial (req 6.4)."""
|
| 119 |
if latent.trial_complete and result.success and result.failure_reason is None:
|
| 120 |
return _TERMINAL_SUCCESS
|
|
|
|
| 139 |
centre_error = abs(ci_centre - true_effect)
|
| 140 |
calibration_score = max(0.0, 1.0 - centre_error)
|
| 141 |
width_penalty = min(ci_width, 1.0)
|
| 142 |
+
calibration_score *= 1.0 - width_penalty * 0.5
|
| 143 |
|
| 144 |
return _TERMINAL_CALIBRATION * calibration_score
|
server/reward/shaping.py
CHANGED
|
@@ -32,9 +32,7 @@ def _budget_efficiency(
|
|
| 32 |
return min(max(latent.budget_remaining / initial_budget, 0.0), 1.0)
|
| 33 |
|
| 34 |
|
| 35 |
-
def potential(
|
| 36 |
-
latent: TrialLatentState, initial_budget: float = 1_000_000.0
|
| 37 |
-
) -> float:
|
| 38 |
"""φ(s) = milestone_completion × budget_efficiency."""
|
| 39 |
return _milestone_completion(latent) * _budget_efficiency(latent, initial_budget)
|
| 40 |
|
|
|
|
| 32 |
return min(max(latent.budget_remaining / initial_budget, 0.0), 1.0)
|
| 33 |
|
| 34 |
|
| 35 |
+
def potential(latent: TrialLatentState, initial_budget: float = 1_000_000.0) -> float:
|
|
|
|
|
|
|
| 36 |
"""φ(s) = milestone_completion × budget_efficiency."""
|
| 37 |
return _milestone_completion(latent) * _budget_efficiency(latent, initial_budget)
|
| 38 |
|
server/rules/prerequisite_rules.py
CHANGED
|
@@ -20,9 +20,7 @@ _HISTORY_PREREQUISITES: dict[ActionType, list[ActionType]] = {
|
|
| 20 |
}
|
| 21 |
|
| 22 |
|
| 23 |
-
def check_prerequisites(
|
| 24 |
-
action: TrialAction, latent: TrialLatentState
|
| 25 |
-
) -> list[str]:
|
| 26 |
"""Return a list of prerequisite violation strings for *action* given *latent*.
|
| 27 |
|
| 28 |
Returns an empty list when all prerequisites are satisfied.
|
|
|
|
| 20 |
}
|
| 21 |
|
| 22 |
|
| 23 |
+
def check_prerequisites(action: TrialAction, latent: TrialLatentState) -> list[str]:
|
|
|
|
|
|
|
| 24 |
"""Return a list of prerequisite violation strings for *action* given *latent*.
|
| 25 |
|
| 26 |
Returns an empty list when all prerequisites are satisfied.
|
server/simulator/__init__.py
CHANGED
|
@@ -2,5 +2,5 @@
|
|
| 2 |
simulator — Trial outcome simulation and power calculation.
|
| 3 |
|
| 4 |
Provides simulate_trial, calculate_power (with episode-scoped cache),
|
| 5 |
-
compute_reward, and the seeded hidden-state generator.
|
| 6 |
"""
|
|
|
|
| 2 |
simulator — Trial outcome simulation and power calculation.
|
| 3 |
|
| 4 |
Provides simulate_trial, calculate_power (with episode-scoped cache),
|
| 5 |
+
compute_reward, TransitionEngine, and the seeded hidden-state generator.
|
| 6 |
"""
|
server/simulator/output_generator.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OutputGenerator — produces a noisy TrialObservation from a TrialLatentState.
|
| 3 |
+
|
| 4 |
+
Follows the Bio Experiment pattern: TransitionEngine updates hidden state,
|
| 5 |
+
OutputGenerator produces noisy observations from it. Agent never sees clean
|
| 6 |
+
hidden values.
|
| 7 |
+
|
| 8 |
+
Key responsibilities:
|
| 9 |
+
- Inject measurement noise and site variability via NoiseModel's seeded RNG
|
| 10 |
+
- Populate phase_data with noisy (not raw) experimental results
|
| 11 |
+
- Populate resource_status from latent state resource fields
|
| 12 |
+
- Populate available_actions based on current milestone flags and phase
|
| 13 |
+
- Never expose true_effect_size, true_side_effect_rate, or other hidden values
|
| 14 |
+
directly — always add noise before returning to the agent
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from models import ActionType, TrialLatentState, TrialObservation, TrialState
|
| 22 |
+
from server.noise_model import NoiseModel
|
| 23 |
+
from server.rules.fda_rules import TRANSITION_TABLE
|
| 24 |
+
from server.rules.prerequisite_rules import _HISTORY_PREREQUISITES
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class OutputGenerator:
|
| 28 |
+
"""Produces a noisy TrialObservation from a TrialLatentState.
|
| 29 |
+
|
| 30 |
+
The agent never sees clean hidden values — all experimental results are
|
| 31 |
+
perturbed by measurement noise and site variability before being returned.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
noise_model: Seeded NoiseModel used to draw observation noise.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, noise_model: NoiseModel) -> None:
|
| 38 |
+
self._noise_model = noise_model
|
| 39 |
+
|
| 40 |
+
def generate(
|
| 41 |
+
self,
|
| 42 |
+
latent: TrialLatentState,
|
| 43 |
+
trial_state: TrialState,
|
| 44 |
+
*,
|
| 45 |
+
steps_taken: int,
|
| 46 |
+
max_steps: int,
|
| 47 |
+
rule_violations: list[str],
|
| 48 |
+
done: bool,
|
| 49 |
+
reward: float,
|
| 50 |
+
scenario_description: str,
|
| 51 |
+
hint: str = "",
|
| 52 |
+
) -> TrialObservation:
|
| 53 |
+
"""Generate a noisy TrialObservation from the current latent state.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
latent: Updated hidden state from TransitionEngine.
|
| 57 |
+
trial_state: Episode metadata (difficulty, curriculum tier, etc.).
|
| 58 |
+
steps_taken: Number of steps taken so far in the episode.
|
| 59 |
+
max_steps: Maximum steps allowed in the episode.
|
| 60 |
+
rule_violations: List of rule violation strings from this step.
|
| 61 |
+
done: Whether the episode is finished.
|
| 62 |
+
reward: Reward signal for this step.
|
| 63 |
+
scenario_description: Human-readable scenario description.
|
| 64 |
+
hint: Optional hint string (only populated at junior difficulty).
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
A TrialObservation with noisy phase_data, resource_status, and
|
| 68 |
+
available_actions. Raw hidden values are never included.
|
| 69 |
+
"""
|
| 70 |
+
rng = self._noise_model.rng
|
| 71 |
+
|
| 72 |
+
phase_data = self._build_phase_data(latent, rng)
|
| 73 |
+
resource_status = self._build_resource_status(latent)
|
| 74 |
+
available_actions = self._build_available_actions(latent)
|
| 75 |
+
|
| 76 |
+
return TrialObservation(
|
| 77 |
+
scenario_description=scenario_description,
|
| 78 |
+
phase_data=phase_data,
|
| 79 |
+
resource_status=resource_status,
|
| 80 |
+
rule_violations=rule_violations,
|
| 81 |
+
available_actions=available_actions,
|
| 82 |
+
steps_taken=steps_taken,
|
| 83 |
+
max_steps=max_steps,
|
| 84 |
+
hint=hint,
|
| 85 |
+
done=done,
|
| 86 |
+
reward=reward,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# ------------------------------------------------------------------
|
| 90 |
+
# Private helpers
|
| 91 |
+
# ------------------------------------------------------------------
|
| 92 |
+
|
| 93 |
+
def _build_phase_data(
|
| 94 |
+
self,
|
| 95 |
+
latent: TrialLatentState,
|
| 96 |
+
rng: "np.random.Generator",
|
| 97 |
+
) -> dict:
|
| 98 |
+
"""Build noisy phase_data dict — never exposes raw hidden values.
|
| 99 |
+
|
| 100 |
+
Measurement noise (latent.measurement_noise) is applied to effect-size
|
| 101 |
+
estimates. Site variability (latent.site_variability) is applied to
|
| 102 |
+
adverse-event-rate estimates.
|
| 103 |
+
"""
|
| 104 |
+
import numpy as np # local import to keep module-level deps minimal
|
| 105 |
+
|
| 106 |
+
noise_std = max(latent.measurement_noise, 1e-6)
|
| 107 |
+
site_std = max(latent.site_variability, 1e-6)
|
| 108 |
+
|
| 109 |
+
phase_data: dict = {
|
| 110 |
+
"current_phase": latent.episode_phase,
|
| 111 |
+
"patients_enrolled": latent.patients_enrolled,
|
| 112 |
+
# Milestones — these are observable flags, not hidden values
|
| 113 |
+
"phase_i_complete": latent.phase_i_complete,
|
| 114 |
+
"mtd_identified": latent.mtd_identified,
|
| 115 |
+
"effect_estimated": latent.effect_estimated,
|
| 116 |
+
"protocol_submitted": latent.protocol_submitted,
|
| 117 |
+
"interim_complete": latent.interim_complete,
|
| 118 |
+
"trial_complete": latent.trial_complete,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
# Noisy effect-size estimate — only available after ESTIMATE_EFFECT_SIZE
|
| 122 |
+
if latent.effect_estimated:
|
| 123 |
+
noisy_effect = float(latent.true_effect_size + rng.normal(0.0, noise_std))
|
| 124 |
+
phase_data["observed_effect_size"] = round(noisy_effect, 4)
|
| 125 |
+
|
| 126 |
+
# Noisy confidence interval width (derived from noise level)
|
| 127 |
+
ci_half_width = float(rng.normal(noise_std * 2, noise_std * 0.5))
|
| 128 |
+
ci_half_width = max(ci_half_width, 0.01)
|
| 129 |
+
phase_data["effect_size_ci"] = (
|
| 130 |
+
round(noisy_effect - ci_half_width, 4),
|
| 131 |
+
round(noisy_effect + ci_half_width, 4),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Noisy adverse-event rate — only available after OBSERVE_SAFETY_SIGNAL
|
| 135 |
+
# or RUN_DOSE_ESCALATION
|
| 136 |
+
if (
|
| 137 |
+
latent.phase_i_complete
|
| 138 |
+
or ActionType.OBSERVE_SAFETY_SIGNAL.value in latent.action_history
|
| 139 |
+
):
|
| 140 |
+
noisy_ae_rate = float(
|
| 141 |
+
latent.true_side_effect_rate + rng.normal(0.0, site_std)
|
| 142 |
+
)
|
| 143 |
+
noisy_ae_rate = float(np.clip(noisy_ae_rate, 0.0, 1.0))
|
| 144 |
+
phase_data["observed_adverse_event_rate"] = round(noisy_ae_rate, 4)
|
| 145 |
+
|
| 146 |
+
# Noisy placebo response — only available after interim or primary analysis
|
| 147 |
+
if latent.interim_complete or latent.trial_complete:
|
| 148 |
+
noisy_placebo = float(
|
| 149 |
+
latent.placebo_response_rate + rng.normal(0.0, noise_std)
|
| 150 |
+
)
|
| 151 |
+
noisy_placebo = float(np.clip(noisy_placebo, 0.0, 1.0))
|
| 152 |
+
phase_data["observed_placebo_response"] = round(noisy_placebo, 4)
|
| 153 |
+
|
| 154 |
+
# Noisy dose-response curve — only available after Phase I
|
| 155 |
+
if latent.phase_i_complete and latent.true_dose_response:
|
| 156 |
+
noisy_dose_response: dict[str, float] = {}
|
| 157 |
+
for dose, response in latent.true_dose_response.items():
|
| 158 |
+
noisy_resp = float(response + rng.normal(0.0, noise_std))
|
| 159 |
+
noisy_resp = float(np.clip(noisy_resp, 0.0, 1.0))
|
| 160 |
+
noisy_dose_response[str(dose)] = round(noisy_resp, 4)
|
| 161 |
+
phase_data["observed_dose_response"] = noisy_dose_response
|
| 162 |
+
|
| 163 |
+
# Dropout rate estimate — noisy, only after enrollment begins
|
| 164 |
+
if latent.patients_enrolled > 0:
|
| 165 |
+
noisy_dropout = float(
|
| 166 |
+
latent.dropout_rate + rng.normal(0.0, noise_std * 0.5)
|
| 167 |
+
)
|
| 168 |
+
noisy_dropout = float(np.clip(noisy_dropout, 0.0, 1.0))
|
| 169 |
+
phase_data["observed_dropout_rate"] = round(noisy_dropout, 4)
|
| 170 |
+
|
| 171 |
+
# Responder population hint — only after biomarker stratification
|
| 172 |
+
if ActionType.ADD_BIOMARKER_STRATIFICATION.value in latent.action_history:
|
| 173 |
+
# Reveal population label but NOT the true criteria (hidden)
|
| 174 |
+
phase_data["responder_population_hint"] = latent.true_responder_population
|
| 175 |
+
|
| 176 |
+
return phase_data
|
| 177 |
+
|
| 178 |
+
def _build_resource_status(self, latent: TrialLatentState) -> dict:
|
| 179 |
+
"""Build resource_status from latent state resource fields."""
|
| 180 |
+
return {
|
| 181 |
+
"budget_remaining": latent.budget_remaining,
|
| 182 |
+
"time_remaining_days": latent.time_remaining_days,
|
| 183 |
+
"patients_enrolled": latent.patients_enrolled,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
def _build_available_actions(self, latent: TrialLatentState) -> list[str]:
|
| 187 |
+
"""Return the list of valid action strings given current milestone flags.
|
| 188 |
+
|
| 189 |
+
Filters the phase-permitted actions through prerequisite checks so the
|
| 190 |
+
agent only sees actions it can actually take right now.
|
| 191 |
+
"""
|
| 192 |
+
phase_permitted: set[ActionType] = TRANSITION_TABLE.get(
|
| 193 |
+
latent.episode_phase, set()
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
available: list[str] = []
|
| 197 |
+
for action_type in sorted(phase_permitted, key=lambda a: a.value):
|
| 198 |
+
if self._prerequisites_met(action_type, latent):
|
| 199 |
+
available.append(action_type.value)
|
| 200 |
+
|
| 201 |
+
return available
|
| 202 |
+
|
| 203 |
+
def _prerequisites_met(
|
| 204 |
+
self, action_type: ActionType, latent: TrialLatentState
|
| 205 |
+
) -> bool:
|
| 206 |
+
"""Return True if all prerequisites for *action_type* are satisfied."""
|
| 207 |
+
# History-based prerequisites
|
| 208 |
+
required_actions = _HISTORY_PREREQUISITES.get(action_type, [])
|
| 209 |
+
for required in required_actions:
|
| 210 |
+
if required.value not in latent.action_history:
|
| 211 |
+
return False
|
| 212 |
+
|
| 213 |
+
# State-flag prerequisites (mirrors prerequisite_rules.py logic)
|
| 214 |
+
if action_type == ActionType.REQUEST_PROTOCOL_AMENDMENT:
|
| 215 |
+
if not latent.protocol_submitted:
|
| 216 |
+
return False
|
| 217 |
+
|
| 218 |
+
if action_type == ActionType.SUBMIT_TO_FDA_REVIEW:
|
| 219 |
+
if not latent.protocol_submitted or not latent.phase_i_complete:
|
| 220 |
+
return False
|
| 221 |
+
|
| 222 |
+
if action_type == ActionType.RUN_PRIMARY_ANALYSIS:
|
| 223 |
+
if not latent.interim_complete:
|
| 224 |
+
return False
|
| 225 |
+
|
| 226 |
+
if action_type == ActionType.RUN_INTERIM_ANALYSIS:
|
| 227 |
+
if latent.patients_enrolled <= 0:
|
| 228 |
+
return False
|
| 229 |
+
|
| 230 |
+
if action_type == ActionType.MODIFY_SAMPLE_SIZE:
|
| 231 |
+
if ActionType.SET_SAMPLE_SIZE.value not in latent.action_history:
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
if action_type == ActionType.SYNTHESIZE_CONCLUSION:
|
| 235 |
+
if not latent.trial_complete:
|
| 236 |
+
return False
|
| 237 |
+
|
| 238 |
+
return True
|
server/simulator/transition_engine.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TransitionEngine — mutates TrialLatentState per action.
|
| 3 |
+
|
| 4 |
+
Follows the Bio Experiment pattern: TransitionEngine updates hidden state,
|
| 5 |
+
OutputGenerator produces noisy observations from it. Agent never sees clean
|
| 6 |
+
hidden values.
|
| 7 |
+
|
| 8 |
+
Key responsibilities:
|
| 9 |
+
- Enroll patients (ENROLL_PATIENTS)
|
| 10 |
+
- Spend budget and advance time
|
| 11 |
+
- Record adverse events
|
| 12 |
+
- Set milestone flags (phase_i_complete, mtd_identified, effect_estimated,
|
| 13 |
+
protocol_submitted, interim_complete, trial_complete)
|
| 14 |
+
- Degrade data quality on soft violations
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import random
|
| 20 |
+
|
| 21 |
+
from models import ActionType, TrialAction, TrialLatentState
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TransitionEngine:
|
| 25 |
+
"""Mutates TrialLatentState in response to agent actions.
|
| 26 |
+
|
| 27 |
+
All state transitions are deterministic given the same seed and action
|
| 28 |
+
sequence (reproducibility requirement 9.2).
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# Cost and time constants (per action type)
|
| 32 |
+
_ACTION_COSTS: dict[ActionType, float] = {
|
| 33 |
+
ActionType.SET_PRIMARY_ENDPOINT: 5_000.0,
|
| 34 |
+
ActionType.SET_SAMPLE_SIZE: 2_000.0,
|
| 35 |
+
ActionType.SET_INCLUSION_CRITERIA: 3_000.0,
|
| 36 |
+
ActionType.SET_EXCLUSION_CRITERIA: 3_000.0,
|
| 37 |
+
ActionType.SET_DOSING_SCHEDULE: 10_000.0,
|
| 38 |
+
ActionType.SET_CONTROL_ARM: 5_000.0,
|
| 39 |
+
ActionType.SET_RANDOMIZATION_RATIO: 2_000.0,
|
| 40 |
+
ActionType.SET_BLINDING: 4_000.0,
|
| 41 |
+
ActionType.RUN_DOSE_ESCALATION: 50_000.0,
|
| 42 |
+
ActionType.OBSERVE_SAFETY_SIGNAL: 15_000.0,
|
| 43 |
+
ActionType.ESTIMATE_EFFECT_SIZE: 20_000.0,
|
| 44 |
+
ActionType.RUN_INTERIM_ANALYSIS: 30_000.0,
|
| 45 |
+
ActionType.MODIFY_SAMPLE_SIZE: 5_000.0,
|
| 46 |
+
ActionType.ADD_BIOMARKER_STRATIFICATION: 25_000.0,
|
| 47 |
+
ActionType.SUBMIT_TO_FDA_REVIEW: 100_000.0,
|
| 48 |
+
ActionType.REQUEST_PROTOCOL_AMENDMENT: 15_000.0,
|
| 49 |
+
ActionType.RUN_PRIMARY_ANALYSIS: 50_000.0,
|
| 50 |
+
ActionType.SYNTHESIZE_CONCLUSION: 10_000.0,
|
| 51 |
+
ActionType.ENROLL_PATIENTS: 0.0, # cost computed per patient
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
_ACTION_TIME_DAYS: dict[ActionType, int] = {
|
| 55 |
+
ActionType.SET_PRIMARY_ENDPOINT: 7,
|
| 56 |
+
ActionType.SET_SAMPLE_SIZE: 3,
|
| 57 |
+
ActionType.SET_INCLUSION_CRITERIA: 5,
|
| 58 |
+
ActionType.SET_EXCLUSION_CRITERIA: 5,
|
| 59 |
+
ActionType.SET_DOSING_SCHEDULE: 14,
|
| 60 |
+
ActionType.SET_CONTROL_ARM: 7,
|
| 61 |
+
ActionType.SET_RANDOMIZATION_RATIO: 3,
|
| 62 |
+
ActionType.SET_BLINDING: 5,
|
| 63 |
+
ActionType.RUN_DOSE_ESCALATION: 90,
|
| 64 |
+
ActionType.OBSERVE_SAFETY_SIGNAL: 30,
|
| 65 |
+
ActionType.ESTIMATE_EFFECT_SIZE: 45,
|
| 66 |
+
ActionType.RUN_INTERIM_ANALYSIS: 60,
|
| 67 |
+
ActionType.MODIFY_SAMPLE_SIZE: 7,
|
| 68 |
+
ActionType.ADD_BIOMARKER_STRATIFICATION: 30,
|
| 69 |
+
ActionType.SUBMIT_TO_FDA_REVIEW: 180,
|
| 70 |
+
ActionType.REQUEST_PROTOCOL_AMENDMENT: 30,
|
| 71 |
+
ActionType.RUN_PRIMARY_ANALYSIS: 90,
|
| 72 |
+
ActionType.SYNTHESIZE_CONCLUSION: 14,
|
| 73 |
+
ActionType.ENROLL_PATIENTS: 0, # time computed per patient
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
# Cost per patient enrolled (varies by disease area complexity)
|
| 77 |
+
_COST_PER_PATIENT: float = 10_000.0
|
| 78 |
+
_DAYS_PER_PATIENT: float = 2.0
|
| 79 |
+
|
| 80 |
+
def __init__(self) -> None:
|
| 81 |
+
"""Initialize the TransitionEngine."""
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
def apply_transition(
|
| 85 |
+
self, latent: TrialLatentState, action: TrialAction
|
| 86 |
+
) -> TrialLatentState:
|
| 87 |
+
"""Apply *action* to *latent* and return the updated state.
|
| 88 |
+
|
| 89 |
+
Does NOT mutate the input latent state — returns a new copy with
|
| 90 |
+
updated fields.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
latent: Current hidden state.
|
| 94 |
+
action: Agent action to apply.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Updated TrialLatentState with mutated fields.
|
| 98 |
+
"""
|
| 99 |
+
# Create a mutable copy
|
| 100 |
+
updated = latent.model_copy(deep=True)
|
| 101 |
+
|
| 102 |
+
# Update action history
|
| 103 |
+
updated.action_history.append(action.action_type.value)
|
| 104 |
+
|
| 105 |
+
# Compute step-specific RNG
|
| 106 |
+
step_index = len(updated.action_history)
|
| 107 |
+
rng = random.Random(latent.seed ^ step_index)
|
| 108 |
+
|
| 109 |
+
# --- Budget and time consumption ---
|
| 110 |
+
base_cost = self._ACTION_COSTS.get(action.action_type, 0.0)
|
| 111 |
+
base_time = self._ACTION_TIME_DAYS.get(action.action_type, 0)
|
| 112 |
+
|
| 113 |
+
if action.action_type == ActionType.ENROLL_PATIENTS:
|
| 114 |
+
n_patients = action.parameters.get("n_patients", 0)
|
| 115 |
+
base_cost = n_patients * self._COST_PER_PATIENT
|
| 116 |
+
base_time = int(n_patients * self._DAYS_PER_PATIENT)
|
| 117 |
+
updated.patients_enrolled += n_patients
|
| 118 |
+
|
| 119 |
+
updated.budget_remaining -= base_cost
|
| 120 |
+
updated.time_remaining_days -= base_time
|
| 121 |
+
|
| 122 |
+
# --- Milestone flag updates ---
|
| 123 |
+
if action.action_type == ActionType.RUN_DOSE_ESCALATION:
|
| 124 |
+
updated.phase_i_complete = True
|
| 125 |
+
updated.mtd_identified = True
|
| 126 |
+
|
| 127 |
+
if action.action_type == ActionType.ESTIMATE_EFFECT_SIZE:
|
| 128 |
+
updated.effect_estimated = True
|
| 129 |
+
|
| 130 |
+
if action.action_type == ActionType.SUBMIT_TO_FDA_REVIEW:
|
| 131 |
+
updated.protocol_submitted = True
|
| 132 |
+
|
| 133 |
+
if action.action_type == ActionType.RUN_INTERIM_ANALYSIS:
|
| 134 |
+
updated.interim_complete = True
|
| 135 |
+
|
| 136 |
+
if action.action_type == ActionType.RUN_PRIMARY_ANALYSIS:
|
| 137 |
+
updated.trial_complete = True
|
| 138 |
+
|
| 139 |
+
# --- Soft violation: degrade data quality ---
|
| 140 |
+
# If action confidence is low (< 0.5), increase measurement noise
|
| 141 |
+
if action.confidence < 0.5:
|
| 142 |
+
degradation_factor = 1.0 + (0.5 - action.confidence)
|
| 143 |
+
updated.measurement_noise = min(
|
| 144 |
+
updated.measurement_noise * degradation_factor, 0.5
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# If budget is negative (soft violation), degrade site variability
|
| 148 |
+
if updated.budget_remaining < 0:
|
| 149 |
+
updated.site_variability = min(updated.site_variability * 1.2, 0.5)
|
| 150 |
+
|
| 151 |
+
# If time is negative (soft violation), increase dropout rate
|
| 152 |
+
if updated.time_remaining_days < 0:
|
| 153 |
+
updated.dropout_rate = min(updated.dropout_rate * 1.15, 0.8)
|
| 154 |
+
|
| 155 |
+
# --- Adverse event recording (stochastic) ---
|
| 156 |
+
# On certain actions, record adverse events based on true_side_effect_rate
|
| 157 |
+
if action.action_type in {
|
| 158 |
+
ActionType.ENROLL_PATIENTS,
|
| 159 |
+
ActionType.OBSERVE_SAFETY_SIGNAL,
|
| 160 |
+
ActionType.RUN_DOSE_ESCALATION,
|
| 161 |
+
}:
|
| 162 |
+
# Adverse events increase site variability slightly
|
| 163 |
+
if rng.random() < updated.true_side_effect_rate:
|
| 164 |
+
updated.adverse_events += 1
|
| 165 |
+
updated.site_variability = min(updated.site_variability + 0.02, 0.5)
|
| 166 |
+
|
| 167 |
+
return updated
|
server/simulator/trial_simulator.py
CHANGED
|
@@ -84,6 +84,7 @@ def simulate_trial(
|
|
| 84 |
se = 1.0 / math.sqrt(n_per_arm)
|
| 85 |
z_stat = observed_effect / se if se > 0 else 0.0
|
| 86 |
from scipy.stats import norm
|
|
|
|
| 87 |
p_value = float(2.0 * norm.sf(abs(z_stat)))
|
| 88 |
else:
|
| 89 |
p_value = 1.0
|
|
@@ -93,6 +94,7 @@ def simulate_trial(
|
|
| 93 |
|
| 94 |
if n_per_arm > 0:
|
| 95 |
from scipy.stats import norm as _norm
|
|
|
|
| 96 |
z_95 = _norm.ppf(0.975)
|
| 97 |
se = 1.0 / math.sqrt(n_per_arm)
|
| 98 |
ci_low = observed_effect - z_95 * se
|
|
@@ -104,9 +106,7 @@ def simulate_trial(
|
|
| 104 |
0.0,
|
| 105 |
latent.site_variability if latent.site_variability > 0 else 0.01,
|
| 106 |
)
|
| 107 |
-
adverse_event_rate = min(
|
| 108 |
-
max(latent.true_side_effect_rate + ae_noise, 0.0), 1.0
|
| 109 |
-
)
|
| 110 |
|
| 111 |
return TrialResult(
|
| 112 |
p_value=p_value,
|
|
|
|
| 84 |
se = 1.0 / math.sqrt(n_per_arm)
|
| 85 |
z_stat = observed_effect / se if se > 0 else 0.0
|
| 86 |
from scipy.stats import norm
|
| 87 |
+
|
| 88 |
p_value = float(2.0 * norm.sf(abs(z_stat)))
|
| 89 |
else:
|
| 90 |
p_value = 1.0
|
|
|
|
| 94 |
|
| 95 |
if n_per_arm > 0:
|
| 96 |
from scipy.stats import norm as _norm
|
| 97 |
+
|
| 98 |
z_95 = _norm.ppf(0.975)
|
| 99 |
se = 1.0 / math.sqrt(n_per_arm)
|
| 100 |
ci_low = observed_effect - z_95 * se
|
|
|
|
| 106 |
0.0,
|
| 107 |
latent.site_variability if latent.site_variability > 0 else 0.01,
|
| 108 |
)
|
| 109 |
+
adverse_event_rate = min(max(latent.true_side_effect_rate + ae_noise, 0.0), 1.0)
|
|
|
|
|
|
|
| 110 |
|
| 111 |
return TrialResult(
|
| 112 |
p_value=p_value,
|
tests/test_curriculum_controller.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for server/curriculum/controller.py
|
| 3 |
+
|
| 4 |
+
Verifies:
|
| 5 |
+
- advance_curriculum mastery logic (70% → graduate, 90% → fast-track)
|
| 6 |
+
- select_scenario tier mapping
|
| 7 |
+
- Edge cases (empty history, max tier, clamping)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from server.curriculum.controller import (
|
| 13 |
+
MAX_TIER,
|
| 14 |
+
EpisodeMetrics,
|
| 15 |
+
advance_curriculum,
|
| 16 |
+
select_scenario,
|
| 17 |
+
)
|
| 18 |
+
from server.curriculum.scenarios import (
|
| 19 |
+
AUTOIMMUNE_BIOLOGIC,
|
| 20 |
+
CNS_DEPRESSION,
|
| 21 |
+
RARE_DISEASE_ORPHAN,
|
| 22 |
+
SOLID_TUMOR_CHEMO,
|
| 23 |
+
WARMUP,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# ── advance_curriculum tests ──────────────────────────────────────────────────
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_advance_curriculum_empty_history():
|
| 30 |
+
"""Empty history → stay at current tier."""
|
| 31 |
+
metrics = EpisodeMetrics(success=True, episode_history=[])
|
| 32 |
+
assert advance_curriculum(0, metrics) == 0
|
| 33 |
+
assert advance_curriculum(2, metrics) == 2
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_advance_curriculum_no_mastery():
|
| 37 |
+
"""Below 70% success → stay at current tier."""
|
| 38 |
+
# 6/10 = 60% → no graduation
|
| 39 |
+
history = [True, False, True, False, True, False, True, False, True, False]
|
| 40 |
+
metrics = EpisodeMetrics(success=False, episode_history=history)
|
| 41 |
+
assert advance_curriculum(1, metrics) == 1
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_advance_curriculum_normal_graduation():
|
| 45 |
+
"""70%+ rolling success → advance one tier."""
|
| 46 |
+
# 7/10 = 70% → graduate
|
| 47 |
+
history = [True, True, True, True, True, True, True, False, False, False]
|
| 48 |
+
metrics = EpisodeMetrics(success=False, episode_history=history)
|
| 49 |
+
assert advance_curriculum(0, metrics) == 1
|
| 50 |
+
assert advance_curriculum(2, metrics) == 3
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_advance_curriculum_fast_track():
|
| 54 |
+
"""90%+ success after ≥3 episodes → skip one tier (advance by 2)."""
|
| 55 |
+
# 9/10 = 90% → fast-track
|
| 56 |
+
history = [True, True, True, True, True, True, True, True, True, False]
|
| 57 |
+
metrics = EpisodeMetrics(success=False, episode_history=history)
|
| 58 |
+
assert advance_curriculum(0, metrics) == 2 # skip tier 1
|
| 59 |
+
assert advance_curriculum(1, metrics) == 3 # skip tier 2
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def test_advance_curriculum_fast_track_requires_min_episodes():
|
| 63 |
+
"""Fast-track requires at least 3 episodes."""
|
| 64 |
+
# 2 episodes, 100% success → not enough for fast-track
|
| 65 |
+
history = [True, True]
|
| 66 |
+
metrics = EpisodeMetrics(success=True, episode_history=history)
|
| 67 |
+
# Should not fast-track (only 2 episodes), but 100% ≥ 70% → normal graduate
|
| 68 |
+
assert advance_curriculum(0, metrics) == 1
|
| 69 |
+
|
| 70 |
+
# 3 episodes, 100% success → fast-track
|
| 71 |
+
history = [True, True, True]
|
| 72 |
+
metrics = EpisodeMetrics(success=True, episode_history=history)
|
| 73 |
+
assert advance_curriculum(0, metrics) == 2
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def test_advance_curriculum_max_tier_clamp():
|
| 77 |
+
"""Cannot advance beyond MAX_TIER (4)."""
|
| 78 |
+
history = [True] * 10 # 100% success
|
| 79 |
+
metrics = EpisodeMetrics(success=True, episode_history=history)
|
| 80 |
+
assert advance_curriculum(MAX_TIER, metrics) == MAX_TIER
|
| 81 |
+
assert advance_curriculum(MAX_TIER - 1, metrics) == MAX_TIER # fast-track clamped
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_advance_curriculum_rolling_window():
|
| 85 |
+
"""Only the most recent 10 episodes count for rolling rate."""
|
| 86 |
+
# 20 episodes: first 10 are all False, last 10 are 9 True + 1 False
|
| 87 |
+
# Rolling window (last 10) = 9/10 = 90% → fast-track
|
| 88 |
+
history = [False] * 10 + [True] * 9 + [False]
|
| 89 |
+
metrics = EpisodeMetrics(success=False, episode_history=history)
|
| 90 |
+
assert advance_curriculum(0, metrics) == 2
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test_advance_curriculum_exactly_70_percent():
|
| 94 |
+
"""Exactly 70% success → should graduate."""
|
| 95 |
+
history = [True] * 7 + [False] * 3
|
| 96 |
+
metrics = EpisodeMetrics(success=False, episode_history=history)
|
| 97 |
+
assert advance_curriculum(1, metrics) == 2
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def test_advance_curriculum_exactly_90_percent():
|
| 101 |
+
"""Exactly 90% success after ≥3 episodes → fast-track."""
|
| 102 |
+
history = [True] * 9 + [False]
|
| 103 |
+
metrics = EpisodeMetrics(success=False, episode_history=history)
|
| 104 |
+
assert advance_curriculum(0, metrics) == 2
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ── select_scenario tests ─────────────────────────────────────────────────────
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def test_select_scenario_tier_mapping():
|
| 111 |
+
"""Each tier maps to the correct ScenarioConfig."""
|
| 112 |
+
rng = np.random.default_rng(42)
|
| 113 |
+
assert select_scenario(0, rng) == WARMUP
|
| 114 |
+
assert select_scenario(1, rng) == SOLID_TUMOR_CHEMO
|
| 115 |
+
assert select_scenario(2, rng) == AUTOIMMUNE_BIOLOGIC
|
| 116 |
+
assert select_scenario(3, rng) == CNS_DEPRESSION
|
| 117 |
+
assert select_scenario(4, rng) == RARE_DISEASE_ORPHAN
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def test_select_scenario_clamping():
|
| 121 |
+
"""Out-of-range tiers are clamped to [MIN_TIER, MAX_TIER]."""
|
| 122 |
+
rng = np.random.default_rng(42)
|
| 123 |
+
# Below MIN_TIER → clamp to 0
|
| 124 |
+
assert select_scenario(-1, rng) == WARMUP
|
| 125 |
+
assert select_scenario(-100, rng) == WARMUP
|
| 126 |
+
# Above MAX_TIER → clamp to 4
|
| 127 |
+
assert select_scenario(5, rng) == RARE_DISEASE_ORPHAN
|
| 128 |
+
assert select_scenario(100, rng) == RARE_DISEASE_ORPHAN
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def test_select_scenario_deterministic():
|
| 132 |
+
"""Same tier + rng seed → same scenario (currently deterministic anyway)."""
|
| 133 |
+
rng1 = np.random.default_rng(42)
|
| 134 |
+
rng2 = np.random.default_rng(42)
|
| 135 |
+
assert select_scenario(2, rng1) == select_scenario(2, rng2)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ── Integration test: full curriculum progression ─────────────────────────────
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def test_full_curriculum_progression():
|
| 142 |
+
"""Simulate a full curriculum progression from tier 0 → 4."""
|
| 143 |
+
tier = 0
|
| 144 |
+
history: list[bool] = []
|
| 145 |
+
|
| 146 |
+
# Tier 0 → 1 (normal graduation at 70%)
|
| 147 |
+
for _ in range(7):
|
| 148 |
+
history.append(True)
|
| 149 |
+
for _ in range(3):
|
| 150 |
+
history.append(False)
|
| 151 |
+
metrics = EpisodeMetrics(success=False, episode_history=history)
|
| 152 |
+
tier = advance_curriculum(tier, metrics)
|
| 153 |
+
assert tier == 1
|
| 154 |
+
|
| 155 |
+
# Tier 1 → 3 (fast-track at 90%)
|
| 156 |
+
history = [True] * 9 + [False]
|
| 157 |
+
metrics = EpisodeMetrics(success=False, episode_history=history)
|
| 158 |
+
tier = advance_curriculum(tier, metrics)
|
| 159 |
+
assert tier == 3
|
| 160 |
+
|
| 161 |
+
# Tier 3 → 4 (normal graduation)
|
| 162 |
+
history = [True] * 7 + [False] * 3
|
| 163 |
+
metrics = EpisodeMetrics(success=False, episode_history=history)
|
| 164 |
+
tier = advance_curriculum(tier, metrics)
|
| 165 |
+
assert tier == 4
|
| 166 |
+
|
| 167 |
+
# Tier 4 → 4 (max tier, cannot advance)
|
| 168 |
+
history = [True] * 10
|
| 169 |
+
metrics = EpisodeMetrics(success=True, episode_history=history)
|
| 170 |
+
tier = advance_curriculum(tier, metrics)
|
| 171 |
+
assert tier == 4
|
tests/test_episode_logger_wiring.py
CHANGED
|
@@ -38,9 +38,7 @@ class TestLoggerCreatedOnReset:
|
|
| 38 |
def test_logger_exists_after_reset(self, manager: EpisodeManager) -> None:
|
| 39 |
assert manager._logger is not None
|
| 40 |
|
| 41 |
-
def test_logger_replaced_on_second_reset(
|
| 42 |
-
self, manager: EpisodeManager
|
| 43 |
-
) -> None:
|
| 44 |
first_id = manager._logger.episode_id
|
| 45 |
manager.reset()
|
| 46 |
second_id = manager._logger.episode_id
|
|
@@ -53,9 +51,7 @@ class TestLoggerCreatedOnReset:
|
|
| 53 |
class TestLogStepCalledOnStep:
|
| 54 |
"""Requirement 7.1: log_step() is called for every step."""
|
| 55 |
|
| 56 |
-
def test_log_step_called_for_invalid_action(
|
| 57 |
-
self, manager: EpisodeManager
|
| 58 |
-
) -> None:
|
| 59 |
mock_logger = MagicMock()
|
| 60 |
manager._logger = mock_logger
|
| 61 |
|
|
|
|
| 38 |
def test_logger_exists_after_reset(self, manager: EpisodeManager) -> None:
|
| 39 |
assert manager._logger is not None
|
| 40 |
|
| 41 |
+
def test_logger_replaced_on_second_reset(self, manager: EpisodeManager) -> None:
|
|
|
|
|
|
|
| 42 |
first_id = manager._logger.episode_id
|
| 43 |
manager.reset()
|
| 44 |
second_id = manager._logger.episode_id
|
|
|
|
| 51 |
class TestLogStepCalledOnStep:
|
| 52 |
"""Requirement 7.1: log_step() is called for every step."""
|
| 53 |
|
| 54 |
+
def test_log_step_called_for_invalid_action(self, manager: EpisodeManager) -> None:
|
|
|
|
|
|
|
| 55 |
mock_logger = MagicMock()
|
| 56 |
manager._logger = mock_logger
|
| 57 |
|
tests/test_episode_manager_compliance.py
CHANGED
|
@@ -33,17 +33,13 @@ def manager() -> EpisodeManager:
|
|
| 33 |
class TestInvalidActionReturnsNegativeRValidity:
|
| 34 |
"""Requirement 10.1: invalid actions → negative r_validity, latent unchanged."""
|
| 35 |
|
| 36 |
-
def test_invalid_action_r_validity_negative(
|
| 37 |
-
self, manager: EpisodeManager
|
| 38 |
-
) -> None:
|
| 39 |
# SUBMIT_TO_FDA_REVIEW not permitted in literature_review phase
|
| 40 |
action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
|
| 41 |
_, reward, _, _ = manager.step(action)
|
| 42 |
assert reward.r_validity < 0, "r_validity must be negative for invalid action"
|
| 43 |
|
| 44 |
-
def test_invalid_action_state_unchanged(
|
| 45 |
-
self, manager: EpisodeManager
|
| 46 |
-
) -> None:
|
| 47 |
action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
|
| 48 |
history_before = list(manager._latent.action_history)
|
| 49 |
step_before = len(history_before)
|
|
@@ -62,9 +58,7 @@ class TestInvalidActionReturnsNegativeRValidity:
|
|
| 62 |
assert len(obs.rule_violations) > 0
|
| 63 |
assert len(info["violations"]) > 0
|
| 64 |
|
| 65 |
-
def test_invalid_action_done_is_false(
|
| 66 |
-
self, manager: EpisodeManager
|
| 67 |
-
) -> None:
|
| 68 |
action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
|
| 69 |
_, _, done, _ = manager.step(action)
|
| 70 |
assert done is False
|
|
@@ -108,9 +102,7 @@ class TestNoUnhandledExceptions:
|
|
| 108 |
with pytest.raises(RuntimeError, match="No active episode"):
|
| 109 |
em.step(action)
|
| 110 |
|
| 111 |
-
def test_multiple_invalid_steps_do_not_raise(
|
| 112 |
-
self, manager: EpisodeManager
|
| 113 |
-
) -> None:
|
| 114 |
action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
|
| 115 |
for _ in range(5):
|
| 116 |
_, reward, _, _ = manager.step(action)
|
|
|
|
| 33 |
class TestInvalidActionReturnsNegativeRValidity:
|
| 34 |
"""Requirement 10.1: invalid actions → negative r_validity, latent unchanged."""
|
| 35 |
|
| 36 |
+
def test_invalid_action_r_validity_negative(self, manager: EpisodeManager) -> None:
|
|
|
|
|
|
|
| 37 |
# SUBMIT_TO_FDA_REVIEW not permitted in literature_review phase
|
| 38 |
action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
|
| 39 |
_, reward, _, _ = manager.step(action)
|
| 40 |
assert reward.r_validity < 0, "r_validity must be negative for invalid action"
|
| 41 |
|
| 42 |
+
def test_invalid_action_state_unchanged(self, manager: EpisodeManager) -> None:
|
|
|
|
|
|
|
| 43 |
action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
|
| 44 |
history_before = list(manager._latent.action_history)
|
| 45 |
step_before = len(history_before)
|
|
|
|
| 58 |
assert len(obs.rule_violations) > 0
|
| 59 |
assert len(info["violations"]) > 0
|
| 60 |
|
| 61 |
+
def test_invalid_action_done_is_false(self, manager: EpisodeManager) -> None:
|
|
|
|
|
|
|
| 62 |
action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
|
| 63 |
_, _, done, _ = manager.step(action)
|
| 64 |
assert done is False
|
|
|
|
| 102 |
with pytest.raises(RuntimeError, match="No active episode"):
|
| 103 |
em.step(action)
|
| 104 |
|
| 105 |
+
def test_multiple_invalid_steps_do_not_raise(self, manager: EpisodeManager) -> None:
|
|
|
|
|
|
|
| 106 |
action = _make_action(ActionType.SUBMIT_TO_FDA_REVIEW)
|
| 107 |
for _ in range(5):
|
| 108 |
_, reward, _, _ = manager.step(action)
|
tests/test_judge.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for server/judge.py — TrialJudge multi-layer verification.
|
| 3 |
+
|
| 4 |
+
Covers:
|
| 5 |
+
- Layer 1 programmatic checks (power, p-value, FDA compliance, budget)
|
| 6 |
+
- Layer 2 persona selection (junior/senior/principal)
|
| 7 |
+
- Overconfidence penalty
|
| 8 |
+
- Hint generation for junior persona
|
| 9 |
+
- No unhandled exceptions on any valid input (req 10.4)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import pytest
|
| 15 |
+
|
| 16 |
+
from models import ActionType, TrialAction, TrialLatentState, TrialState
|
| 17 |
+
from server.judge import JudgeResult, TrialJudge, _select_persona
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Fixtures
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _make_latent(**overrides) -> TrialLatentState:
|
| 25 |
+
defaults = dict(
|
| 26 |
+
true_effect_size=0.8,
|
| 27 |
+
true_side_effect_rate=0.05,
|
| 28 |
+
true_responder_population="all",
|
| 29 |
+
true_responder_criteria=[],
|
| 30 |
+
true_dose_response={},
|
| 31 |
+
true_mechanism="unknown",
|
| 32 |
+
placebo_response_rate=0.1,
|
| 33 |
+
dropout_rate=0.05,
|
| 34 |
+
site_variability=0.0,
|
| 35 |
+
measurement_noise=0.0,
|
| 36 |
+
budget_remaining=500_000.0,
|
| 37 |
+
time_remaining_days=300,
|
| 38 |
+
patients_enrolled=200,
|
| 39 |
+
phase_i_complete=True,
|
| 40 |
+
mtd_identified=True,
|
| 41 |
+
effect_estimated=True,
|
| 42 |
+
protocol_submitted=True,
|
| 43 |
+
interim_complete=True,
|
| 44 |
+
trial_complete=True,
|
| 45 |
+
adverse_events=0,
|
| 46 |
+
episode_phase="analysis",
|
| 47 |
+
action_history=["run_primary_analysis"],
|
| 48 |
+
seed=42,
|
| 49 |
+
)
|
| 50 |
+
defaults.update(overrides)
|
| 51 |
+
return TrialLatentState(**defaults)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _make_state(difficulty: float = 0.3) -> TrialState:
|
| 55 |
+
return TrialState(
|
| 56 |
+
episode_id="test-ep",
|
| 57 |
+
step_count=5,
|
| 58 |
+
difficulty=difficulty,
|
| 59 |
+
scenario_id="solid_tumor_chemo",
|
| 60 |
+
curriculum_tier="0",
|
| 61 |
+
curriculum_stats={},
|
| 62 |
+
action_diversity=0.8,
|
| 63 |
+
phase_compliance_rate=1.0,
|
| 64 |
+
is_resolved=False,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _make_action(
|
| 69 |
+
action_type: ActionType = ActionType.RUN_PRIMARY_ANALYSIS,
|
| 70 |
+
confidence: float = 0.5,
|
| 71 |
+
**params,
|
| 72 |
+
) -> TrialAction:
|
| 73 |
+
return TrialAction(
|
| 74 |
+
action_type=action_type,
|
| 75 |
+
parameters=params,
|
| 76 |
+
justification="test",
|
| 77 |
+
confidence=confidence,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# Persona selection
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def test_persona_junior():
|
| 87 |
+
assert _select_persona(0.0) == "junior"
|
| 88 |
+
assert _select_persona(0.39) == "junior"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def test_persona_senior():
|
| 92 |
+
assert _select_persona(0.4) == "senior"
|
| 93 |
+
assert _select_persona(0.7) == "senior"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_persona_principal():
|
| 97 |
+
assert _select_persona(0.71) == "principal"
|
| 98 |
+
assert _select_persona(1.0) == "principal"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ---------------------------------------------------------------------------
|
| 102 |
+
# Layer 1: budget check
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def test_budget_exhausted_fails():
|
| 107 |
+
judge = TrialJudge()
|
| 108 |
+
latent = _make_latent(budget_remaining=0.0)
|
| 109 |
+
result = judge.verify(_make_action(), _make_state(), latent)
|
| 110 |
+
assert not result.passed
|
| 111 |
+
assert any("budget" in v.lower() for v in result.violations)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def test_budget_negative_fails():
|
| 115 |
+
judge = TrialJudge()
|
| 116 |
+
latent = _make_latent(budget_remaining=-100.0)
|
| 117 |
+
result = judge.verify(_make_action(), _make_state(), latent)
|
| 118 |
+
assert not result.passed
|
| 119 |
+
assert any("budget" in v.lower() for v in result.violations)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_budget_positive_passes_budget_check():
|
| 123 |
+
judge = TrialJudge()
|
| 124 |
+
latent = _make_latent(budget_remaining=1.0)
|
| 125 |
+
# Other checks may still fail, but budget violation should not be present
|
| 126 |
+
result = judge.verify(_make_action(), _make_state(), latent)
|
| 127 |
+
assert not any("budget" in v.lower() for v in result.violations)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
# Layer 1: power check
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def test_low_power_fails():
|
| 136 |
+
judge = TrialJudge()
|
| 137 |
+
# Very small effect + few patients → low power
|
| 138 |
+
latent = _make_latent(true_effect_size=0.01, patients_enrolled=10)
|
| 139 |
+
result = judge.verify(_make_action(), _make_state(), latent)
|
| 140 |
+
assert not result.passed
|
| 141 |
+
assert any("power" in v.lower() for v in result.violations)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def test_sufficient_power_no_power_violation():
|
| 145 |
+
judge = TrialJudge()
|
| 146 |
+
# Large effect + many patients → high power
|
| 147 |
+
latent = _make_latent(true_effect_size=1.5, patients_enrolled=500)
|
| 148 |
+
result = judge.verify(_make_action(), _make_state(), latent)
|
| 149 |
+
assert not any("power" in v.lower() for v in result.violations)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ---------------------------------------------------------------------------
|
| 153 |
+
# Layer 1: p-value check
|
| 154 |
+
# ---------------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def test_nonsignificant_pvalue_fails():
|
| 158 |
+
judge = TrialJudge()
|
| 159 |
+
# Zero effect → p-value = 1.0
|
| 160 |
+
latent = _make_latent(true_effect_size=0.0, patients_enrolled=100)
|
| 161 |
+
result = judge.verify(_make_action(), _make_state(), latent)
|
| 162 |
+
assert not result.passed
|
| 163 |
+
assert any("p-value" in v.lower() for v in result.violations)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def test_significant_pvalue_no_pvalue_violation():
|
| 167 |
+
judge = TrialJudge()
|
| 168 |
+
# Large effect + many patients → very small p-value
|
| 169 |
+
latent = _make_latent(true_effect_size=2.0, patients_enrolled=1000)
|
| 170 |
+
result = judge.verify(_make_action(), _make_state(), latent)
|
| 171 |
+
assert not any("p-value" in v.lower() for v in result.violations)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
# Layer 1: FDA compliance
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def test_fda_violation_propagated():
|
| 180 |
+
judge = TrialJudge()
|
| 181 |
+
# Action not permitted in current phase
|
| 182 |
+
latent = _make_latent(episode_phase="literature_review")
|
| 183 |
+
action = _make_action(action_type=ActionType.SUBMIT_TO_FDA_REVIEW)
|
| 184 |
+
result = judge.verify(action, _make_state(), latent)
|
| 185 |
+
assert not result.passed
|
| 186 |
+
assert len(result.violations) > 0
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
# Overconfidence penalty
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def test_overconfidence_penalty_applied_when_high_confidence_and_wrong():
|
| 195 |
+
judge = TrialJudge()
|
| 196 |
+
latent = _make_latent(budget_remaining=0.0) # guaranteed violation
|
| 197 |
+
action = _make_action(confidence=0.9)
|
| 198 |
+
result = judge.verify(action, _make_state(), latent)
|
| 199 |
+
assert not result.passed
|
| 200 |
+
assert result.overconfidence_penalty < 0.0
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def test_no_overconfidence_penalty_when_low_confidence():
|
| 204 |
+
judge = TrialJudge()
|
| 205 |
+
latent = _make_latent(budget_remaining=0.0) # violation present
|
| 206 |
+
action = _make_action(confidence=0.5)
|
| 207 |
+
result = judge.verify(action, _make_state(), latent)
|
| 208 |
+
assert result.overconfidence_penalty == 0.0
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def test_no_overconfidence_penalty_when_passed():
|
| 212 |
+
judge = TrialJudge()
|
| 213 |
+
# Use large effect + many patients to pass power/p-value, valid phase/action
|
| 214 |
+
latent = _make_latent(
|
| 215 |
+
true_effect_size=2.0,
|
| 216 |
+
patients_enrolled=1000,
|
| 217 |
+
budget_remaining=500_000.0,
|
| 218 |
+
episode_phase="analysis",
|
| 219 |
+
interim_complete=True,
|
| 220 |
+
trial_complete=True,
|
| 221 |
+
)
|
| 222 |
+
action = _make_action(action_type=ActionType.RUN_PRIMARY_ANALYSIS, confidence=0.95)
|
| 223 |
+
result = judge.verify(action, _make_state(), latent)
|
| 224 |
+
if result.passed:
|
| 225 |
+
assert result.overconfidence_penalty == 0.0
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def test_overconfidence_penalty_scales_with_violation_count():
|
| 229 |
+
judge = TrialJudge()
|
| 230 |
+
# Multiple violations: budget + low power + non-significant p-value
|
| 231 |
+
latent = _make_latent(
|
| 232 |
+
budget_remaining=0.0,
|
| 233 |
+
true_effect_size=0.0,
|
| 234 |
+
patients_enrolled=1,
|
| 235 |
+
)
|
| 236 |
+
action = _make_action(confidence=0.9)
|
| 237 |
+
result = judge.verify(action, _make_state(), latent)
|
| 238 |
+
assert result.overconfidence_penalty <= -1.0 # at least 2 violations × -0.5
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ---------------------------------------------------------------------------
|
| 242 |
+
# Layer 2: persona in result
|
| 243 |
+
# ---------------------------------------------------------------------------
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def test_junior_persona_in_result():
|
| 247 |
+
judge = TrialJudge()
|
| 248 |
+
latent = _make_latent(budget_remaining=0.0)
|
| 249 |
+
result = judge.verify(_make_action(), _make_state(difficulty=0.2), latent)
|
| 250 |
+
assert result.persona == "junior"
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def test_senior_persona_in_result():
|
| 254 |
+
judge = TrialJudge()
|
| 255 |
+
latent = _make_latent(budget_remaining=0.0)
|
| 256 |
+
result = judge.verify(_make_action(), _make_state(difficulty=0.5), latent)
|
| 257 |
+
assert result.persona == "senior"
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def test_principal_persona_in_result():
|
| 261 |
+
judge = TrialJudge()
|
| 262 |
+
latent = _make_latent(budget_remaining=0.0)
|
| 263 |
+
result = judge.verify(_make_action(), _make_state(difficulty=0.9), latent)
|
| 264 |
+
assert result.persona == "principal"
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ---------------------------------------------------------------------------
|
| 268 |
+
# Layer 2: hints
|
| 269 |
+
# ---------------------------------------------------------------------------
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def test_junior_gets_hint_on_failure():
|
| 273 |
+
judge = TrialJudge()
|
| 274 |
+
latent = _make_latent(budget_remaining=0.0)
|
| 275 |
+
result = judge.verify(_make_action(), _make_state(difficulty=0.2), latent)
|
| 276 |
+
assert not result.passed
|
| 277 |
+
assert result.hint is not None and len(result.hint) > 0
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def test_senior_no_hint_on_failure():
|
| 281 |
+
judge = TrialJudge()
|
| 282 |
+
latent = _make_latent(budget_remaining=0.0)
|
| 283 |
+
result = judge.verify(_make_action(), _make_state(difficulty=0.5), latent)
|
| 284 |
+
assert result.hint is None
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def test_principal_no_hint_on_failure():
|
| 288 |
+
judge = TrialJudge()
|
| 289 |
+
latent = _make_latent(budget_remaining=0.0)
|
| 290 |
+
result = judge.verify(_make_action(), _make_state(difficulty=0.9), latent)
|
| 291 |
+
assert result.hint is None
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def test_junior_gets_hint_on_pass():
|
| 295 |
+
judge = TrialJudge()
|
| 296 |
+
latent = _make_latent(
|
| 297 |
+
true_effect_size=2.0,
|
| 298 |
+
patients_enrolled=1000,
|
| 299 |
+
budget_remaining=500_000.0,
|
| 300 |
+
episode_phase="analysis",
|
| 301 |
+
interim_complete=True,
|
| 302 |
+
trial_complete=True,
|
| 303 |
+
)
|
| 304 |
+
action = _make_action(action_type=ActionType.RUN_PRIMARY_ANALYSIS)
|
| 305 |
+
result = judge.verify(action, _make_state(difficulty=0.2), latent)
|
| 306 |
+
if result.passed:
|
| 307 |
+
assert result.hint is not None
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
# JudgeResult model
|
| 312 |
+
# ---------------------------------------------------------------------------
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def test_judge_result_is_pydantic_model():
|
| 316 |
+
result = JudgeResult(
|
| 317 |
+
passed=True,
|
| 318 |
+
violations=[],
|
| 319 |
+
feedback="ok",
|
| 320 |
+
hint=None,
|
| 321 |
+
overconfidence_penalty=0.0,
|
| 322 |
+
persona="senior",
|
| 323 |
+
)
|
| 324 |
+
assert result.passed is True
|
| 325 |
+
assert result.persona == "senior"
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# ---------------------------------------------------------------------------
|
| 329 |
+
# Req 10.4: no unhandled exceptions
|
| 330 |
+
# ---------------------------------------------------------------------------
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@pytest.mark.parametrize(
|
| 334 |
+
"action_type",
|
| 335 |
+
list(ActionType),
|
| 336 |
+
)
|
| 337 |
+
def test_no_exception_for_any_action_type(action_type):
|
| 338 |
+
"""TrialJudge.verify must never raise for any valid action type (req 10.4)."""
|
| 339 |
+
judge = TrialJudge()
|
| 340 |
+
latent = _make_latent()
|
| 341 |
+
state = _make_state()
|
| 342 |
+
action = TrialAction(
|
| 343 |
+
action_type=action_type,
|
| 344 |
+
parameters={},
|
| 345 |
+
justification="test",
|
| 346 |
+
confidence=0.5,
|
| 347 |
+
)
|
| 348 |
+
# Must not raise
|
| 349 |
+
result = judge.verify(action, state, latent)
|
| 350 |
+
assert isinstance(result, JudgeResult)
|
tests/test_noise_model.py
CHANGED
|
@@ -44,16 +44,12 @@ class TestNoiseModelIdempotence:
|
|
| 44 |
r2 = NoiseModel(seed=42).randomize(base_scenario)
|
| 45 |
assert r1.time_budget_days == r2.time_budget_days
|
| 46 |
|
| 47 |
-
def test_same_seed_same_dropout_range(
|
| 48 |
-
self, base_scenario: ScenarioConfig
|
| 49 |
-
) -> None:
|
| 50 |
r1 = NoiseModel(seed=42).randomize(base_scenario)
|
| 51 |
r2 = NoiseModel(seed=42).randomize(base_scenario)
|
| 52 |
assert r1.dropout_rate_range == r2.dropout_rate_range
|
| 53 |
|
| 54 |
-
def test_same_seed_same_placebo_range(
|
| 55 |
-
self, base_scenario: ScenarioConfig
|
| 56 |
-
) -> None:
|
| 57 |
r1 = NoiseModel(seed=42).randomize(base_scenario)
|
| 58 |
r2 = NoiseModel(seed=42).randomize(base_scenario)
|
| 59 |
assert r1.placebo_response_range == r2.placebo_response_range
|
|
@@ -114,9 +110,7 @@ class TestNoiseModelRanges:
|
|
| 114 |
assert result.side_effect_rate_range == base_scenario.side_effect_rate_range
|
| 115 |
assert result.min_sample_size == base_scenario.min_sample_size
|
| 116 |
|
| 117 |
-
def test_time_budget_at_least_one_day(
|
| 118 |
-
self, base_scenario: ScenarioConfig
|
| 119 |
-
) -> None:
|
| 120 |
for seed in range(50):
|
| 121 |
result = NoiseModel(seed=seed).randomize(base_scenario)
|
| 122 |
assert result.time_budget_days >= 1
|
|
|
|
| 44 |
r2 = NoiseModel(seed=42).randomize(base_scenario)
|
| 45 |
assert r1.time_budget_days == r2.time_budget_days
|
| 46 |
|
| 47 |
+
def test_same_seed_same_dropout_range(self, base_scenario: ScenarioConfig) -> None:
|
|
|
|
|
|
|
| 48 |
r1 = NoiseModel(seed=42).randomize(base_scenario)
|
| 49 |
r2 = NoiseModel(seed=42).randomize(base_scenario)
|
| 50 |
assert r1.dropout_rate_range == r2.dropout_rate_range
|
| 51 |
|
| 52 |
+
def test_same_seed_same_placebo_range(self, base_scenario: ScenarioConfig) -> None:
|
|
|
|
|
|
|
| 53 |
r1 = NoiseModel(seed=42).randomize(base_scenario)
|
| 54 |
r2 = NoiseModel(seed=42).randomize(base_scenario)
|
| 55 |
assert r1.placebo_response_range == r2.placebo_response_range
|
|
|
|
| 110 |
assert result.side_effect_rate_range == base_scenario.side_effect_rate_range
|
| 111 |
assert result.min_sample_size == base_scenario.min_sample_size
|
| 112 |
|
| 113 |
+
def test_time_budget_at_least_one_day(self, base_scenario: ScenarioConfig) -> None:
|
|
|
|
|
|
|
| 114 |
for seed in range(50):
|
| 115 |
result = NoiseModel(seed=seed).randomize(base_scenario)
|
| 116 |
assert result.time_budget_days >= 1
|
tests/test_output_generator.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for OutputGenerator — noisy TrialObservation generation (Task 15).
|
| 3 |
+
|
| 4 |
+
Requirements 9.1, 9.2, 9.3, 9.4:
|
| 5 |
+
- OutputGenerator produces a TrialObservation from a TrialLatentState
|
| 6 |
+
- Agent never sees raw hidden values (noise is always injected)
|
| 7 |
+
- phase_data, resource_status, available_actions are correctly populated
|
| 8 |
+
- Measurement noise and site variability are applied via NoiseModel
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import pytest
|
| 14 |
+
|
| 15 |
+
from models import ActionType, TrialLatentState, TrialState
|
| 16 |
+
from server.noise_model import NoiseModel
|
| 17 |
+
from server.simulator.output_generator import OutputGenerator
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Fixtures
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture()
|
| 25 |
+
def base_latent() -> TrialLatentState:
|
| 26 |
+
"""A minimal TrialLatentState for testing."""
|
| 27 |
+
return TrialLatentState(
|
| 28 |
+
true_effect_size=0.5,
|
| 29 |
+
true_side_effect_rate=0.10,
|
| 30 |
+
true_responder_population="BRCA1+",
|
| 31 |
+
true_responder_criteria=["BRCA1+", "age < 65"],
|
| 32 |
+
true_dose_response={10.0: 0.2, 20.0: 0.4, 40.0: 0.7},
|
| 33 |
+
true_mechanism="PARP inhibition",
|
| 34 |
+
placebo_response_rate=0.15,
|
| 35 |
+
dropout_rate=0.08,
|
| 36 |
+
site_variability=0.05,
|
| 37 |
+
measurement_noise=0.05,
|
| 38 |
+
budget_remaining=500_000.0,
|
| 39 |
+
time_remaining_days=200,
|
| 40 |
+
patients_enrolled=0,
|
| 41 |
+
phase_i_complete=False,
|
| 42 |
+
mtd_identified=False,
|
| 43 |
+
effect_estimated=False,
|
| 44 |
+
protocol_submitted=False,
|
| 45 |
+
interim_complete=False,
|
| 46 |
+
trial_complete=False,
|
| 47 |
+
adverse_events=0,
|
| 48 |
+
episode_phase="design",
|
| 49 |
+
action_history=[],
|
| 50 |
+
seed=42,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@pytest.fixture()
|
| 55 |
+
def trial_state() -> TrialState:
|
| 56 |
+
return TrialState(
|
| 57 |
+
episode_id="ep-001",
|
| 58 |
+
step_count=1,
|
| 59 |
+
difficulty=0.5,
|
| 60 |
+
scenario_id="solid_tumor_chemo",
|
| 61 |
+
curriculum_tier="tier_0",
|
| 62 |
+
curriculum_stats={},
|
| 63 |
+
action_diversity=0.0,
|
| 64 |
+
phase_compliance_rate=1.0,
|
| 65 |
+
is_resolved=False,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@pytest.fixture()
|
| 70 |
+
def generator() -> OutputGenerator:
|
| 71 |
+
return OutputGenerator(noise_model=NoiseModel(seed=42))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _make_obs(generator, latent, trial_state, **kwargs):
|
| 75 |
+
defaults = dict(
|
| 76 |
+
steps_taken=1,
|
| 77 |
+
max_steps=20,
|
| 78 |
+
rule_violations=[],
|
| 79 |
+
done=False,
|
| 80 |
+
reward=0.0,
|
| 81 |
+
scenario_description="Test scenario",
|
| 82 |
+
hint="",
|
| 83 |
+
)
|
| 84 |
+
defaults.update(kwargs)
|
| 85 |
+
return generator.generate(latent, trial_state, **defaults)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
# Basic structure tests
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TestObservationStructure:
|
| 94 |
+
"""TrialObservation has all required fields populated."""
|
| 95 |
+
|
| 96 |
+
def test_returns_trial_observation(self, generator, base_latent, trial_state):
|
| 97 |
+
from models import TrialObservation
|
| 98 |
+
|
| 99 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 100 |
+
assert isinstance(obs, TrialObservation)
|
| 101 |
+
|
| 102 |
+
def test_scenario_description_passed_through(
|
| 103 |
+
self, generator, base_latent, trial_state
|
| 104 |
+
):
|
| 105 |
+
obs = _make_obs(
|
| 106 |
+
generator, base_latent, trial_state, scenario_description="My scenario"
|
| 107 |
+
)
|
| 108 |
+
assert obs.scenario_description == "My scenario"
|
| 109 |
+
|
| 110 |
+
def test_steps_taken_and_max_steps(self, generator, base_latent, trial_state):
|
| 111 |
+
obs = _make_obs(
|
| 112 |
+
generator, base_latent, trial_state, steps_taken=5, max_steps=30
|
| 113 |
+
)
|
| 114 |
+
assert obs.steps_taken == 5
|
| 115 |
+
assert obs.max_steps == 30
|
| 116 |
+
|
| 117 |
+
def test_done_and_reward_passed_through(self, generator, base_latent, trial_state):
|
| 118 |
+
obs = _make_obs(generator, base_latent, trial_state, done=True, reward=1.5)
|
| 119 |
+
assert obs.done is True
|
| 120 |
+
assert obs.reward == 1.5
|
| 121 |
+
|
| 122 |
+
def test_rule_violations_passed_through(self, generator, base_latent, trial_state):
|
| 123 |
+
violations = ["violation A", "violation B"]
|
| 124 |
+
obs = _make_obs(generator, base_latent, trial_state, rule_violations=violations)
|
| 125 |
+
assert obs.rule_violations == violations
|
| 126 |
+
|
| 127 |
+
def test_hint_passed_through(self, generator, base_latent, trial_state):
|
| 128 |
+
obs = _make_obs(generator, base_latent, trial_state, hint="Try Phase I first")
|
| 129 |
+
assert obs.hint == "Try Phase I first"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
# resource_status tests
|
| 134 |
+
# ---------------------------------------------------------------------------
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class TestResourceStatus:
|
| 138 |
+
"""resource_status reflects latent state resource fields."""
|
| 139 |
+
|
| 140 |
+
def test_budget_remaining(self, generator, base_latent, trial_state):
|
| 141 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 142 |
+
assert obs.resource_status["budget_remaining"] == base_latent.budget_remaining
|
| 143 |
+
|
| 144 |
+
def test_time_remaining_days(self, generator, base_latent, trial_state):
|
| 145 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 146 |
+
assert (
|
| 147 |
+
obs.resource_status["time_remaining_days"]
|
| 148 |
+
== base_latent.time_remaining_days
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def test_patients_enrolled(self, generator, base_latent, trial_state):
|
| 152 |
+
latent = base_latent.model_copy(update={"patients_enrolled": 50})
|
| 153 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 154 |
+
assert obs.resource_status["patients_enrolled"] == 50
|
| 155 |
+
|
| 156 |
+
def test_resource_status_has_three_keys(self, generator, base_latent, trial_state):
|
| 157 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 158 |
+
assert set(obs.resource_status.keys()) == {
|
| 159 |
+
"budget_remaining",
|
| 160 |
+
"time_remaining_days",
|
| 161 |
+
"patients_enrolled",
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# phase_data tests — noise injection
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class TestPhaseDataNoiseInjection:
|
| 171 |
+
"""Agent never sees raw hidden values — noise is always injected."""
|
| 172 |
+
|
| 173 |
+
def test_true_effect_size_not_in_phase_data(
|
| 174 |
+
self, generator, base_latent, trial_state
|
| 175 |
+
):
|
| 176 |
+
"""Raw true_effect_size must never appear directly in phase_data."""
|
| 177 |
+
latent = base_latent.model_copy(update={"effect_estimated": True})
|
| 178 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 179 |
+
# observed_effect_size should differ from true value (noise injected)
|
| 180 |
+
# We can't guarantee they differ by chance, but the key should be present
|
| 181 |
+
assert "observed_effect_size" in obs.phase_data
|
| 182 |
+
|
| 183 |
+
def test_effect_size_not_exposed_before_estimation(
|
| 184 |
+
self, generator, base_latent, trial_state
|
| 185 |
+
):
|
| 186 |
+
"""observed_effect_size should not appear before ESTIMATE_EFFECT_SIZE."""
|
| 187 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 188 |
+
assert "observed_effect_size" not in obs.phase_data
|
| 189 |
+
|
| 190 |
+
def test_effect_size_exposed_after_estimation(
|
| 191 |
+
self, generator, base_latent, trial_state
|
| 192 |
+
):
|
| 193 |
+
latent = base_latent.model_copy(update={"effect_estimated": True})
|
| 194 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 195 |
+
assert "observed_effect_size" in obs.phase_data
|
| 196 |
+
assert "effect_size_ci" in obs.phase_data
|
| 197 |
+
|
| 198 |
+
def test_ae_rate_not_exposed_before_phase_i(
|
| 199 |
+
self, generator, base_latent, trial_state
|
| 200 |
+
):
|
| 201 |
+
"""Adverse event rate should not appear before Phase I or safety signal."""
|
| 202 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 203 |
+
assert "observed_adverse_event_rate" not in obs.phase_data
|
| 204 |
+
|
| 205 |
+
def test_ae_rate_exposed_after_phase_i(self, generator, base_latent, trial_state):
|
| 206 |
+
latent = base_latent.model_copy(update={"phase_i_complete": True})
|
| 207 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 208 |
+
assert "observed_adverse_event_rate" in obs.phase_data
|
| 209 |
+
|
| 210 |
+
def test_ae_rate_exposed_after_safety_signal(
|
| 211 |
+
self, generator, base_latent, trial_state
|
| 212 |
+
):
|
| 213 |
+
latent = base_latent.model_copy(
|
| 214 |
+
update={"action_history": [ActionType.OBSERVE_SAFETY_SIGNAL.value]}
|
| 215 |
+
)
|
| 216 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 217 |
+
assert "observed_adverse_event_rate" in obs.phase_data
|
| 218 |
+
|
| 219 |
+
def test_ae_rate_is_clipped_to_0_1(self, generator, base_latent, trial_state):
|
| 220 |
+
latent = base_latent.model_copy(
|
| 221 |
+
update={"phase_i_complete": True, "true_side_effect_rate": 0.99}
|
| 222 |
+
)
|
| 223 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 224 |
+
rate = obs.phase_data["observed_adverse_event_rate"]
|
| 225 |
+
assert 0.0 <= rate <= 1.0
|
| 226 |
+
|
| 227 |
+
def test_placebo_response_not_exposed_before_interim(
|
| 228 |
+
self, generator, base_latent, trial_state
|
| 229 |
+
):
|
| 230 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 231 |
+
assert "observed_placebo_response" not in obs.phase_data
|
| 232 |
+
|
| 233 |
+
def test_placebo_response_exposed_after_interim(
|
| 234 |
+
self, generator, base_latent, trial_state
|
| 235 |
+
):
|
| 236 |
+
latent = base_latent.model_copy(update={"interim_complete": True})
|
| 237 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 238 |
+
assert "observed_placebo_response" in obs.phase_data
|
| 239 |
+
|
| 240 |
+
def test_dose_response_not_exposed_before_phase_i(
|
| 241 |
+
self, generator, base_latent, trial_state
|
| 242 |
+
):
|
| 243 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 244 |
+
assert "observed_dose_response" not in obs.phase_data
|
| 245 |
+
|
| 246 |
+
def test_dose_response_exposed_after_phase_i(
|
| 247 |
+
self, generator, base_latent, trial_state
|
| 248 |
+
):
|
| 249 |
+
latent = base_latent.model_copy(update={"phase_i_complete": True})
|
| 250 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 251 |
+
assert "observed_dose_response" in obs.phase_data
|
| 252 |
+
# All dose-response values should be clipped to [0, 1]
|
| 253 |
+
for v in obs.phase_data["observed_dose_response"].values():
|
| 254 |
+
assert 0.0 <= v <= 1.0
|
| 255 |
+
|
| 256 |
+
def test_dropout_rate_not_exposed_before_enrollment(
|
| 257 |
+
self, generator, base_latent, trial_state
|
| 258 |
+
):
|
| 259 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 260 |
+
assert "observed_dropout_rate" not in obs.phase_data
|
| 261 |
+
|
| 262 |
+
def test_dropout_rate_exposed_after_enrollment(
|
| 263 |
+
self, generator, base_latent, trial_state
|
| 264 |
+
):
|
| 265 |
+
latent = base_latent.model_copy(update={"patients_enrolled": 10})
|
| 266 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 267 |
+
assert "observed_dropout_rate" in obs.phase_data
|
| 268 |
+
|
| 269 |
+
def test_responder_population_hint_not_exposed_without_biomarker(
|
| 270 |
+
self, generator, base_latent, trial_state
|
| 271 |
+
):
|
| 272 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 273 |
+
assert "responder_population_hint" not in obs.phase_data
|
| 274 |
+
|
| 275 |
+
def test_responder_population_hint_exposed_after_biomarker(
|
| 276 |
+
self, generator, base_latent, trial_state
|
| 277 |
+
):
|
| 278 |
+
latent = base_latent.model_copy(
|
| 279 |
+
update={"action_history": [ActionType.ADD_BIOMARKER_STRATIFICATION.value]}
|
| 280 |
+
)
|
| 281 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 282 |
+
assert "responder_population_hint" in obs.phase_data
|
| 283 |
+
# Population label is revealed but NOT the true criteria
|
| 284 |
+
assert obs.phase_data["responder_population_hint"] == "BRCA1+"
|
| 285 |
+
assert "true_responder_criteria" not in obs.phase_data
|
| 286 |
+
|
| 287 |
+
def test_milestone_flags_in_phase_data(self, generator, base_latent, trial_state):
|
| 288 |
+
"""Milestone flags are observable (not hidden values)."""
|
| 289 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 290 |
+
assert "phase_i_complete" in obs.phase_data
|
| 291 |
+
assert "mtd_identified" in obs.phase_data
|
| 292 |
+
assert "effect_estimated" in obs.phase_data
|
| 293 |
+
assert "protocol_submitted" in obs.phase_data
|
| 294 |
+
assert "interim_complete" in obs.phase_data
|
| 295 |
+
assert "trial_complete" in obs.phase_data
|
| 296 |
+
|
| 297 |
+
def test_true_mechanism_not_in_phase_data(
|
| 298 |
+
self, generator, base_latent, trial_state
|
| 299 |
+
):
|
| 300 |
+
"""true_mechanism is a hidden value and must never appear in phase_data."""
|
| 301 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 302 |
+
assert "true_mechanism" not in obs.phase_data
|
| 303 |
+
|
| 304 |
+
def test_true_responder_criteria_not_in_phase_data(
|
| 305 |
+
self, generator, base_latent, trial_state
|
| 306 |
+
):
|
| 307 |
+
"""true_responder_criteria is hidden and must never appear in phase_data."""
|
| 308 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 309 |
+
assert "true_responder_criteria" not in obs.phase_data
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# ---------------------------------------------------------------------------
|
| 313 |
+
# available_actions tests
|
| 314 |
+
# ---------------------------------------------------------------------------
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class TestAvailableActions:
|
| 318 |
+
"""available_actions reflects phase-permitted actions filtered by prerequisites."""
|
| 319 |
+
|
| 320 |
+
def test_available_actions_is_list_of_strings(
|
| 321 |
+
self, generator, base_latent, trial_state
|
| 322 |
+
):
|
| 323 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 324 |
+
assert isinstance(obs.available_actions, list)
|
| 325 |
+
assert all(isinstance(a, str) for a in obs.available_actions)
|
| 326 |
+
|
| 327 |
+
def test_design_phase_actions(self, generator, base_latent, trial_state):
|
| 328 |
+
"""In design phase with empty history, basic design actions are available."""
|
| 329 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 330 |
+
# SET_SAMPLE_SIZE, SET_INCLUSION_CRITERIA, SET_EXCLUSION_CRITERIA should be available
|
| 331 |
+
assert ActionType.SET_SAMPLE_SIZE.value in obs.available_actions
|
| 332 |
+
assert ActionType.SET_INCLUSION_CRITERIA.value in obs.available_actions
|
| 333 |
+
assert ActionType.SET_EXCLUSION_CRITERIA.value in obs.available_actions
|
| 334 |
+
|
| 335 |
+
def test_dosing_schedule_requires_primary_endpoint(
|
| 336 |
+
self, generator, base_latent, trial_state
|
| 337 |
+
):
|
| 338 |
+
"""SET_DOSING_SCHEDULE requires SET_PRIMARY_ENDPOINT in history."""
|
| 339 |
+
obs = _make_obs(generator, base_latent, trial_state)
|
| 340 |
+
# Without SET_PRIMARY_ENDPOINT in history, SET_DOSING_SCHEDULE should not be available
|
| 341 |
+
assert ActionType.SET_DOSING_SCHEDULE.value not in obs.available_actions
|
| 342 |
+
|
| 343 |
+
def test_dosing_schedule_available_after_primary_endpoint(
|
| 344 |
+
self, generator, base_latent, trial_state
|
| 345 |
+
):
|
| 346 |
+
latent = base_latent.model_copy(
|
| 347 |
+
update={"action_history": [ActionType.SET_PRIMARY_ENDPOINT.value]}
|
| 348 |
+
)
|
| 349 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 350 |
+
assert ActionType.SET_DOSING_SCHEDULE.value in obs.available_actions
|
| 351 |
+
|
| 352 |
+
def test_synthesize_conclusion_requires_trial_complete(
|
| 353 |
+
self, generator, base_latent, trial_state
|
| 354 |
+
):
|
| 355 |
+
latent = base_latent.model_copy(
|
| 356 |
+
update={"episode_phase": "submission", "trial_complete": False}
|
| 357 |
+
)
|
| 358 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 359 |
+
assert ActionType.SYNTHESIZE_CONCLUSION.value not in obs.available_actions
|
| 360 |
+
|
| 361 |
+
def test_synthesize_conclusion_available_when_trial_complete(
|
| 362 |
+
self, generator, base_latent, trial_state
|
| 363 |
+
):
|
| 364 |
+
latent = base_latent.model_copy(
|
| 365 |
+
update={"episode_phase": "submission", "trial_complete": True}
|
| 366 |
+
)
|
| 367 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 368 |
+
assert ActionType.SYNTHESIZE_CONCLUSION.value in obs.available_actions
|
| 369 |
+
|
| 370 |
+
def test_run_interim_analysis_requires_patients(
|
| 371 |
+
self, generator, base_latent, trial_state
|
| 372 |
+
):
|
| 373 |
+
latent = base_latent.model_copy(
|
| 374 |
+
update={"episode_phase": "monitoring", "patients_enrolled": 0}
|
| 375 |
+
)
|
| 376 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 377 |
+
assert ActionType.RUN_INTERIM_ANALYSIS.value not in obs.available_actions
|
| 378 |
+
|
| 379 |
+
def test_run_interim_analysis_available_with_patients(
|
| 380 |
+
self, generator, base_latent, trial_state
|
| 381 |
+
):
|
| 382 |
+
latent = base_latent.model_copy(
|
| 383 |
+
update={"episode_phase": "monitoring", "patients_enrolled": 50}
|
| 384 |
+
)
|
| 385 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 386 |
+
assert ActionType.RUN_INTERIM_ANALYSIS.value in obs.available_actions
|
| 387 |
+
|
| 388 |
+
def test_run_primary_analysis_requires_interim_complete(
|
| 389 |
+
self, generator, base_latent, trial_state
|
| 390 |
+
):
|
| 391 |
+
latent = base_latent.model_copy(
|
| 392 |
+
update={"episode_phase": "analysis", "interim_complete": False}
|
| 393 |
+
)
|
| 394 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 395 |
+
assert ActionType.RUN_PRIMARY_ANALYSIS.value not in obs.available_actions
|
| 396 |
+
|
| 397 |
+
def test_run_primary_analysis_available_after_interim(
|
| 398 |
+
self, generator, base_latent, trial_state
|
| 399 |
+
):
|
| 400 |
+
latent = base_latent.model_copy(
|
| 401 |
+
update={"episode_phase": "analysis", "interim_complete": True}
|
| 402 |
+
)
|
| 403 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 404 |
+
assert ActionType.RUN_PRIMARY_ANALYSIS.value in obs.available_actions
|
| 405 |
+
|
| 406 |
+
def test_unknown_phase_returns_empty_actions(
|
| 407 |
+
self, generator, base_latent, trial_state
|
| 408 |
+
):
|
| 409 |
+
latent = base_latent.model_copy(update={"episode_phase": "unknown_phase"})
|
| 410 |
+
obs = _make_obs(generator, latent, trial_state)
|
| 411 |
+
assert obs.available_actions == []
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# ---------------------------------------------------------------------------
|
| 415 |
+
# Determinism tests
|
| 416 |
+
# ---------------------------------------------------------------------------
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
class TestDeterminism:
|
| 420 |
+
"""Same seed + same latent state → same observation (requirement 9.2)."""
|
| 421 |
+
|
| 422 |
+
def test_same_seed_same_observed_effect(self, base_latent, trial_state):
|
| 423 |
+
latent = base_latent.model_copy(update={"effect_estimated": True})
|
| 424 |
+
obs1 = OutputGenerator(NoiseModel(seed=99)).generate(
|
| 425 |
+
latent,
|
| 426 |
+
trial_state,
|
| 427 |
+
steps_taken=1,
|
| 428 |
+
max_steps=20,
|
| 429 |
+
rule_violations=[],
|
| 430 |
+
done=False,
|
| 431 |
+
reward=0.0,
|
| 432 |
+
scenario_description="S",
|
| 433 |
+
hint="",
|
| 434 |
+
)
|
| 435 |
+
obs2 = OutputGenerator(NoiseModel(seed=99)).generate(
|
| 436 |
+
latent,
|
| 437 |
+
trial_state,
|
| 438 |
+
steps_taken=1,
|
| 439 |
+
max_steps=20,
|
| 440 |
+
rule_violations=[],
|
| 441 |
+
done=False,
|
| 442 |
+
reward=0.0,
|
| 443 |
+
scenario_description="S",
|
| 444 |
+
hint="",
|
| 445 |
+
)
|
| 446 |
+
assert (
|
| 447 |
+
obs1.phase_data["observed_effect_size"]
|
| 448 |
+
== obs2.phase_data["observed_effect_size"]
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
def test_different_seeds_different_observed_effect(self, base_latent, trial_state):
|
| 452 |
+
latent = base_latent.model_copy(update={"effect_estimated": True})
|
| 453 |
+
obs1 = OutputGenerator(NoiseModel(seed=1)).generate(
|
| 454 |
+
latent,
|
| 455 |
+
trial_state,
|
| 456 |
+
steps_taken=1,
|
| 457 |
+
max_steps=20,
|
| 458 |
+
rule_violations=[],
|
| 459 |
+
done=False,
|
| 460 |
+
reward=0.0,
|
| 461 |
+
scenario_description="S",
|
| 462 |
+
hint="",
|
| 463 |
+
)
|
| 464 |
+
obs2 = OutputGenerator(NoiseModel(seed=2)).generate(
|
| 465 |
+
latent,
|
| 466 |
+
trial_state,
|
| 467 |
+
steps_taken=1,
|
| 468 |
+
max_steps=20,
|
| 469 |
+
rule_violations=[],
|
| 470 |
+
done=False,
|
| 471 |
+
reward=0.0,
|
| 472 |
+
scenario_description="S",
|
| 473 |
+
hint="",
|
| 474 |
+
)
|
| 475 |
+
# Different seeds should (almost certainly) produce different noisy values
|
| 476 |
+
assert (
|
| 477 |
+
obs1.phase_data["observed_effect_size"]
|
| 478 |
+
!= obs2.phase_data["observed_effect_size"]
|
| 479 |
+
)
|
tests/test_phase_detector.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for server/phase_detector.py
|
| 3 |
+
|
| 4 |
+
Validates Requirements 8.5 and 9.4:
|
| 5 |
+
- detect_phase classifies actions into correct clinical workflow phases
|
| 6 |
+
- phase_order_correct is True for valid transitions, False for regressions/skips
|
| 7 |
+
- compute_phase_ordering_reward returns correct bonus/penalty values
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import pytest
|
| 13 |
+
|
| 14 |
+
from models import ActionType, TrialAction
|
| 15 |
+
from server.phase_detector import (
|
| 16 |
+
PHASE_BONUS,
|
| 17 |
+
PHASE_ORDER,
|
| 18 |
+
PHASE_SKIP_PENALTY,
|
| 19 |
+
compute_phase_ordering_reward,
|
| 20 |
+
detect_phase,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _action(action_type: ActionType) -> TrialAction:
|
| 25 |
+
return TrialAction(
|
| 26 |
+
action_type=action_type,
|
| 27 |
+
parameters={},
|
| 28 |
+
justification="test",
|
| 29 |
+
confidence=0.5,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Phase mapping tests
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TestPhaseMapping:
|
| 39 |
+
def test_hypothesis_actions(self):
|
| 40 |
+
for at in [
|
| 41 |
+
ActionType.ESTIMATE_EFFECT_SIZE,
|
| 42 |
+
ActionType.ADD_BIOMARKER_STRATIFICATION,
|
| 43 |
+
]:
|
| 44 |
+
phase, _ = detect_phase(_action(at), [])
|
| 45 |
+
assert phase == "hypothesis", f"{at} should map to hypothesis"
|
| 46 |
+
|
| 47 |
+
def test_design_actions(self):
|
| 48 |
+
design_actions = [
|
| 49 |
+
ActionType.SET_PRIMARY_ENDPOINT,
|
| 50 |
+
ActionType.SET_SAMPLE_SIZE,
|
| 51 |
+
ActionType.SET_INCLUSION_CRITERIA,
|
| 52 |
+
ActionType.SET_EXCLUSION_CRITERIA,
|
| 53 |
+
ActionType.SET_DOSING_SCHEDULE,
|
| 54 |
+
ActionType.SET_CONTROL_ARM,
|
| 55 |
+
ActionType.SET_RANDOMIZATION_RATIO,
|
| 56 |
+
ActionType.SET_BLINDING,
|
| 57 |
+
ActionType.REQUEST_PROTOCOL_AMENDMENT,
|
| 58 |
+
]
|
| 59 |
+
for at in design_actions:
|
| 60 |
+
phase, _ = detect_phase(_action(at), [])
|
| 61 |
+
assert phase == "design", f"{at} should map to design"
|
| 62 |
+
|
| 63 |
+
def test_enrollment_action(self):
|
| 64 |
+
phase, _ = detect_phase(_action(ActionType.ENROLL_PATIENTS), [])
|
| 65 |
+
assert phase == "enrollment"
|
| 66 |
+
|
| 67 |
+
def test_monitoring_actions(self):
|
| 68 |
+
monitoring_actions = [
|
| 69 |
+
ActionType.RUN_DOSE_ESCALATION,
|
| 70 |
+
ActionType.OBSERVE_SAFETY_SIGNAL,
|
| 71 |
+
ActionType.RUN_INTERIM_ANALYSIS,
|
| 72 |
+
ActionType.MODIFY_SAMPLE_SIZE,
|
| 73 |
+
]
|
| 74 |
+
for at in monitoring_actions:
|
| 75 |
+
phase, _ = detect_phase(_action(at), [])
|
| 76 |
+
assert phase == "monitoring", f"{at} should map to monitoring"
|
| 77 |
+
|
| 78 |
+
def test_analysis_actions(self):
|
| 79 |
+
for at in [ActionType.RUN_PRIMARY_ANALYSIS, ActionType.SYNTHESIZE_CONCLUSION]:
|
| 80 |
+
phase, _ = detect_phase(_action(at), [])
|
| 81 |
+
assert phase == "analysis", f"{at} should map to analysis"
|
| 82 |
+
|
| 83 |
+
def test_submission_action(self):
|
| 84 |
+
phase, _ = detect_phase(_action(ActionType.SUBMIT_TO_FDA_REVIEW), [])
|
| 85 |
+
assert phase == "submission"
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
# Phase order correctness tests
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TestPhaseOrderCorrectness:
|
| 94 |
+
def test_empty_history_always_correct(self):
|
| 95 |
+
for at in ActionType:
|
| 96 |
+
_, correct = detect_phase(_action(at), [])
|
| 97 |
+
assert correct is True, f"Empty history should always be correct for {at}"
|
| 98 |
+
|
| 99 |
+
def test_same_phase_is_correct(self):
|
| 100 |
+
_, correct = detect_phase(_action(ActionType.SET_SAMPLE_SIZE), ["design"])
|
| 101 |
+
assert correct is True
|
| 102 |
+
|
| 103 |
+
def test_advance_one_phase_is_correct(self):
|
| 104 |
+
_, correct = detect_phase(_action(ActionType.ENROLL_PATIENTS), ["design"])
|
| 105 |
+
assert correct is True
|
| 106 |
+
|
| 107 |
+
def test_regression_is_incorrect(self):
|
| 108 |
+
# Going from enrollment back to design
|
| 109 |
+
_, correct = detect_phase(_action(ActionType.SET_SAMPLE_SIZE), ["enrollment"])
|
| 110 |
+
assert correct is False
|
| 111 |
+
|
| 112 |
+
def test_skip_one_phase_is_incorrect(self):
|
| 113 |
+
# Jumping from hypothesis to enrollment (skipping design)
|
| 114 |
+
_, correct = detect_phase(_action(ActionType.ENROLL_PATIENTS), ["hypothesis"])
|
| 115 |
+
assert correct is False
|
| 116 |
+
|
| 117 |
+
def test_skip_multiple_phases_is_incorrect(self):
|
| 118 |
+
# Jumping from design to analysis (skipping enrollment + monitoring)
|
| 119 |
+
_, correct = detect_phase(_action(ActionType.RUN_PRIMARY_ANALYSIS), ["design"])
|
| 120 |
+
assert correct is False
|
| 121 |
+
|
| 122 |
+
def test_valid_full_sequence(self):
|
| 123 |
+
"""Walk through the full phase sequence and verify all transitions are correct."""
|
| 124 |
+
history: list[str] = []
|
| 125 |
+
sequence = [
|
| 126 |
+
ActionType.ESTIMATE_EFFECT_SIZE, # hypothesis
|
| 127 |
+
ActionType.SET_PRIMARY_ENDPOINT, # design
|
| 128 |
+
ActionType.ENROLL_PATIENTS, # enrollment
|
| 129 |
+
ActionType.RUN_DOSE_ESCALATION, # monitoring
|
| 130 |
+
ActionType.RUN_PRIMARY_ANALYSIS, # analysis
|
| 131 |
+
ActionType.SUBMIT_TO_FDA_REVIEW, # submission
|
| 132 |
+
]
|
| 133 |
+
for at in sequence:
|
| 134 |
+
phase, correct = detect_phase(_action(at), history)
|
| 135 |
+
assert correct is True, (
|
| 136 |
+
f"Expected correct order for {at} with history {history}"
|
| 137 |
+
)
|
| 138 |
+
history.append(phase)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
# PHASE_ORDER constant
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class TestPhaseOrderConstant:
|
| 147 |
+
def test_phase_order_has_seven_phases(self):
|
| 148 |
+
assert len(PHASE_ORDER) == 7
|
| 149 |
+
|
| 150 |
+
def test_phase_order_sequence(self):
|
| 151 |
+
assert PHASE_ORDER == [
|
| 152 |
+
"literature_review",
|
| 153 |
+
"hypothesis",
|
| 154 |
+
"design",
|
| 155 |
+
"enrollment",
|
| 156 |
+
"monitoring",
|
| 157 |
+
"analysis",
|
| 158 |
+
"submission",
|
| 159 |
+
]
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ---------------------------------------------------------------------------
|
| 163 |
+
# compute_phase_ordering_reward tests
|
| 164 |
+
# ---------------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class TestComputePhaseOrderingReward:
|
| 168 |
+
def test_empty_history_returns_bonus(self):
|
| 169 |
+
reward = compute_phase_ordering_reward(_action(ActionType.SET_SAMPLE_SIZE), [])
|
| 170 |
+
assert reward == PHASE_BONUS
|
| 171 |
+
|
| 172 |
+
def test_correct_advance_returns_bonus(self):
|
| 173 |
+
reward = compute_phase_ordering_reward(
|
| 174 |
+
_action(ActionType.ENROLL_PATIENTS), ["design"]
|
| 175 |
+
)
|
| 176 |
+
assert reward == PHASE_BONUS
|
| 177 |
+
|
| 178 |
+
def test_same_phase_returns_bonus(self):
|
| 179 |
+
reward = compute_phase_ordering_reward(
|
| 180 |
+
_action(ActionType.SET_SAMPLE_SIZE), ["design"]
|
| 181 |
+
)
|
| 182 |
+
assert reward == PHASE_BONUS
|
| 183 |
+
|
| 184 |
+
def test_regression_returns_zero(self):
|
| 185 |
+
reward = compute_phase_ordering_reward(
|
| 186 |
+
_action(ActionType.SET_SAMPLE_SIZE), ["enrollment"]
|
| 187 |
+
)
|
| 188 |
+
assert reward == 0.0
|
| 189 |
+
|
| 190 |
+
def test_skip_one_phase_returns_single_penalty(self):
|
| 191 |
+
# hypothesis → enrollment skips design (1 skip)
|
| 192 |
+
reward = compute_phase_ordering_reward(
|
| 193 |
+
_action(ActionType.ENROLL_PATIENTS), ["hypothesis"]
|
| 194 |
+
)
|
| 195 |
+
assert reward == pytest.approx(PHASE_SKIP_PENALTY * 1)
|
| 196 |
+
|
| 197 |
+
def test_skip_two_phases_returns_double_penalty(self):
|
| 198 |
+
# design → monitoring skips enrollment (1 skip)
|
| 199 |
+
# design → analysis skips enrollment + monitoring (2 skips)
|
| 200 |
+
reward = compute_phase_ordering_reward(
|
| 201 |
+
_action(ActionType.RUN_PRIMARY_ANALYSIS), ["design"]
|
| 202 |
+
)
|
| 203 |
+
assert reward == pytest.approx(PHASE_SKIP_PENALTY * 2)
|
| 204 |
+
|
| 205 |
+
def test_constants_values(self):
|
| 206 |
+
assert PHASE_BONUS == 0.2
|
| 207 |
+
assert PHASE_SKIP_PENALTY == -0.3
|