Spaces:
Sleeping
Sleeping
| """ | |
| Clinical Trial Triage β Typed Models | |
| ===================================== | |
| Pydantic models for Actions, Observations, Rewards, and State. | |
| All models are fully typed and OpenEnv-spec compliant. | |
| """ | |
| from __future__ import annotations | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, ConfigDict, Field | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # ENUMS | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| class AESeverity(str, Enum): | |
| MILD = "mild" | |
| MODERATE = "moderate" | |
| SEVERE = "severe" | |
| LIFE_THREATENING = "life_threatening" | |
| FATAL = "fatal" | |
| class ReportingTimeline(str, Enum): | |
| SEVEN_DAY = "7-day" # SAE unexpected fatal/life-threatening | |
| FIFTEEN_DAY = "15-day" # SUSAR (Suspected Unexpected Serious Adverse Reaction) | |
| ROUTINE = "routine" # Annual safety report | |
| class DeviationType(str, Enum): | |
| MAJOR = "major" # Affects subject safety or data integrity | |
| MINOR = "minor" # Administrative, no subject safety impact | |
| PROTOCOL_AMENDMENT = "protocol_amendment" | |
| class CausalityAssessment(str, Enum): | |
| DEFINITELY_RELATED = "definitely_related" | |
| PROBABLY_RELATED = "probably_related" | |
| POSSIBLY_RELATED = "possibly_related" | |
| UNLIKELY_RELATED = "unlikely_related" | |
| NOT_RELATED = "not_related" | |
| UNASSESSABLE = "unassessable" | |
| class TaskID(str, Enum): | |
| ADVERSE_EVENT_TRIAGE = "adverse_event_triage" | |
| PROTOCOL_DEVIATION_AUDIT = "protocol_deviation_audit" | |
| SAFETY_NARRATIVE_GENERATION = "safety_narrative_generation" | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # ACTIONS | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| class AdverseEventTriageAction(BaseModel): | |
| """Action for Task 1: Adverse Event Triage.""" | |
| severity_classification: AESeverity = Field( | |
| ..., | |
| description="Agent's severity classification of the adverse event.", | |
| ) | |
| reporting_timeline: ReportingTimeline = Field( | |
| ..., | |
| description="Required regulatory reporting timeline.", | |
| ) | |
| meddra_soc: str = Field( | |
| ..., | |
| description="MedDRA System Organ Class (e.g., 'Cardiac disorders').", | |
| max_length=120, | |
| ) | |
| meddra_preferred_term: str = Field( | |
| ..., | |
| description="MedDRA Preferred Term (e.g., 'Myocardial infarction').", | |
| max_length=120, | |
| ) | |
| is_serious: bool = Field( | |
| ..., | |
| description="Whether this qualifies as a Serious Adverse Event (SAE).", | |
| ) | |
| rationale: str = Field( | |
| ..., | |
| description="Agent's reasoning (max 500 chars).", | |
| max_length=500, | |
| ) | |
| class ProtocolDeviationAction(BaseModel): | |
| """Action for Task 2: Protocol Deviation Audit.""" | |
| deviation_type: DeviationType = Field( | |
| ..., | |
| description="Classification of each deviation found.", | |
| ) | |
| capa_required: bool = Field( | |
| ..., | |
| description="Whether a Corrective and Preventive Action plan is required.", | |
| ) | |
| site_risk_score: float = Field( | |
| ..., | |
| ge=0.0, | |
| le=10.0, | |
| description="Risk score for the site (0=low, 10=critical).", | |
| ) | |
| flagged_finding_ids: List[str] = Field( | |
| default_factory=list, | |
| description="List of finding IDs the agent considers GCP violations.", | |
| ) | |
| recommended_action: str = Field( | |
| ..., | |
| description="Agent's recommended next step (e.g., 'Immediate re-monitoring').", | |
| max_length=300, | |
| ) | |
| class SafetyNarrativeAction(BaseModel): | |
| """Action for Task 3: Safety Narrative Generation.""" | |
| narrative_text: str = Field( | |
| ..., | |
| description="Full ICH E2B-compliant ICSR safety narrative.", | |
| min_length=100, | |
| max_length=4000, | |
| ) | |
| causality_assessment: CausalityAssessment = Field( | |
| ..., | |
| description="Causality assessment for the primary suspect drug.", | |
| ) | |
| key_temporal_flags: List[str] = Field( | |
| default_factory=list, | |
| description="Temporal markers identified (e.g., 'onset 3 days after dose increase').", | |
| ) | |
| dechallenge_positive: Optional[bool] = Field( | |
| None, | |
| description="Whether the AE resolved on drug discontinuation (None if unknown).", | |
| ) | |
| rechallenge_positive: Optional[bool] = Field( | |
| None, | |
| description="Whether the AE recurred on re-administration (None if not done).", | |
| ) | |
| # Union action type β the agent sends one of these per step | |
| class TriageAction(BaseModel): | |
| """Top-level Action model wrapping task-specific actions.""" | |
| task_id: TaskID = Field(..., description="Which task this action targets.") | |
| ae_triage: Optional[AdverseEventTriageAction] = Field( | |
| None, description="Populated for adverse_event_triage task." | |
| ) | |
| deviation_audit: Optional[ProtocolDeviationAction] = Field( | |
| None, description="Populated for protocol_deviation_audit task." | |
| ) | |
| safety_narrative: Optional[SafetyNarrativeAction] = Field( | |
| None, description="Populated for safety_narrative_generation task." | |
| ) | |
| model_config = ConfigDict(use_enum_values=True) | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # OBSERVATIONS | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| class AdverseEventObservation(BaseModel): | |
| """Observation returned for AE Triage task.""" | |
| case_id: str | |
| narrative: str = Field(..., description="Raw AE narrative from site.") | |
| patient_age: int | |
| patient_sex: str | |
| study_drug: str | |
| dose_mg: float | |
| days_on_drug: int | |
| relevant_medical_history: List[str] | |
| concomitant_medications: List[str] | |
| lab_values: Dict[str, Any] | |
| ae_onset_date: str | |
| ae_description: str | |
| outcome: str | |
| step_count: int | |
| max_steps: int | |
| scoring_hints: Optional[Dict[str, Any]] = None | |
| class ProtocolDeviationObservation(BaseModel): | |
| """Observation returned for Protocol Deviation Audit task.""" | |
| site_id: str | |
| site_name: str | |
| visit_type: str | |
| findings: List[Dict[str, Any]] | |
| prior_deviations: int | |
| active_subjects: int | |
| study_phase: str | |
| last_monitoring_visit: str | |
| step_count: int | |
| max_steps: int | |
| class SafetyNarrativeObservation(BaseModel): | |
| """Observation returned for Safety Narrative Generation task.""" | |
| case_id: str | |
| patient_demographics: Dict[str, Any] | |
| study_drug: str | |
| suspect_drugs: List[str] | |
| concomitant_medications: List[Dict[str, Any]] | |
| adverse_event: Dict[str, Any] | |
| lab_values_timeline: List[Dict[str, Any]] | |
| medical_history: List[str] | |
| action_taken: str | |
| outcome_at_last_followup: str | |
| reference_documents: List[str] | |
| step_count: int | |
| max_steps: int | |
| class TriageObservation(BaseModel): | |
| """Top-level Observation returned from step() / reset().""" | |
| task_id: TaskID | |
| ae_observation: Optional[AdverseEventObservation] = None | |
| deviation_observation: Optional[ProtocolDeviationObservation] = None | |
| narrative_observation: Optional[SafetyNarrativeObservation] = None | |
| message: str = "" | |
| model_config = ConfigDict(use_enum_values=True) | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # REWARD | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| class TriageReward(BaseModel): | |
| """ | |
| Structured reward with partial credit signals. | |
| All sub-scores normalized to [0, 1]. | |
| """ | |
| total: float = Field(..., ge=0.0, le=1.0, description="Weighted total reward.") | |
| # Task-1 sub-scores | |
| severity_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| timeline_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| soc_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| pt_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| # Task-2 sub-scores | |
| deviation_type_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| capa_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| risk_score_proximity: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| violation_recall: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| violation_precision: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| # Task-3 sub-scores | |
| temporal_coverage: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| causality_accuracy: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| narrative_completeness: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| regulatory_compliance: Optional[float] = Field(None, ge=0.0, le=1.0) | |
| # Penalty flags | |
| penalty_applied: bool = False | |
| penalty_reason: Optional[str] = None | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # STATE | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| class TriageState(BaseModel): | |
| """Episode state metadata returned from state().""" | |
| episode_id: str | |
| task_id: TaskID | |
| step_count: int | |
| max_steps: int | |
| done: bool | |
| cumulative_reward: float | |
| actions_taken: List[Dict[str, Any]] = Field(default_factory=list) | |
| current_case_id: Optional[str] = None | |
| started_at: str | |
| completed_at: Optional[str] = None | |
| model_config = ConfigDict(use_enum_values=True) | |
| class StepResult(BaseModel): | |
| """Result returned from step().""" | |
| observation: TriageObservation | |
| reward: float | |
| reward_detail: TriageReward | |
| done: bool | |
| info: Dict[str, Any] = Field(default_factory=dict) |