| """ |
| PRANA-Env agent with full minimal RL loop. |
| |
| The LLM agent must: |
| 1. query_db β retrieve required fields from PatientDB |
| 2. record_value β write each field into the episode record |
| 3. file_report β submit to KARS validator β reward β done |
| |
| Reward signal: |
| +15 KARS PASSED first attempt |
| +10 KARS PASSED after correction |
| -1 redundant query (field already recorded) |
| -5 filed with missing fields (recoverable) |
| -10 unrecoverable failure |
| """ |
|
|
| import json |
| import openai |
| from dataclasses import dataclass, field |
| from typing import Optional |
| from prana_env.client import PranaEnv |
| from prana_env.models import PranaAction |
|
|
| |
|
|
| TOOLS = [ |
| { |
| "type": "function", |
| "function": { |
| "name": "query_db", |
| "description": "Retrieve a specific field from a clinical datastore for a patient.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "target": {"type": "string", "description": "PatientDB | ClinicalNotesDB | PharmacyDB | WaitlistDB"}, |
| "field": {"type": "string", "description": "Field name (e.g. hba1c, gfr, creatinine, blood_type)"}, |
| "patient_id": {"type": "string", "description": "Patient identifier (e.g. P001)"}, |
| }, |
| "required": ["target", "field", "patient_id"], |
| }, |
| }, |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "record_value", |
| "description": "Write a retrieved field value into the episode patient record.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "field": {"type": "string", "description": "Field name to record"}, |
| "value": {"type": "string", "description": "Value to record"}, |
| "source": {"type": "string", "description": "Datastore the value came from"}, |
| }, |
| "required": ["field", "value"], |
| }, |
| }, |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "file_report", |
| "description": ( |
| "Submit the compiled patient record to the KARS validator. " |
| "Returns PASSED (done) or FAILED with missing fields. " |
| "Call only after recording all required fields: hba1c, gfr, creatinine, blood_type." |
| ), |
| "parameters": {"type": "object", "properties": {}, "required": []}, |
| }, |
| }, |
| ] |
|
|
| SYSTEM_PROMPT = """You are a kidney transplant administrative agent. |
| |
| Your goal is to compile a complete patient record and file a KARS-compliant SRTR report. |
| |
| Required fields: hba1c, gfr, creatinine, blood_type (all from PatientDB). |
| |
| KARS Recency Policy: |
| - Time-sensitive fields (hba1c, gfr, creatinine) must be recorded within 90 days of the filing date. |
| - Stable fields (blood_type) have no recency requirement. |
| - The episode starts with a pre-existing T1 record (~4 months old). These values are STALE. |
| - You must re-query and re-record hba1c, gfr, and creatinine before filing. |
| - Do NOT re-query blood_type β it is stable and already valid. |
| |
| Workflow: |
| 1. Check recorded_fields in the observation β identify stale time-sensitive fields. |
| 2. Use query_db to retrieve fresh values for stale fields only. |
| 3. Use record_value to write each fresh value into the patient record. |
| 4. Use file_report to submit. If it fails due to stale or missing fields, fix and retry. |
| |
| Do not guess values. Always query before recording.""" |
|
|
| |
|
|
| @dataclass |
| class Step: |
| action: dict |
| observation: str |
| reward: float |
| done: bool |
|
|
| @dataclass |
| class Trajectory: |
| episode_id: str |
| steps: list[Step] = field(default_factory=list) |
|
|
| @property |
| def total_reward(self) -> float: |
| return sum(s.reward for s in self.steps) |
|
|
| def __repr__(self): |
| terminal = next((s for s in reversed(self.steps) if s.done), None) |
| kars = terminal.observation if terminal else "incomplete" |
| return ( |
| f"Trajectory(episode={self.episode_id}, " |
| f"steps={len(self.steps)}, " |
| f"total_reward={self.total_reward}, " |
| f"outcome={kars!r})" |
| ) |
|
|
| |
|
|
| def reset(env: PranaEnv, patient_id: str) -> str: |
| result = env.reset(patient_id=patient_id) |
| return result.observation.query_result |
|
|
|
|
| def step(env: PranaEnv, action_type: str, **kwargs) -> tuple[str, float, bool, list]: |
| result = env.step(PranaAction(action_type=action_type, **kwargs)) |
| obs = result.observation |
| return ( |
| obs.query_result, |
| obs.reward or 0.0, |
| obs.done or False, |
| obs.missing_fields or [], |
| ) |
|
|
|
|
| def rollout(env: PranaEnv, task: str, patient_id: str, episode_id: str, max_turns: int = 20) -> Trajectory: |
| """Run one full episode. LLM drives the action loop until done=True.""" |
| llm = openai.OpenAI() |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": task}, |
| ] |
| trajectory = Trajectory(episode_id=episode_id) |
|
|
| print(f"\nββ Episode {episode_id} ββββββββββββββββββββββββββββββ") |
| print(f"Task: {task}") |
|
|
| initial_obs = reset(env, patient_id) |
| print(f"[reset] {initial_obs}") |
|
|
| for turn in range(max_turns): |
| response = llm.chat.completions.create( |
| model="gpt-4o", |
| tools=TOOLS, |
| messages=messages, |
| ) |
| msg = response.choices[0].message |
| messages.append(msg) |
|
|
| |
| if msg.tool_calls is None: |
| print(f"[turn {turn+1}] Agent: {msg.content}") |
| trajectory.steps.append(Step( |
| action={"type": "end_turn"}, |
| observation=msg.content or "", |
| reward=0.0, |
| done=True, |
| )) |
| break |
|
|
| for tool_call in msg.tool_calls: |
| action_type = tool_call.function.name |
| inp = json.loads(tool_call.function.arguments) |
| print(f"[turn {turn+1}] {action_type}({json.dumps(inp)})") |
|
|
| obs_str, reward, done, missing = step(env, action_type, **inp) |
| print(f"[turn {turn+1}] obs={obs_str!r} reward={reward} done={done}") |
|
|
| trajectory.steps.append(Step( |
| action={"type": action_type, **inp}, |
| observation=obs_str, |
| reward=reward, |
| done=done, |
| )) |
|
|
| messages.append({ |
| "role": "tool", |
| "tool_call_id": tool_call.id, |
| "content": obs_str, |
| }) |
|
|
| if done: |
| return trajectory |
|
|
| return trajectory |
|
|
|
|
| def run_episodes(task: str, patient_id: str, n: int = 1) -> list[Trajectory]: |
| """Run N independent episodes. Set n=8 for GRPO rollout batch.""" |
| trajectories = [] |
| with PranaEnv(base_url="http://localhost:8000") as env: |
| for i in range(n): |
| traj = rollout(env, task, patient_id, episode_id=f"ep_{i+1}") |
| trajectories.append(traj) |
|
|
| print(f"\nββ Summary ({n} episode(s)) ββββββββββββββββββββββββββ") |
| for t in trajectories: |
| print(f" {t}") |
| return trajectories |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| run_episodes( |
| task=( |
| "File a KARS-compliant SRTR report for patient P001. " |
| "A T1 record exists from 4 months ago. " |
| "Check which fields are stale, re-query only what's needed, and file." |
| ), |
| patient_id="P001", |
| n=1, |
| ) |
|
|