""" LLM Baseline Inference Script (GroqCloud) ========================================= Runs a GroqCloud-backed LLM baseline against all 3 tasks and produces reproducible scores. If GROQ_API_KEY is not set, it falls back to the deterministic heuristic baseline so the script always completes. Usage: GROQ_API_KEY=... python scripts/baseline_inference.py Optional environment variables: GROQ_BASE_URL=https://api.groq.com GROQ_API_KEYS=gsk_key_1,gsk_key_2,gsk_key_3 BASELINE_MODEL=llama-3.3-70b-versatile GROQ_KEY_STATE_FILE=outputs/groq_key_usage.json Output: outputs/baseline_results.json """ from __future__ import annotations import json import os import re import sys import time from pathlib import Path from typing import Any, Dict, List, Optional # Ensure project root is in path sys.path.insert(0, str(Path(__file__).parent.parent)) from models import ( AdverseEventTriageAction, ProtocolDeviationAction, SafetyNarrativeAction, TaskID, TriageAction, ) from scripts.groq_key_pool import GroqKeyPool, parse_groq_keys from server.environment import ClinicalTrialEnvironment # ----------------------------------------------------------------------------- # CONFIG # ----------------------------------------------------------------------------- API_KEY = os.environ.get("GROQ_API_KEY", "") API_KEYS_CSV = os.environ.get("GROQ_API_KEYS", "") BASE_URL = os.environ.get("GROQ_BASE_URL", "https://api.groq.com") MODEL = os.environ.get("BASELINE_MODEL", "llama-3.3-70b-versatile") TEMPERATURE = 0.0 OUTPUT_DIR = Path(__file__).parent.parent / "outputs" KEY_STATE_FILE = Path( os.environ.get( "GROQ_KEY_STATE_FILE", str(OUTPUT_DIR / "groq_key_usage.json"), ) ) # ----------------------------------------------------------------------------- # SYSTEM PROMPTS # ----------------------------------------------------------------------------- AE_SYSTEM_PROMPT = """You are an expert clinical research pharmacovigilance specialist. Return only a valid JSON object with these exact fields: { "severity_classification": "mild|moderate|severe|life_threatening|fatal", "reporting_timeline": "7-day|15-day|routine", "meddra_soc": "string", "meddra_preferred_term": "string", "is_serious": true, "rationale": "string up to 500 chars" } """ DEVIATION_SYSTEM_PROMPT = """You are a senior GCP auditor. Return only a valid JSON object with these exact fields: { "deviation_type": "major|minor|protocol_amendment", "capa_required": true, "site_risk_score": 0.0, "flagged_finding_ids": ["F001"], "recommended_action": "string up to 300 chars" } """ NARRATIVE_SYSTEM_PROMPT = """You are a regulatory medical writer. Return only a valid JSON object with these exact fields: { "narrative_text": "100-4000 chars", "causality_assessment": "definitely_related|probably_related|possibly_related|unlikely_related|not_related|unassessable", "key_temporal_flags": ["string"], "dechallenge_positive": true, "rechallenge_positive": null } """ def _extract_json_object(raw_text: str) -> Optional[Dict[str, Any]]: """Best-effort extraction when model returns extra text around JSON.""" text = raw_text.strip() try: return json.loads(text) except json.JSONDecodeError: pass match = re.search(r"\{[\s\S]*\}", text) if not match: return None try: return json.loads(match.group(0)) except json.JSONDecodeError: return None class LLMAgent: """GroqCloud-backed agent with strict JSON parsing and retries.""" def __init__(self, key_pool: GroqKeyPool, model: str = MODEL): self.key_pool = key_pool self.model = model def _call(self, system_prompt: str, user_content: str, retries: int = 3) -> Optional[Dict[str, Any]]: for attempt in range(retries): key_id = self.key_pool.acquire_key() if key_id is None: print(f" [Attempt {attempt + 1}] No Groq API key available") time.sleep(2**attempt) continue client = self.key_pool.get_client(key_id) self.key_pool.mark_request(key_id) try: response = client.chat.completions.create( model=self.model, temperature=TEMPERATURE, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}, ], ) raw = response.choices[0].message.content or "" parsed = _extract_json_object(raw) if parsed is not None: self.key_pool.mark_success(key_id) return parsed self.key_pool.mark_failure(key_id, "invalid json response") print(f" [Attempt {attempt + 1}] Could not parse JSON response") except Exception as exc: self.key_pool.mark_failure(key_id, str(exc)) print(f" [Attempt {attempt + 1}] API error: {exc}") time.sleep(2**attempt) return None def act_ae_triage(self, obs_dict: Dict[str, Any]) -> Optional[TriageAction]: ae_obs = obs_dict.get("ae_observation", {}) user_content = f""" ADVERSE EVENT CASE ================== Case ID: {ae_obs.get('case_id')} Patient: {ae_obs.get('patient_age')}y {ae_obs.get('patient_sex')} Study Drug: {ae_obs.get('study_drug')} {ae_obs.get('dose_mg')}mg Days on Drug: {ae_obs.get('days_on_drug')} Narrative: {ae_obs.get('narrative')} AE Description: {ae_obs.get('ae_description')} Outcome: {ae_obs.get('outcome')} Medical History: {', '.join(ae_obs.get('relevant_medical_history', []))} Concomitant Medications: {', '.join(ae_obs.get('concomitant_medications', []))} Lab Values: {json.dumps(ae_obs.get('lab_values', {}), indent=2)} """ result = self._call(AE_SYSTEM_PROMPT, user_content) if not result: return None try: return TriageAction( task_id=TaskID.ADVERSE_EVENT_TRIAGE, ae_triage=AdverseEventTriageAction(**result), ) except Exception as exc: print(f" Action parse error: {exc}") return None def act_deviation_audit(self, obs_dict: Dict[str, Any]) -> Optional[TriageAction]: dev_obs = obs_dict.get("deviation_observation", {}) findings_str = "\n".join( f" [{f['id']}] {f['category']}: {f['description']}" for f in dev_obs.get("findings", []) ) user_content = f""" SITE AUDIT FINDINGS =================== Site: {dev_obs.get('site_id')} - {dev_obs.get('site_name')} Visit Type: {dev_obs.get('visit_type')} Study Phase: {dev_obs.get('study_phase')} Active Subjects: {dev_obs.get('active_subjects')} Prior Deviations: {dev_obs.get('prior_deviations')} Last Monitoring: {dev_obs.get('last_monitoring_visit')} Findings: {findings_str} """ result = self._call(DEVIATION_SYSTEM_PROMPT, user_content) if not result: return None try: return TriageAction( task_id=TaskID.PROTOCOL_DEVIATION_AUDIT, deviation_audit=ProtocolDeviationAction(**result), ) except Exception as exc: print(f" Action parse error: {exc}") return None def act_safety_narrative(self, obs_dict: Dict[str, Any]) -> Optional[TriageAction]: nr_obs = obs_dict.get("narrative_observation", {}) user_content = f""" CASE FOR ICSR NARRATIVE ======================= Case ID: {nr_obs.get('case_id')} Patient: {json.dumps(nr_obs.get('patient_demographics', {}), indent=2)} Study Drug: {nr_obs.get('study_drug')} Suspect Drugs: {nr_obs.get('suspect_drugs')} Concomitant Medications: {json.dumps(nr_obs.get('concomitant_medications', []), indent=2)} Adverse Event: {json.dumps(nr_obs.get('adverse_event', {}), indent=2)} Lab Values Timeline: {json.dumps(nr_obs.get('lab_values_timeline', []), indent=2)} Medical History: {nr_obs.get('medical_history')} Action Taken: {nr_obs.get('action_taken')} Outcome: {nr_obs.get('outcome_at_last_followup')} Reference Documents: {nr_obs.get('reference_documents')} """ result = self._call(NARRATIVE_SYSTEM_PROMPT, user_content) if not result: return None try: return TriageAction( task_id=TaskID.SAFETY_NARRATIVE_GENERATION, safety_narrative=SafetyNarrativeAction(**result), ) except Exception as exc: print(f" Action parse error: {exc}") return None def _run_task(env: ClinicalTrialEnvironment, task_id: TaskID, agent: LLMAgent, max_steps: int) -> Dict[str, Any]: """Run one task and return per-step rewards and details.""" obs_dict = env.reset(task_id=task_id).model_dump() rewards: List[float] = [] details: List[Dict[str, Any]] = [] for _ in range(max_steps): if task_id == TaskID.ADVERSE_EVENT_TRIAGE: action = agent.act_ae_triage(obs_dict) elif task_id == TaskID.PROTOCOL_DEVIATION_AUDIT: action = agent.act_deviation_audit(obs_dict) else: action = agent.act_safety_narrative(obs_dict) if action is None: rewards.append(0.0) details.append({"error": "agent_failed_to_produce_valid_action"}) continue step_result = env.step(action) rewards.append(step_result.reward) details.append(step_result.reward_detail.model_dump()) obs_dict = step_result.observation.model_dump() if step_result.done: break return { "per_step_rewards": rewards, "mean_reward": round(sum(rewards) / max(len(rewards), 1), 4), "n_steps": len(rewards), "details": details, } def run_llm_baseline() -> Dict[str, Any]: """Run GroqCloud LLM baseline or deterministic fallback.""" api_keys = parse_groq_keys(api_key=API_KEY, api_keys_csv=API_KEYS_CSV) if not api_keys: from scripts.heuristic_baseline import run_heuristic_baseline fallback = run_heuristic_baseline() fallback["baseline_type"] = "heuristic_fallback" fallback["reason"] = "No Groq key found (GROQ_API_KEY or GROQ_API_KEYS); used deterministic heuristic baseline." OUTPUT_DIR.mkdir(parents=True, exist_ok=True) output_path = OUTPUT_DIR / "baseline_results.json" with open(output_path, "w", encoding="utf-8") as file: json.dump(fallback, file, indent=2) print(json.dumps(fallback, indent=2)) print(f"\nResults saved to: {output_path}") return fallback key_pool = GroqKeyPool( api_keys=api_keys, base_url=BASE_URL, state_file=KEY_STATE_FILE, ) agent = LLMAgent(key_pool=key_pool, model=MODEL) env = ClinicalTrialEnvironment() results: Dict[str, Any] = { "model": MODEL, "temperature": TEMPERATURE, "baseline_type": "groq_llm", "provider": "groqcloud", "key_pool": key_pool.snapshot(), "tasks": {}, } results["tasks"][TaskID.ADVERSE_EVENT_TRIAGE] = _run_task( env=env, task_id=TaskID.ADVERSE_EVENT_TRIAGE, agent=agent, max_steps=3, ) results["tasks"][TaskID.PROTOCOL_DEVIATION_AUDIT] = _run_task( env=env, task_id=TaskID.PROTOCOL_DEVIATION_AUDIT, agent=agent, max_steps=3, ) results["tasks"][TaskID.SAFETY_NARRATIVE_GENERATION] = _run_task( env=env, task_id=TaskID.SAFETY_NARRATIVE_GENERATION, agent=agent, max_steps=1, ) all_means = [task_result["mean_reward"] for task_result in results["tasks"].values()] results["overall_mean_reward"] = round(sum(all_means) / len(all_means), 4) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) output_path = OUTPUT_DIR / "baseline_results.json" with open(output_path, "w", encoding="utf-8") as file: json.dump(results, file, indent=2) print(json.dumps(results, indent=2)) print(f"\nResults saved to: {output_path}") return results if __name__ == "__main__": run_llm_baseline()