Spaces:
Sleeping
Sleeping
| """ | |
| TransitionEngine — mutates TrialLatentState per action. | |
| Follows the Bio Experiment pattern: TransitionEngine updates hidden state, | |
| OutputGenerator produces noisy observations from it. Agent never sees clean | |
| hidden values. | |
| Key responsibilities: | |
| - Enroll patients (ENROLL_PATIENTS) | |
| - Spend budget and advance time | |
| - Record adverse events | |
| - Set milestone flags (phase_i_complete, mtd_identified, effect_estimated, | |
| protocol_submitted, interim_complete, trial_complete) | |
| - Degrade data quality on soft violations | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from models import ActionType, TrialAction, TrialLatentState | |
| class TransitionEngine: | |
| """Mutates TrialLatentState in response to agent actions. | |
| All state transitions are deterministic given the same seed and action | |
| sequence (reproducibility requirement 9.2). | |
| """ | |
| # Cost and time constants (per action type) | |
| _ACTION_COSTS: dict[ActionType, float] = { | |
| ActionType.SET_PRIMARY_ENDPOINT: 5_000.0, | |
| ActionType.SET_SAMPLE_SIZE: 2_000.0, | |
| ActionType.SET_INCLUSION_CRITERIA: 3_000.0, | |
| ActionType.SET_EXCLUSION_CRITERIA: 3_000.0, | |
| ActionType.SET_DOSING_SCHEDULE: 10_000.0, | |
| ActionType.SET_CONTROL_ARM: 5_000.0, | |
| ActionType.SET_RANDOMIZATION_RATIO: 2_000.0, | |
| ActionType.SET_BLINDING: 4_000.0, | |
| ActionType.RUN_DOSE_ESCALATION: 50_000.0, | |
| ActionType.OBSERVE_SAFETY_SIGNAL: 15_000.0, | |
| ActionType.ESTIMATE_EFFECT_SIZE: 20_000.0, | |
| ActionType.RUN_INTERIM_ANALYSIS: 30_000.0, | |
| ActionType.MODIFY_SAMPLE_SIZE: 5_000.0, | |
| ActionType.ADD_BIOMARKER_STRATIFICATION: 25_000.0, | |
| ActionType.SUBMIT_TO_FDA_REVIEW: 100_000.0, | |
| ActionType.REQUEST_PROTOCOL_AMENDMENT: 15_000.0, | |
| ActionType.RUN_PRIMARY_ANALYSIS: 50_000.0, | |
| ActionType.SYNTHESIZE_CONCLUSION: 10_000.0, | |
| ActionType.ENROLL_PATIENTS: 0.0, # cost computed per patient | |
| } | |
| _ACTION_TIME_DAYS: dict[ActionType, int] = { | |
| ActionType.SET_PRIMARY_ENDPOINT: 7, | |
| ActionType.SET_SAMPLE_SIZE: 3, | |
| ActionType.SET_INCLUSION_CRITERIA: 5, | |
| ActionType.SET_EXCLUSION_CRITERIA: 5, | |
| ActionType.SET_DOSING_SCHEDULE: 14, | |
| ActionType.SET_CONTROL_ARM: 7, | |
| ActionType.SET_RANDOMIZATION_RATIO: 3, | |
| ActionType.SET_BLINDING: 5, | |
| ActionType.RUN_DOSE_ESCALATION: 90, | |
| ActionType.OBSERVE_SAFETY_SIGNAL: 30, | |
| ActionType.ESTIMATE_EFFECT_SIZE: 45, | |
| ActionType.RUN_INTERIM_ANALYSIS: 60, | |
| ActionType.MODIFY_SAMPLE_SIZE: 7, | |
| ActionType.ADD_BIOMARKER_STRATIFICATION: 30, | |
| ActionType.SUBMIT_TO_FDA_REVIEW: 180, | |
| ActionType.REQUEST_PROTOCOL_AMENDMENT: 30, | |
| ActionType.RUN_PRIMARY_ANALYSIS: 90, | |
| ActionType.SYNTHESIZE_CONCLUSION: 14, | |
| ActionType.ENROLL_PATIENTS: 0, # time computed per patient | |
| } | |
| # Cost per patient enrolled (varies by disease area complexity) | |
| _COST_PER_PATIENT: float = 10_000.0 | |
| _DAYS_PER_PATIENT: float = 2.0 | |
| def __init__(self) -> None: | |
| """Initialize the TransitionEngine.""" | |
| pass | |
| def apply_transition( | |
| self, latent: TrialLatentState, action: TrialAction | |
| ) -> TrialLatentState: | |
| """Apply *action* to *latent* and return the updated state. | |
| Does NOT mutate the input latent state — returns a new copy with | |
| updated fields. | |
| Args: | |
| latent: Current hidden state. | |
| action: Agent action to apply. | |
| Returns: | |
| Updated TrialLatentState with mutated fields. | |
| """ | |
| # Create a mutable copy | |
| updated = latent.model_copy(deep=True) | |
| # Update action history | |
| updated.action_history.append(action.action_type.value) | |
| # Compute step-specific RNG | |
| step_index = len(updated.action_history) | |
| rng = random.Random(latent.seed ^ step_index) | |
| # --- Budget and time consumption --- | |
| base_cost = self._ACTION_COSTS.get(action.action_type, 0.0) | |
| base_time = self._ACTION_TIME_DAYS.get(action.action_type, 0) | |
| if action.action_type == ActionType.ENROLL_PATIENTS: | |
| n_patients = max(int(action.parameters.get("n_patients", 0)), 0) | |
| base_cost = n_patients * self._COST_PER_PATIENT | |
| base_time = int(n_patients * self._DAYS_PER_PATIENT) | |
| updated.patients_enrolled += n_patients | |
| updated.budget_remaining -= base_cost | |
| updated.time_remaining_days -= base_time | |
| # --- Milestone flag updates --- | |
| if action.action_type == ActionType.RUN_DOSE_ESCALATION: | |
| updated.phase_i_complete = True | |
| updated.mtd_identified = True | |
| if action.action_type == ActionType.ESTIMATE_EFFECT_SIZE: | |
| updated.effect_estimated = True | |
| if action.action_type == ActionType.SUBMIT_TO_FDA_REVIEW: | |
| updated.protocol_submitted = True | |
| if action.action_type == ActionType.RUN_INTERIM_ANALYSIS: | |
| updated.interim_complete = True | |
| if action.action_type == ActionType.RUN_PRIMARY_ANALYSIS: | |
| updated.primary_analysis_complete = True | |
| # trial_complete is only set by SYNTHESIZE_CONCLUSION so the agent | |
| # is forced through the full conclusion/submission workflow rather | |
| # than ending the episode the moment primary analysis runs. | |
| if action.action_type == ActionType.SYNTHESIZE_CONCLUSION: | |
| updated.trial_complete = True | |
| # --- Soft violation: degrade data quality --- | |
| # If action confidence is low (< 0.5), increase measurement noise | |
| if action.confidence < 0.5: | |
| degradation = 0.05 * (0.5 - action.confidence) | |
| updated.measurement_noise = min( | |
| updated.measurement_noise + degradation, 0.5 | |
| ) | |
| # If budget is negative (soft violation), degrade site variability | |
| if updated.budget_remaining < 0: | |
| updated.site_variability = min(updated.site_variability + 0.03, 0.5) | |
| # If time is negative (soft violation), increase dropout rate | |
| if updated.time_remaining_days < 0: | |
| updated.dropout_rate = min(updated.dropout_rate * 1.15, 0.8) | |
| # --- Phase progression (G23) --- | |
| # Advance episode_phase based on the action taken so the phase detector | |
| # and rule engine see a moving phase rather than a stuck "literature_review". | |
| # Phase names must match TRANSITION_TABLE keys in fda_rules.py. | |
| _PHASE_TRANSITIONS: dict[ActionType, str] = { | |
| ActionType.SET_PRIMARY_ENDPOINT: "hypothesis", | |
| ActionType.ESTIMATE_EFFECT_SIZE: "hypothesis", | |
| ActionType.SET_SAMPLE_SIZE: "design", | |
| ActionType.SET_INCLUSION_CRITERIA: "design", | |
| ActionType.SET_EXCLUSION_CRITERIA: "design", | |
| ActionType.SET_DOSING_SCHEDULE: "design", | |
| ActionType.SET_CONTROL_ARM: "design", | |
| ActionType.SET_RANDOMIZATION_RATIO: "design", | |
| ActionType.SET_BLINDING: "design", | |
| ActionType.ADD_BIOMARKER_STRATIFICATION: "design", | |
| ActionType.REQUEST_PROTOCOL_AMENDMENT: "design", | |
| ActionType.ENROLL_PATIENTS: "enrollment", | |
| ActionType.RUN_DOSE_ESCALATION: "enrollment", | |
| ActionType.OBSERVE_SAFETY_SIGNAL: "enrollment", | |
| ActionType.MODIFY_SAMPLE_SIZE: "enrollment", | |
| ActionType.RUN_INTERIM_ANALYSIS: "monitoring", | |
| ActionType.RUN_PRIMARY_ANALYSIS: "analysis", | |
| ActionType.SYNTHESIZE_CONCLUSION: "analysis", | |
| ActionType.SUBMIT_TO_FDA_REVIEW: "submission", | |
| } | |
| # Only advance — never go backwards | |
| _PHASE_ORDER = [ | |
| "literature_review", | |
| "hypothesis", | |
| "design", | |
| "enrollment", | |
| "monitoring", | |
| "analysis", | |
| "submission", | |
| ] | |
| target_phase = _PHASE_TRANSITIONS.get(action.action_type) | |
| if target_phase is not None: | |
| try: | |
| current_idx = _PHASE_ORDER.index(updated.episode_phase) | |
| target_idx = _PHASE_ORDER.index(target_phase) | |
| if target_idx > current_idx: | |
| updated.episode_phase = target_phase | |
| except ValueError: | |
| updated.episode_phase = target_phase | |
| # --- Adverse event recording (stochastic) --- | |
| # On certain actions, record adverse events based on true_side_effect_rate | |
| if action.action_type in { | |
| ActionType.ENROLL_PATIENTS, | |
| ActionType.OBSERVE_SAFETY_SIGNAL, | |
| ActionType.RUN_DOSE_ESCALATION, | |
| }: | |
| # For ENROLL_PATIENTS, scale AEs with number of patients | |
| n_exposed = 1 | |
| if action.action_type == ActionType.ENROLL_PATIENTS: | |
| n_exposed = max(action.parameters.get("n_patients", 1), 1) | |
| # Each exposed patient has independent AE chance | |
| ae_count = sum( | |
| 1 for _ in range(n_exposed) | |
| if rng.random() < updated.true_side_effect_rate | |
| ) | |
| if ae_count > 0: | |
| updated.adverse_events += ae_count | |
| updated.site_variability = min( | |
| updated.site_variability + 0.02 * ae_count, 0.5 | |
| ) | |
| return updated | |