openenv-clinical-trial / server /simulator /transition_engine.py
Coding Ninja
fixes : rewards and training
d68729f
"""
TransitionEngine — mutates TrialLatentState per action.
Follows the Bio Experiment pattern: TransitionEngine updates hidden state,
OutputGenerator produces noisy observations from it. Agent never sees clean
hidden values.
Key responsibilities:
- Enroll patients (ENROLL_PATIENTS)
- Spend budget and advance time
- Record adverse events
- Set milestone flags (phase_i_complete, mtd_identified, effect_estimated,
protocol_submitted, interim_complete, trial_complete)
- Degrade data quality on soft violations
"""
from __future__ import annotations
import random
from models import ActionType, TrialAction, TrialLatentState
class TransitionEngine:
"""Mutates TrialLatentState in response to agent actions.
All state transitions are deterministic given the same seed and action
sequence (reproducibility requirement 9.2).
"""
# Cost and time constants (per action type)
_ACTION_COSTS: dict[ActionType, float] = {
ActionType.SET_PRIMARY_ENDPOINT: 5_000.0,
ActionType.SET_SAMPLE_SIZE: 2_000.0,
ActionType.SET_INCLUSION_CRITERIA: 3_000.0,
ActionType.SET_EXCLUSION_CRITERIA: 3_000.0,
ActionType.SET_DOSING_SCHEDULE: 10_000.0,
ActionType.SET_CONTROL_ARM: 5_000.0,
ActionType.SET_RANDOMIZATION_RATIO: 2_000.0,
ActionType.SET_BLINDING: 4_000.0,
ActionType.RUN_DOSE_ESCALATION: 50_000.0,
ActionType.OBSERVE_SAFETY_SIGNAL: 15_000.0,
ActionType.ESTIMATE_EFFECT_SIZE: 20_000.0,
ActionType.RUN_INTERIM_ANALYSIS: 30_000.0,
ActionType.MODIFY_SAMPLE_SIZE: 5_000.0,
ActionType.ADD_BIOMARKER_STRATIFICATION: 25_000.0,
ActionType.SUBMIT_TO_FDA_REVIEW: 100_000.0,
ActionType.REQUEST_PROTOCOL_AMENDMENT: 15_000.0,
ActionType.RUN_PRIMARY_ANALYSIS: 50_000.0,
ActionType.SYNTHESIZE_CONCLUSION: 10_000.0,
ActionType.ENROLL_PATIENTS: 0.0, # cost computed per patient
}
_ACTION_TIME_DAYS: dict[ActionType, int] = {
ActionType.SET_PRIMARY_ENDPOINT: 7,
ActionType.SET_SAMPLE_SIZE: 3,
ActionType.SET_INCLUSION_CRITERIA: 5,
ActionType.SET_EXCLUSION_CRITERIA: 5,
ActionType.SET_DOSING_SCHEDULE: 14,
ActionType.SET_CONTROL_ARM: 7,
ActionType.SET_RANDOMIZATION_RATIO: 3,
ActionType.SET_BLINDING: 5,
ActionType.RUN_DOSE_ESCALATION: 90,
ActionType.OBSERVE_SAFETY_SIGNAL: 30,
ActionType.ESTIMATE_EFFECT_SIZE: 45,
ActionType.RUN_INTERIM_ANALYSIS: 60,
ActionType.MODIFY_SAMPLE_SIZE: 7,
ActionType.ADD_BIOMARKER_STRATIFICATION: 30,
ActionType.SUBMIT_TO_FDA_REVIEW: 180,
ActionType.REQUEST_PROTOCOL_AMENDMENT: 30,
ActionType.RUN_PRIMARY_ANALYSIS: 90,
ActionType.SYNTHESIZE_CONCLUSION: 14,
ActionType.ENROLL_PATIENTS: 0, # time computed per patient
}
# Cost per patient enrolled (varies by disease area complexity)
_COST_PER_PATIENT: float = 10_000.0
_DAYS_PER_PATIENT: float = 2.0
def __init__(self) -> None:
"""Initialize the TransitionEngine."""
pass
def apply_transition(
self, latent: TrialLatentState, action: TrialAction
) -> TrialLatentState:
"""Apply *action* to *latent* and return the updated state.
Does NOT mutate the input latent state — returns a new copy with
updated fields.
Args:
latent: Current hidden state.
action: Agent action to apply.
Returns:
Updated TrialLatentState with mutated fields.
"""
# Create a mutable copy
updated = latent.model_copy(deep=True)
# Update action history
updated.action_history.append(action.action_type.value)
# Compute step-specific RNG
step_index = len(updated.action_history)
rng = random.Random(latent.seed ^ step_index)
# --- Budget and time consumption ---
base_cost = self._ACTION_COSTS.get(action.action_type, 0.0)
base_time = self._ACTION_TIME_DAYS.get(action.action_type, 0)
if action.action_type == ActionType.ENROLL_PATIENTS:
n_patients = max(int(action.parameters.get("n_patients", 0)), 0)
base_cost = n_patients * self._COST_PER_PATIENT
base_time = int(n_patients * self._DAYS_PER_PATIENT)
updated.patients_enrolled += n_patients
updated.budget_remaining -= base_cost
updated.time_remaining_days -= base_time
# --- Milestone flag updates ---
if action.action_type == ActionType.RUN_DOSE_ESCALATION:
updated.phase_i_complete = True
updated.mtd_identified = True
if action.action_type == ActionType.ESTIMATE_EFFECT_SIZE:
updated.effect_estimated = True
if action.action_type == ActionType.SUBMIT_TO_FDA_REVIEW:
updated.protocol_submitted = True
if action.action_type == ActionType.RUN_INTERIM_ANALYSIS:
updated.interim_complete = True
if action.action_type == ActionType.RUN_PRIMARY_ANALYSIS:
updated.primary_analysis_complete = True
# trial_complete is only set by SYNTHESIZE_CONCLUSION so the agent
# is forced through the full conclusion/submission workflow rather
# than ending the episode the moment primary analysis runs.
if action.action_type == ActionType.SYNTHESIZE_CONCLUSION:
updated.trial_complete = True
# --- Soft violation: degrade data quality ---
# If action confidence is low (< 0.5), increase measurement noise
if action.confidence < 0.5:
degradation = 0.05 * (0.5 - action.confidence)
updated.measurement_noise = min(
updated.measurement_noise + degradation, 0.5
)
# If budget is negative (soft violation), degrade site variability
if updated.budget_remaining < 0:
updated.site_variability = min(updated.site_variability + 0.03, 0.5)
# If time is negative (soft violation), increase dropout rate
if updated.time_remaining_days < 0:
updated.dropout_rate = min(updated.dropout_rate * 1.15, 0.8)
# --- Phase progression (G23) ---
# Advance episode_phase based on the action taken so the phase detector
# and rule engine see a moving phase rather than a stuck "literature_review".
# Phase names must match TRANSITION_TABLE keys in fda_rules.py.
_PHASE_TRANSITIONS: dict[ActionType, str] = {
ActionType.SET_PRIMARY_ENDPOINT: "hypothesis",
ActionType.ESTIMATE_EFFECT_SIZE: "hypothesis",
ActionType.SET_SAMPLE_SIZE: "design",
ActionType.SET_INCLUSION_CRITERIA: "design",
ActionType.SET_EXCLUSION_CRITERIA: "design",
ActionType.SET_DOSING_SCHEDULE: "design",
ActionType.SET_CONTROL_ARM: "design",
ActionType.SET_RANDOMIZATION_RATIO: "design",
ActionType.SET_BLINDING: "design",
ActionType.ADD_BIOMARKER_STRATIFICATION: "design",
ActionType.REQUEST_PROTOCOL_AMENDMENT: "design",
ActionType.ENROLL_PATIENTS: "enrollment",
ActionType.RUN_DOSE_ESCALATION: "enrollment",
ActionType.OBSERVE_SAFETY_SIGNAL: "enrollment",
ActionType.MODIFY_SAMPLE_SIZE: "enrollment",
ActionType.RUN_INTERIM_ANALYSIS: "monitoring",
ActionType.RUN_PRIMARY_ANALYSIS: "analysis",
ActionType.SYNTHESIZE_CONCLUSION: "analysis",
ActionType.SUBMIT_TO_FDA_REVIEW: "submission",
}
# Only advance — never go backwards
_PHASE_ORDER = [
"literature_review",
"hypothesis",
"design",
"enrollment",
"monitoring",
"analysis",
"submission",
]
target_phase = _PHASE_TRANSITIONS.get(action.action_type)
if target_phase is not None:
try:
current_idx = _PHASE_ORDER.index(updated.episode_phase)
target_idx = _PHASE_ORDER.index(target_phase)
if target_idx > current_idx:
updated.episode_phase = target_phase
except ValueError:
updated.episode_phase = target_phase
# --- Adverse event recording (stochastic) ---
# On certain actions, record adverse events based on true_side_effect_rate
if action.action_type in {
ActionType.ENROLL_PATIENTS,
ActionType.OBSERVE_SAFETY_SIGNAL,
ActionType.RUN_DOSE_ESCALATION,
}:
# For ENROLL_PATIENTS, scale AEs with number of patients
n_exposed = 1
if action.action_type == ActionType.ENROLL_PATIENTS:
n_exposed = max(action.parameters.get("n_patients", 1), 1)
# Each exposed patient has independent AE chance
ae_count = sum(
1 for _ in range(n_exposed)
if rng.random() < updated.true_side_effect_rate
)
if ae_count > 0:
updated.adverse_events += ae_count
updated.site_variability = min(
updated.site_variability + 0.02 * ae_count, 0.5
)
return updated