pbanavara commited on
Commit
75a4eab
Β·
verified Β·
1 Parent(s): fb9e30c

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. README.md +124 -197
  2. client.py +4 -14
  3. data/patient_db.json +35 -11
  4. models.py +25 -15
  5. server/prana_env_environment.py +318 -34
  6. test_agent.py +227 -0
  7. test_client.py +12 -9
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Prana Env Environment Server
3
- emoji: πŸ’
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: docker
@@ -9,247 +9,174 @@ app_port: 8000
9
  base_path: /web
10
  tags:
11
  - openenv
 
 
12
  ---
13
 
14
- # Prana Env Environment
15
 
16
- A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
 
18
- ## Quick Start
19
 
20
- The simplest way to use the Prana Env environment is through the `PranaEnv` class:
21
 
22
- ```python
23
- from prana_env import PranaAction, PranaEnv
 
 
 
 
 
 
 
 
24
 
25
- try:
26
- # Create environment from Docker image
27
- prana_envenv = PranaEnv.from_docker_image("prana_env-env:latest")
28
 
29
- # Reset
30
- result = prana_envenv.reset()
31
- print(f"Reset: {result.observation.echoed_message}")
 
 
32
 
33
- # Send multiple messages
34
- messages = ["Hello, World!", "Testing echo", "Final message"]
35
 
36
- for msg in messages:
37
- result = prana_envenv.step(PranaAction(message=msg))
38
- print(f"Sent: '{msg}'")
39
- print(f" β†’ Echoed: '{result.observation.echoed_message}'")
40
- print(f" β†’ Length: {result.observation.message_length}")
41
- print(f" β†’ Reward: {result.reward}")
42
 
43
- finally:
44
- # Always clean up
45
- prana_envenv.close()
 
 
 
 
 
 
 
46
  ```
47
 
48
- That's it! The `PranaEnv.from_docker_image()` method handles:
49
- - Starting the Docker container
50
- - Waiting for the server to be ready
51
- - Connecting to the environment
52
- - Container cleanup when you call `close()`
53
 
54
- ## Building the Docker Image
55
 
56
- Before using the environment, you need to build the Docker image:
 
 
 
 
 
 
57
 
58
- ```bash
59
- # From project root
60
- docker build -t prana_env-env:latest -f server/Dockerfile .
61
- ```
62
 
63
- ## Deploying to Hugging Face Spaces
64
 
65
- You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
 
66
 
67
- ```bash
68
- # From the environment directory (where openenv.yaml is located)
69
- openenv push
 
 
70
 
71
- # Or specify options
72
- openenv push --namespace my-org --private
73
  ```
 
74
 
75
- The `openenv push` command will:
76
- 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
77
- 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
78
- 3. Upload to Hugging Face (ensuring you're logged in)
79
-
80
- ### Prerequisites
81
-
82
- - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
83
-
84
- ### Options
85
-
86
- - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
87
- - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
88
- - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
89
- - `--private`: Deploy the space as private (default: public)
90
 
91
- ### Examples
92
 
93
  ```bash
94
- # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
95
- openenv push
96
-
97
- # Push to a specific repository
98
- openenv push --repo-id my-org/my-env
99
-
100
- # Push with a custom base image
101
- openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
102
-
103
- # Push as a private space
104
- openenv push --private
105
-
106
- # Combine options
107
- openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
108
  ```
109
 
110
- After deployment, your space will be available at:
111
- `https://huggingface.co/spaces/<repo-id>`
112
-
113
- The deployed space includes:
114
- - **Web Interface** at `/web` - Interactive UI for exploring the environment
115
- - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
116
- - **Health Check** at `/health` - Container health monitoring
117
- - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
118
-
119
- ## Environment Details
120
-
121
- ### Action
122
- **PranaAction**: Contains a single field
123
- - `message` (str) - The message to echo back
124
-
125
- ### Observation
126
- **PranaObservation**: Contains the echo response and metadata
127
- - `echoed_message` (str) - The message echoed back
128
- - `message_length` (int) - Length of the message
129
- - `reward` (float) - Reward based on message length (length Γ— 0.1)
130
- - `done` (bool) - Always False for echo environment
131
- - `metadata` (dict) - Additional info like step count
132
-
133
- ### Reward
134
- The reward is calculated as: `message_length Γ— 0.1`
135
- - "Hi" β†’ reward: 0.2
136
- - "Hello, World!" β†’ reward: 1.3
137
- - Empty message β†’ reward: 0.0
138
-
139
- ## Advanced Usage
140
-
141
- ### Connecting to an Existing Server
142
-
143
- If you already have a Prana Env environment server running, you can connect directly:
144
-
145
  ```python
146
- from prana_env import PranaEnv
147
-
148
- # Connect to existing server
149
- prana_envenv = PranaEnv(base_url="<ENV_HTTP_URL_HERE>")
150
-
151
- # Use as normal
152
- result = prana_envenv.reset()
153
- result = prana_envenv.step(PranaAction(message="Hello!"))
154
  ```
155
 
156
- Note: When connecting to an existing server, `prana_envenv.close()` will NOT stop the server.
157
-
158
- ### Using the Context Manager
159
-
160
- The client supports context manager usage for automatic connection management:
161
-
162
  ```python
163
- from prana_env import PranaAction, PranaEnv
164
-
165
- # Connect with context manager (auto-connects and closes)
166
- with PranaEnv(base_url="http://localhost:8000") as env:
167
- result = env.reset()
168
- print(f"Reset: {result.observation.echoed_message}")
169
- # Multiple steps with low latency
170
- for msg in ["Hello", "World", "!"]:
171
- result = env.step(PranaAction(message=msg))
172
- print(f"Echoed: {result.observation.echoed_message}")
173
  ```
174
 
175
- The client uses WebSocket connections for:
176
- - **Lower latency**: No HTTP connection overhead per request
177
- - **Persistent session**: Server maintains your environment state
178
- - **Efficient for episodes**: Better for many sequential steps
179
 
180
- ### Concurrent WebSocket Sessions
 
 
 
 
181
 
182
- The server supports multiple concurrent WebSocket connections. To enable this,
183
- modify `server/app.py` to use factory mode:
184
 
185
- ```python
186
- # In server/app.py - use factory mode for concurrent sessions
187
- app = create_app(
188
- PranaEnvironment, # Pass class, not instance
189
- PranaAction,
190
- PranaObservation,
191
- max_concurrent_envs=4, # Allow 4 concurrent sessions
192
- )
193
- ```
194
 
195
- Then multiple clients can connect simultaneously:
196
 
197
- ```python
198
- from prana_env import PranaAction, PranaEnv
199
- from concurrent.futures import ThreadPoolExecutor
200
-
201
- def run_episode(client_id: int):
202
- with PranaEnv(base_url="http://localhost:8000") as env:
203
- result = env.reset()
204
- for i in range(10):
205
- result = env.step(PranaAction(message=f"Client {client_id}, step {i}"))
206
- return client_id, result.observation.message_length
207
-
208
- # Run 4 episodes concurrently
209
- with ThreadPoolExecutor(max_workers=4) as executor:
210
- results = list(executor.map(run_episode, range(4)))
211
  ```
212
 
213
- ## Development & Testing
214
 
215
- ### Direct Environment Testing
 
 
216
 
217
- Test the environment logic directly without starting the HTTP server:
 
 
218
 
219
- ```bash
220
- # From the server directory
221
- python3 server/prana_env_environment.py
 
222
  ```
223
 
224
- This verifies that:
225
- - Environment resets correctly
226
- - Step executes actions properly
227
- - State tracking works
228
- - Rewards are calculated correctly
229
-
230
- ### Running Locally
231
-
232
- Run the server locally for development:
233
 
234
  ```bash
235
- uvicorn server.app:app --reload
 
 
236
  ```
237
 
238
- ## Project Structure
239
-
240
- ```
241
- prana_env/
242
- β”œβ”€β”€ .dockerignore # Docker build exclusions
243
- β”œβ”€β”€ __init__.py # Module exports
244
- β”œβ”€β”€ README.md # This file
245
- β”œβ”€β”€ openenv.yaml # OpenEnv manifest
246
- β”œβ”€β”€ pyproject.toml # Project metadata and dependencies
247
- β”œβ”€β”€ uv.lock # Locked dependencies (generated)
248
- β”œβ”€β”€ client.py # PranaEnv client
249
- β”œβ”€β”€ models.py # Action and Observation models
250
- └── server/
251
- β”œβ”€β”€ __init__.py # Server module exports
252
- β”œβ”€β”€ prana_env_environment.py # Core environment logic
253
- β”œβ”€β”€ app.py # FastAPI application (HTTP + WebSocket endpoints)
254
- └── Dockerfile # Container image definition
255
- ```
 
1
  ---
2
+ title: PRANA-Env Environment Server
3
+ emoji: πŸ₯
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: docker
 
9
  base_path: /web
10
  tags:
11
  - openenv
12
+ - reinforcement-learning
13
+ - clinical
14
  ---
15
 
16
+ # PRANA-Env
17
 
18
+ **Policy Reinforced Administrative Navigation Agent** β€” an OpenEnv RL environment for kidney transplant administration.
19
 
20
+ PRANA-Env simulates the multi-step clinical workflow required to file a KARS-compliant SRTR report for a transplant candidate. The agent must query fragmented datastores, detect stale lab values, and file a complete report β€” earning rewards from a deterministic KARS validator.
21
 
22
+ ## Architecture
23
 
24
+ ```
25
+ LLM Agent (GPT-4o / fine-tuned model)
26
+ β”‚
27
+ β”‚ query_db / record_value / file_report
28
+ β–Ό
29
+ PranaEnv Client ──(WebSocket)── PranaEnvironment Server
30
+ β”‚
31
+ KARS Validator
32
+ (reward signal)
33
+ ```
34
 
35
+ ## Action Space
 
 
36
 
37
+ | Action | Required fields | Effect |
38
+ |--------|----------------|--------|
39
+ | `query_db` | `target`, `field`, `patient_id` | Returns current value from PatientDB |
40
+ | `record_value` | `field`, `value` | Writes value into episode record with today's timestamp |
41
+ | `file_report` | β€” | KARS validates record β†’ reward β†’ done |
42
 
43
+ ## Observation Space
 
44
 
45
+ Every observation includes:
 
 
 
 
 
46
 
47
+ ```python
48
+ PranaObservation(
49
+ query_result # str: value, NOT_FOUND, RECORDED, KARS status
50
+ active_task # str: current task context (t1–t5)
51
+ recorded_fields # dict: {field: {value, recorded_at}} β€” full current record
52
+ missing_fields # list[str]: KARS issues after file_report
53
+ kars_result # str | None: "PASSED" | "FAILED"
54
+ reward # float
55
+ done # bool
56
+ )
57
  ```
58
 
59
+ `recorded_fields` shows the agent its full current state including timestamps β€” enabling staleness detection and selective re-querying.
 
 
 
 
60
 
61
+ ## Reward Signal
62
 
63
+ | Event | Reward |
64
+ |-------|--------|
65
+ | KARS PASSED β€” first attempt | **+15** |
66
+ | KARS PASSED β€” after correction | **+10** |
67
+ | Re-query of already-fresh field | **βˆ’1** |
68
+ | KARS FAILED β€” missing or stale fields | **βˆ’5** |
69
+ | KARS FAILED β€” unrecoverable (3 attempts) | **βˆ’10** |
70
 
71
+ ## Temporal Model (T1 β†’ T5)
 
 
 
72
 
73
+ Episodes simulate a 4-month clinical timeline:
74
 
75
+ - **T1 (2025-11-07)**: Initial labs recorded. Snapshot pre-loaded into episode record on `reset()`.
76
+ - **T5 (2026-03-07)**: Filing date. KARS requires time-sensitive fields within **90 days**.
77
 
78
+ On `reset()`, the agent sees a pre-populated record with stale T1 values. It must:
79
+ 1. Identify which fields are stale (`hba1c`, `gfr`, `creatinine` β€” time-sensitive)
80
+ 2. Re-query only those fields to get current T5 values
81
+ 3. Leave stable fields (`blood_type`) untouched β€” re-querying incurs a penalty
82
+ 4. File when the record is complete and fresh
83
 
84
+ **Example trajectory:**
 
85
  ```
86
+ reset() β†’ record pre-loaded: {hba1c: {value: 7.2, recorded_at: 2025-11-07}, ...}
87
 
88
+ query_db(hba1c) β†’ 8.9 (T5 value β€” GFR worsened)
89
+ query_db(gfr) β†’ 12.1 (was 18.5 at T1)
90
+ query_db(creatinine) β†’ 4.7 (was 3.8 at T1)
91
+ record_value Γ— 3
92
+ file_report() β†’ KARS PASSED, reward=+15
93
+ ```
 
 
 
 
 
 
 
 
 
94
 
95
+ ## Quick Start
96
 
97
  ```bash
98
+ # Start the server
99
+ conda activate openenv
100
+ uvicorn server.app:app --host 0.0.0.0 --port 8000
 
 
 
 
 
 
 
 
 
 
 
101
  ```
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  ```python
104
+ # Run the LLM agent loop
105
+ python test_agent.py
 
 
 
 
 
 
106
  ```
107
 
 
 
 
 
 
 
108
  ```python
109
+ # Run N episodes for GRPO rollout batch
110
+ from test_agent import run_episodes
111
+
112
+ trajectories = run_episodes(
113
+ task="File a KARS-compliant SRTR report for patient P001. "
114
+ "A T1 record exists from 4 months ago. "
115
+ "Check which fields are stale, re-query only what's needed, and file.",
116
+ patient_id="P001",
117
+ n=8, # GRPO batch size
118
+ )
119
  ```
120
 
121
+ ## Patients
 
 
 
122
 
123
+ | ID | Condition | T1 GFR | T5 GFR | HbA1c T1β†’T5 | Notes |
124
+ |----|-----------|--------|--------|-------------|-------|
125
+ | P001 | CKD Stage 4 | 18.5 | 12.1 | 7.2β†’8.9 | Complete record |
126
+ | P002 | Diabetic nephropathy | 11.0 | 8.3 | 9.1β†’10.2 | Antihypertensives, insulin |
127
+ | P003 | CKD Stage 3 | 22.3 | 19.8 | null | HbA1c never recorded, inactive waitlist |
128
 
129
+ ## KARS Required Fields
 
130
 
131
+ | Field | Source | Time-sensitive |
132
+ |-------|--------|---------------|
133
+ | `hba1c` | PatientDB | Yes β€” 90-day window |
134
+ | `gfr` | PatientDB | Yes β€” 90-day window |
135
+ | `creatinine` | PatientDB | Yes β€” 90-day window |
136
+ | `blood_type` | PatientDB | No β€” stable |
 
 
 
137
 
138
+ ## Project Structure
139
 
140
+ ```
141
+ prana_env/
142
+ β”œβ”€β”€ client.py # PranaEnv WebSocket client
143
+ β”œβ”€β”€ models.py # PranaAction, PranaObservation
144
+ β”œβ”€β”€ test_agent.py # LLM agent RL loop (GPT-4o)
145
+ β”œβ”€β”€ test_client.py # Smoke test client
146
+ β”œβ”€β”€ data/
147
+ β”‚ └── patient_db.json # Patient records with T1 snapshots and T5 values
148
+ └── server/
149
+ β”œβ”€β”€ app.py # FastAPI + WebSocket server
150
+ β”œβ”€β”€ prana_env_environment.py # RL environment: actions, KARS validator, rewards
151
+ └── Dockerfile
 
 
152
  ```
153
 
154
+ ## Connecting to an Existing Server
155
 
156
+ ```python
157
+ from prana_env.client import PranaEnv
158
+ from prana_env.models import PranaAction
159
 
160
+ with PranaEnv(base_url="http://localhost:8000") as env:
161
+ result = env.reset(patient_id="P001")
162
+ print(result.observation.query_result)
163
 
164
+ result = env.step(PranaAction(action_type="query_db", target="PatientDB",
165
+ field="hba1c", patient_id="P001"))
166
+ print(result.observation.query_result) # "8.9"
167
+ print(result.observation.recorded_fields) # current record state
168
  ```
169
 
170
+ ## Deploying to Hugging Face Spaces
 
 
 
 
 
 
 
 
171
 
172
  ```bash
173
+ openenv push
174
+ # or
175
+ openenv push --repo-id my-org/prana-env --private
176
  ```
177
 
178
+ After deployment:
179
+ - **Web UI**: `/web`
180
+ - **API docs**: `/docs`
181
+ - **Health**: `/health`
182
+ - **WebSocket**: `/ws`
 
 
 
 
 
 
 
 
 
 
 
 
 
client.py CHANGED
@@ -10,20 +10,7 @@ from .models import PranaAction, PranaObservation
10
 
11
 
12
  class PranaEnv(EnvClient[PranaAction, PranaObservation, State]):
13
- """
14
- Client for PRANA-Env.
15
-
16
- Example:
17
- >>> with PranaEnv(base_url="http://localhost:8000") as client:
18
- ... client.reset()
19
- ... result = client.step(PranaAction(
20
- ... action_type="query_db",
21
- ... target="PatientDB",
22
- ... field="hba1c",
23
- ... patient_id="P001",
24
- ... ))
25
- ... print(result.observation.query_result) # "7.2"
26
- """
27
 
28
  def _step_payload(self, action: PranaAction) -> Dict:
29
  return {k: v for k, v in action.model_dump().items() if v is not None}
@@ -34,6 +21,9 @@ class PranaEnv(EnvClient[PranaAction, PranaObservation, State]):
34
  query_result=obs_data.get("query_result", ""),
35
  active_task=obs_data.get("active_task", "t1"),
36
  policy_alerts=obs_data.get("policy_alerts", ""),
 
 
 
37
  done=payload.get("done", False),
38
  reward=payload.get("reward", 0.0),
39
  metadata=obs_data.get("metadata", {}),
 
10
 
11
 
12
  class PranaEnv(EnvClient[PranaAction, PranaObservation, State]):
13
+ """Client for PRANA-Env."""
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def _step_payload(self, action: PranaAction) -> Dict:
16
  return {k: v for k, v in action.model_dump().items() if v is not None}
 
21
  query_result=obs_data.get("query_result", ""),
22
  active_task=obs_data.get("active_task", "t1"),
23
  policy_alerts=obs_data.get("policy_alerts", ""),
24
+ kars_result=obs_data.get("kars_result"),
25
+ missing_fields=obs_data.get("missing_fields", []),
26
+ recorded_fields=obs_data.get("recorded_fields", {}),
27
  done=payload.get("done", False),
28
  reward=payload.get("reward", 0.0),
29
  metadata=obs_data.get("metadata", {}),
data/patient_db.json CHANGED
@@ -5,20 +5,36 @@
5
  "name": "Jane Doe",
6
  "age": 52,
7
  "blood_type": "A+",
8
- "hba1c": 7.2,
9
- "gfr": 18.5,
10
- "creatinine": 3.8,
11
- "pra": 12
 
 
 
 
 
 
 
 
12
  },
13
  "P002": {
14
  "patient_id": "P002",
15
  "name": "John Smith",
16
  "age": 61,
17
  "blood_type": "O-",
18
- "hba1c": 9.1,
19
- "gfr": 11.0,
20
- "creatinine": 5.2,
21
- "pra": 45
 
 
 
 
 
 
 
 
22
  },
23
  "P003": {
24
  "patient_id": "P003",
@@ -26,9 +42,17 @@
26
  "age": 47,
27
  "blood_type": "B+",
28
  "hba1c": null,
29
- "gfr": 22.3,
30
- "creatinine": 3.1,
31
- "pra": 8
 
 
 
 
 
 
 
 
32
  }
33
  }
34
  }
 
5
  "name": "Jane Doe",
6
  "age": 52,
7
  "blood_type": "A+",
8
+ "hba1c": 8.9,
9
+ "gfr": 12.1,
10
+ "creatinine": 4.7,
11
+ "pra": 12,
12
+ "t1_snapshot": {
13
+ "hba1c": 7.2,
14
+ "gfr": 18.5,
15
+ "creatinine": 3.8,
16
+ "blood_type": "A+",
17
+ "pra": 12,
18
+ "recorded_at": "2025-11-07"
19
+ }
20
  },
21
  "P002": {
22
  "patient_id": "P002",
23
  "name": "John Smith",
24
  "age": 61,
25
  "blood_type": "O-",
26
+ "hba1c": 10.2,
27
+ "gfr": 8.3,
28
+ "creatinine": 6.1,
29
+ "pra": 45,
30
+ "t1_snapshot": {
31
+ "hba1c": 9.1,
32
+ "gfr": 11.0,
33
+ "creatinine": 5.2,
34
+ "blood_type": "O-",
35
+ "pra": 45,
36
+ "recorded_at": "2025-11-07"
37
+ }
38
  },
39
  "P003": {
40
  "patient_id": "P003",
 
42
  "age": 47,
43
  "blood_type": "B+",
44
  "hba1c": null,
45
+ "gfr": 19.8,
46
+ "creatinine": 3.4,
47
+ "pra": 8,
48
+ "t1_snapshot": {
49
+ "hba1c": null,
50
+ "gfr": 22.3,
51
+ "creatinine": 3.1,
52
+ "blood_type": "B+",
53
+ "pra": 8,
54
+ "recorded_at": "2025-11-07"
55
+ }
56
  }
57
  }
58
  }
models.py CHANGED
@@ -5,7 +5,7 @@ Action space and observation space for the kidney transplant
5
  administration environment.
6
  """
7
 
8
- from typing import Optional
9
 
10
  from pydantic import Field
11
 
@@ -16,24 +16,19 @@ class PranaAction(Action):
16
  """
17
  Action for PRANA-Env.
18
 
19
- For the smoke test, only action_type='query_db' is supported.
20
-
21
- Example:
22
- >>> action = PranaAction(
23
- ... action_type="query_db",
24
- ... target="PatientDB",
25
- ... field="hba1c",
26
- ... patient_id="P001",
27
- ... )
28
  """
29
 
30
  action_type: str = Field(
31
  ...,
32
  description=(
33
- "Type of action: query_db | record_value | update_past_record "
34
- "| search_policy | infer_from_evidence | file_report | advance_task"
35
  ),
36
  )
 
37
  target: Optional[str] = Field(
38
  default=None,
39
  description="Datastore name for query_db (PatientDB, ClinicalNotesDB, PharmacyDB, WaitlistDB)",
@@ -44,8 +39,12 @@ class PranaAction(Action):
44
  patient_id: Optional[str] = Field(
45
  default=None, description="Patient identifier"
46
  )
 
47
  value: Optional[str] = Field(
48
- default=None, description="Value to record (for record_value / update_past_record)"
 
 
 
49
  )
50
  task_ref: Optional[str] = Field(
51
  default=None, description="Task reference for retroactive updates (e.g. 't1')"
@@ -58,8 +57,6 @@ class PranaAction(Action):
58
  class PranaObservation(Observation):
59
  """
60
  Observation from PRANA-Env.
61
-
62
- Contains the result of the last action plus episode context.
63
  """
64
 
65
  query_result: str = Field(
@@ -74,3 +71,16 @@ class PranaObservation(Observation):
74
  default="",
75
  description="Any OPTN policy rules triggered by this observation",
76
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  administration environment.
6
  """
7
 
8
+ from typing import List, Optional
9
 
10
  from pydantic import Field
11
 
 
16
  """
17
  Action for PRANA-Env.
18
 
19
+ Supported action_types:
20
+ query_db β€” retrieve a field from a datastore
21
+ record_value β€” write a field into the episode patient record
22
+ file_report β€” submit compiled record to KARS validator
 
 
 
 
 
23
  """
24
 
25
  action_type: str = Field(
26
  ...,
27
  description=(
28
+ "Type of action: query_db | record_value | file_report"
 
29
  ),
30
  )
31
+ # query_db / record_value
32
  target: Optional[str] = Field(
33
  default=None,
34
  description="Datastore name for query_db (PatientDB, ClinicalNotesDB, PharmacyDB, WaitlistDB)",
 
39
  patient_id: Optional[str] = Field(
40
  default=None, description="Patient identifier"
41
  )
42
+ # record_value / update_past_record
43
  value: Optional[str] = Field(
44
+ default=None, description="Value to record"
45
+ )
46
+ source: Optional[str] = Field(
47
+ default=None, description="Source datastore the value was retrieved from"
48
  )
49
  task_ref: Optional[str] = Field(
50
  default=None, description="Task reference for retroactive updates (e.g. 't1')"
 
57
  class PranaObservation(Observation):
58
  """
59
  Observation from PRANA-Env.
 
 
60
  """
61
 
62
  query_result: str = Field(
 
71
  default="",
72
  description="Any OPTN policy rules triggered by this observation",
73
  )
74
+ # Populated after file_report
75
+ kars_result: Optional[str] = Field(
76
+ default=None,
77
+ description="KARS validation result: PASSED or FAILED",
78
+ )
79
+ missing_fields: List[str] = Field(
80
+ default_factory=list,
81
+ description="Fields missing from the report per KARS requirements",
82
+ )
83
+ recorded_fields: dict = Field(
84
+ default_factory=dict,
85
+ description="Current patient record β€” fields recorded so far this episode",
86
+ )
server/prana_env_environment.py CHANGED
@@ -1,14 +1,34 @@
1
  """
2
  PRANA-Env Environment Implementation.
3
 
4
- Kidney transplant administration RL environment built on OpenEnv.
5
- Phase 1 smoke test: supports query_db action against PatientDB.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
- import json
9
  import logging
 
 
10
  from pathlib import Path
11
  from uuid import uuid4
 
12
 
13
  from openenv.core.env_server.interfaces import Environment
14
  from openenv.core.env_server.types import State
@@ -18,22 +38,63 @@ from models import PranaAction, PranaObservation
18
  tag = "[prana_env/environment]"
19
  logger = logging.getLogger(__name__)
20
 
21
- # Path to data directory β€” resolved relative to this file
22
  DATA_DIR = Path(__file__).parent.parent / "data"
23
 
 
 
 
 
24
 
25
- class PranaEnvironment(Environment):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
- PRANA-Env: kidney transplant administration environment.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- Episode structure (5 tasks):
30
- t1: Initial Labs β€” query PatientDB (HbA1c, GFR, creatinine)
31
- t2: Waitlist Update β€” query/update WaitlistDB
32
- t3: Medication Review β€” query PharmacyDB
33
- t4: Physician Notes β€” query ClinicalNotesDB
34
- t5: SRTR Report Filing β€” file_report β†’ KARS validator
35
 
36
- Phase 1 smoke test: query_db against PatientDB only.
 
 
 
 
 
 
 
 
37
  """
38
 
39
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
@@ -43,26 +104,88 @@ class PranaEnvironment(Environment):
43
  self._state = State(episode_id=str(uuid4()), step_count=0)
44
  self._active_task = "t1"
45
  self._patient_id: str | None = None
 
 
 
 
 
46
  self._patient_db = self._load_db("patient_db.json")
47
  logger.info(f"{tag} Loaded PatientDB with {len(self._patient_db.get('patients', {}))} patients")
48
 
49
  def _load_db(self, filename: str) -> dict:
50
  path = DATA_DIR / filename
51
- logger.info(f"{tag} Loading datastore from {path}")
52
  with open(path) as f:
53
  return json.load(f)
54
 
 
 
 
55
  def reset(self, seed: int | None = None, episode_id: str | None = None, **kwargs) -> PranaObservation:
56
  patient_id: str | None = kwargs.get("patient_id")
 
 
 
 
 
 
 
57
  self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
58
  self._active_task = "t1"
59
  self._patient_id = patient_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- logger.info(f"{tag} reset β€” episode={self._state.episode_id} patient_id={patient_id}")
 
 
 
 
 
 
 
 
62
 
63
  return PranaObservation(
64
- query_result="Episode reset. Ready for task t1: Initial Labs.",
 
 
 
 
 
 
65
  active_task=self._active_task,
 
66
  done=False,
67
  reward=0.0,
68
  )
@@ -71,83 +194,244 @@ class PranaEnvironment(Environment):
71
  self._state.step_count += 1
72
  logger.info(
73
  f"{tag} step={self._state.step_count} action_type={action.action_type} "
74
- f"target={action.target} field={action.field} patient_id={action.patient_id}"
75
  )
76
 
77
  if action.action_type == "query_db":
78
  return self._handle_query_db(action)
 
 
 
 
79
 
80
  logger.warning(f"{tag} Unsupported action_type={action.action_type}")
81
  return PranaObservation(
82
- query_result=f"NOT_SUPPORTED: action_type '{action.action_type}' not implemented yet.",
83
  active_task=self._active_task,
 
84
  done=False,
85
  reward=0.0,
86
  )
87
 
 
 
88
  def _handle_query_db(self, action: PranaAction) -> PranaObservation:
89
  db_name = (action.target or "").lower()
90
  field = (action.field or "").lower()
91
  patient_id = action.patient_id or self._patient_id
92
 
93
- logger.info(f"{tag} query_db db={db_name} field={field} patient_id={patient_id}")
94
-
95
  if db_name != "patientdb":
96
- logger.warning(f"{tag} Datastore '{db_name}' not available in Phase 1")
97
  return PranaObservation(
98
- query_result=f"NOT_AVAILABLE: datastore '{action.target}' not loaded in Phase 1.",
99
  active_task=self._active_task,
 
100
  done=False,
101
  reward=0.0,
102
  )
103
 
104
  if not patient_id:
105
  return PranaObservation(
106
- query_result="ERROR: patient_id is required for query_db.",
107
  active_task=self._active_task,
 
108
  done=False,
109
  reward=0.0,
110
  )
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  patients = self._patient_db.get("patients", {})
113
  patient = patients.get(patient_id)
114
-
115
  if not patient:
116
- logger.info(f"{tag} patient_id={patient_id} NOT_FOUND in PatientDB")
117
  return PranaObservation(
118
  query_result=f"NOT_FOUND: patient '{patient_id}' not in PatientDB.",
119
  active_task=self._active_task,
 
120
  done=False,
121
  reward=0.0,
122
  )
123
 
124
- if field not in patient:
125
- logger.info(f"{tag} field={field} NOT_FOUND for patient={patient_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  return PranaObservation(
127
- query_result=f"NOT_FOUND: field '{field}' not in PatientDB for patient '{patient_id}'.",
128
  active_task=self._active_task,
 
129
  done=False,
130
  reward=0.0,
131
  )
132
 
133
- value = patient[field]
134
- if value is None:
135
- logger.info(f"{tag} field={field} is NULL for patient={patient_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return PranaObservation(
137
- query_result=f"NOT_FOUND: field '{field}' has no recorded value for patient '{patient_id}'.",
138
  active_task=self._active_task,
 
139
  done=False,
140
  reward=0.0,
141
  )
142
 
143
- logger.info(f"{tag} query_db success field={field} value={value} patient={patient_id}")
 
 
 
 
 
 
144
  return PranaObservation(
145
- query_result=str(value),
 
 
 
146
  active_task=self._active_task,
 
147
  done=False,
148
  reward=0.0,
149
  )
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  @property
152
  def state(self) -> State:
153
  return self._state
 
1
  """
2
  PRANA-Env Environment Implementation.
3
 
4
+ Minimal RL loop:
5
+ 1. query_db β€” retrieve field from PatientDB
6
+ 2. record_value β€” write field into episode patient record
7
+ 3. file_report β€” KARS validator β†’ reward signal β†’ episode done
8
+
9
+ Reward:
10
+ +15 KARS PASSED on first attempt
11
+ +10 KARS PASSED after prior failed attempt
12
+ -1 query_db for a field already fresh in the record (inefficiency penalty)
13
+ -5 file_report with missing or stale required fields
14
+ -10 unrecoverable KARS failure (max filing attempts exceeded)
15
+
16
+ Stochasticity (4 sources):
17
+ 1. T1 date randomization β€” T1 age sampled Uniform(T1_AGE_MIN, T1_AGE_MAX) days
18
+ Agent must calculate staleness dynamically, not memorize
19
+ 2. Random patient selection β€” if no patient_id given, pick randomly from pool
20
+ 3. Anomaly injection β€” with ANOMALY_PROB, inject a spurious reading for one
21
+ time-sensitive field; agent must detect and escalate
22
+ 4. Field availability noise β€” with PENDING_PROB, a field returns PENDING on first
23
+ query; resolved on retry (simulates data entry lag)
24
  """
25
 
 
26
  import logging
27
+ import random
28
+ from datetime import date, timedelta
29
  from pathlib import Path
30
  from uuid import uuid4
31
+ import json
32
 
33
  from openenv.core.env_server.interfaces import Environment
34
  from openenv.core.env_server.types import State
 
38
  tag = "[prana_env/environment]"
39
  logger = logging.getLogger(__name__)
40
 
 
41
  DATA_DIR = Path(__file__).parent.parent / "data"
42
 
43
+ # KARS required fields
44
+ KARS_REQUIRED_FIELDS = ["hba1c", "gfr", "creatinine", "blood_type"]
45
+ TIME_SENSITIVE_FIELDS = {"hba1c", "gfr", "creatinine"}
46
+ STABLE_FIELDS = {"blood_type", "pra"}
47
 
48
+ MAX_FILE_ATTEMPTS = 3
49
+
50
+ # Temporal constants
51
+ EPISODE_DATE = date(2026, 3, 7)
52
+ RECENCY_DAYS = 90
53
+
54
+ # ── Stochasticity parameters ──────────────────────────────────────────────────
55
+ T1_AGE_MIN_DAYS = 60 # shortest possible T1 record age (fresh β€” no re-query needed)
56
+ T1_AGE_MAX_DAYS = 150 # longest possible T1 record age (stale β€” must re-query)
57
+ ANOMALY_PROB = 0.30 # probability of injecting anomalous reading per episode
58
+ ANOMALY_DELTA = 0.40 # anomalous value deviates by this fraction from true T5
59
+ ANOMALY_WINDOW_DAYS = 14 # anomaly detection window (matches OPTN Clinical Integrity Policy)
60
+ ANOMALY_THRESHOLD = 0.25 # flag if delta > 25% within window
61
+ PENDING_PROB = 0.15 # probability of PENDING response on first query of a field
62
+
63
+
64
+ def kars_validate(record: dict) -> tuple[bool, list[str]]:
65
  """
66
+ Deterministic KARS validator with recency checks.
67
+ record values: {field: {"value": ..., "recorded_at": "YYYY-MM-DD"}}
68
+ Returns (passed, issues).
69
+ """
70
+ cutoff = EPISODE_DATE - timedelta(days=RECENCY_DAYS)
71
+ issues = []
72
+
73
+ for f in KARS_REQUIRED_FIELDS:
74
+ entry = record.get(f)
75
+ if entry is None or entry.get("value") is None:
76
+ issues.append(f"{f} (missing)")
77
+ continue
78
+ if f in TIME_SENSITIVE_FIELDS:
79
+ try:
80
+ recorded_at = date.fromisoformat(entry.get("recorded_at", ""))
81
+ if recorded_at < cutoff:
82
+ issues.append(f"{f} (stale: recorded {recorded_at}, must be after {cutoff})")
83
+ except ValueError:
84
+ issues.append(f"{f} (invalid date)")
85
+
86
+ return (len(issues) == 0, issues)
87
 
 
 
 
 
 
 
88
 
89
+ class PranaEnvironment(Environment):
90
+ """
91
+ PRANA-Env: kidney transplant administration RL environment.
92
+
93
+ Stochastic per-episode:
94
+ - T1 record age varies (60–150 days) β€” agent must calculate recency dynamically
95
+ - Patient selected randomly if not specified
96
+ - One time-sensitive field may have an injected anomalous reading (30% episodes)
97
+ - Some fields return PENDING on first query (15% per field) β€” retry resolves
98
  """
99
 
100
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
 
104
  self._state = State(episode_id=str(uuid4()), step_count=0)
105
  self._active_task = "t1"
106
  self._patient_id: str | None = None
107
+ self._patient_record: dict = {}
108
+ self._file_attempts: int = 0
109
+ self._t1_date: date = EPISODE_DATE - timedelta(days=120)
110
+ self._pending_fields: set = set()
111
+ self._injected_anomaly: dict | None = None
112
  self._patient_db = self._load_db("patient_db.json")
113
  logger.info(f"{tag} Loaded PatientDB with {len(self._patient_db.get('patients', {}))} patients")
114
 
115
  def _load_db(self, filename: str) -> dict:
116
  path = DATA_DIR / filename
 
117
  with open(path) as f:
118
  return json.load(f)
119
 
120
+ def _make_entry(self, value, recorded_at: date) -> dict:
121
+ return {"value": str(value), "recorded_at": recorded_at.isoformat()}
122
+
123
  def reset(self, seed: int | None = None, episode_id: str | None = None, **kwargs) -> PranaObservation:
124
  patient_id: str | None = kwargs.get("patient_id")
125
+ patients = self._patient_db.get("patients", {})
126
+
127
+ # ── Stochasticity 2: random patient selection ─────────────────────────
128
+ if not patient_id:
129
+ patient_id = random.choice(list(patients.keys()))
130
+ logger.info(f"{tag} No patient_id specified β€” randomly selected {patient_id}")
131
+
132
  self._state = State(episode_id=episode_id or str(uuid4()), step_count=0)
133
  self._active_task = "t1"
134
  self._patient_id = patient_id
135
+ self._patient_record = {}
136
+ self._file_attempts = 0
137
+ self._pending_fields = set()
138
+ self._injected_anomaly = None
139
+
140
+ # ── Stochasticity 1: randomize T1 record age ──────────────────────────
141
+ t1_days_ago = random.randint(T1_AGE_MIN_DAYS, T1_AGE_MAX_DAYS)
142
+ self._t1_date = EPISODE_DATE - timedelta(days=t1_days_ago)
143
+ cutoff = EPISODE_DATE - timedelta(days=RECENCY_DAYS)
144
+ t1_is_stale = self._t1_date < cutoff
145
+
146
+ # Pre-populate record with T1 snapshot at randomized date
147
+ patient = patients.get(patient_id, {})
148
+ snapshot = patient.get("t1_snapshot", {})
149
+ for field in KARS_REQUIRED_FIELDS:
150
+ val = snapshot.get(field)
151
+ if val is not None:
152
+ self._patient_record[field] = self._make_entry(val, self._t1_date)
153
+
154
+ # ── Stochasticity 3: anomaly injection ────────────────────────────────
155
+ if random.random() < ANOMALY_PROB:
156
+ field = random.choice(sorted(TIME_SENSITIVE_FIELDS))
157
+ t5_value = patient.get(field)
158
+ if t5_value is not None:
159
+ direction = random.choice([-1, 1])
160
+ anomaly_value = round(t5_value * (1 + direction * ANOMALY_DELTA), 1)
161
+ anomaly_days = random.randint(1, 6)
162
+ self._injected_anomaly = {
163
+ "field": field,
164
+ "value": anomaly_value,
165
+ "recorded_at": (EPISODE_DATE - timedelta(days=anomaly_days)).isoformat(),
166
+ }
167
+ logger.info(f"{tag} Injected anomaly: {self._injected_anomaly}")
168
 
169
+ logger.info(
170
+ f"{tag} reset episode={self._state.episode_id} patient={patient_id} "
171
+ f"t1_date={self._t1_date} t1_stale={t1_is_stale} "
172
+ f"anomaly={self._injected_anomaly}"
173
+ )
174
+
175
+ stale_note = (
176
+ f"T1 record is {'STALE (>90 days)' if t1_is_stale else 'FRESH (≀90 days)'}."
177
+ )
178
 
179
  return PranaObservation(
180
+ query_result=(
181
+ f"Episode reset. Patient: {patient_id}. "
182
+ f"Filing date: {EPISODE_DATE}. "
183
+ f"T1 record date: {self._t1_date} ({t1_days_ago} days ago). {stale_note} "
184
+ f"Required fields: {KARS_REQUIRED_FIELDS}. "
185
+ f"Time-sensitive {sorted(TIME_SENSITIVE_FIELDS)} must be recorded after {cutoff}."
186
+ ),
187
  active_task=self._active_task,
188
+ recorded_fields=self._patient_record.copy(),
189
  done=False,
190
  reward=0.0,
191
  )
 
194
  self._state.step_count += 1
195
  logger.info(
196
  f"{tag} step={self._state.step_count} action_type={action.action_type} "
197
+ f"field={action.field} value={action.value}"
198
  )
199
 
200
  if action.action_type == "query_db":
201
  return self._handle_query_db(action)
202
+ if action.action_type == "record_value":
203
+ return self._handle_record_value(action)
204
+ if action.action_type == "file_report":
205
+ return self._handle_file_report(action)
206
 
207
  logger.warning(f"{tag} Unsupported action_type={action.action_type}")
208
  return PranaObservation(
209
+ query_result=f"NOT_SUPPORTED: action_type '{action.action_type}'.",
210
  active_task=self._active_task,
211
+ recorded_fields=self._patient_record.copy(),
212
  done=False,
213
  reward=0.0,
214
  )
215
 
216
+ # ── Action handlers ───────────────────────────────────────────────────────
217
+
218
  def _handle_query_db(self, action: PranaAction) -> PranaObservation:
219
  db_name = (action.target or "").lower()
220
  field = (action.field or "").lower()
221
  patient_id = action.patient_id or self._patient_id
222
 
 
 
223
  if db_name != "patientdb":
 
224
  return PranaObservation(
225
+ query_result=f"NOT_AVAILABLE: datastore '{action.target}' not in Phase 1.",
226
  active_task=self._active_task,
227
+ recorded_fields=self._patient_record.copy(),
228
  done=False,
229
  reward=0.0,
230
  )
231
 
232
  if not patient_id:
233
  return PranaObservation(
234
+ query_result="ERROR: patient_id required.",
235
  active_task=self._active_task,
236
+ recorded_fields=self._patient_record.copy(),
237
  done=False,
238
  reward=0.0,
239
  )
240
 
241
+ # Inefficiency penalty β€” field already fresh in record
242
+ cutoff = EPISODE_DATE - timedelta(days=RECENCY_DAYS)
243
+ if field in self._patient_record:
244
+ entry = self._patient_record[field]
245
+ try:
246
+ recorded_at = date.fromisoformat(entry.get("recorded_at", ""))
247
+ if field in STABLE_FIELDS or recorded_at >= cutoff:
248
+ logger.info(f"{tag} field={field} already fresh β€” inefficiency penalty")
249
+ return PranaObservation(
250
+ query_result=f"ALREADY_RECORDED: '{field}' = {entry['value']} (recorded {entry['recorded_at']})",
251
+ active_task=self._active_task,
252
+ recorded_fields=self._patient_record.copy(),
253
+ done=False,
254
+ reward=-1.0,
255
+ )
256
+ except ValueError:
257
+ pass
258
+
259
  patients = self._patient_db.get("patients", {})
260
  patient = patients.get(patient_id)
 
261
  if not patient:
 
262
  return PranaObservation(
263
  query_result=f"NOT_FOUND: patient '{patient_id}' not in PatientDB.",
264
  active_task=self._active_task,
265
+ recorded_fields=self._patient_record.copy(),
266
  done=False,
267
  reward=0.0,
268
  )
269
 
270
+ # ── Stochasticity 4: field availability noise (PENDING) ───────────────
271
+ if field in TIME_SENSITIVE_FIELDS and field not in self._pending_fields:
272
+ if random.random() < PENDING_PROB:
273
+ self._pending_fields.add(field)
274
+ logger.info(f"{tag} field={field} returned PENDING (will resolve on retry)")
275
+ return PranaObservation(
276
+ query_result=(
277
+ f"PENDING: '{field}' not yet entered for patient '{patient_id}'. "
278
+ f"Data entry in progress β€” retry."
279
+ ),
280
+ active_task=self._active_task,
281
+ recorded_fields=self._patient_record.copy(),
282
+ done=False,
283
+ reward=0.0,
284
+ )
285
+
286
+ value = patient.get(field)
287
+ if value is None:
288
  return PranaObservation(
289
+ query_result=f"NOT_FOUND: '{field}' has no value for patient '{patient_id}'.",
290
  active_task=self._active_task,
291
+ recorded_fields=self._patient_record.copy(),
292
  done=False,
293
  reward=0.0,
294
  )
295
 
296
+ # ── Stochasticity 3: include anomaly in history if injected ───────────
297
+ if field in TIME_SENSITIVE_FIELDS:
298
+ query_result = self._format_lab_history(field, patient_id, value)
299
+ else:
300
+ query_result = str(value)
301
+
302
+ logger.info(f"{tag} query_db OK field={field} value={value}")
303
+ return PranaObservation(
304
+ query_result=query_result,
305
+ active_task=self._active_task,
306
+ recorded_fields=self._patient_record.copy(),
307
+ done=False,
308
+ reward=0.0,
309
+ )
310
+
311
+ def _format_lab_history(self, field: str, patient_id: str, t5_value) -> str:
312
+ """
313
+ Format a time-sensitive field as a timestamped history.
314
+ Includes T1 snapshot entry, T5 current entry, and injected anomaly if present.
315
+ Flags anomalies per OPTN Clinical Integrity Policy.
316
+ """
317
+ snapshot = self._patient_db["patients"][patient_id].get("t1_snapshot", {})
318
+ t1_val = snapshot.get(field)
319
+
320
+ history: list[tuple[date, float]] = []
321
+ if t1_val is not None:
322
+ history.append((self._t1_date, float(t1_val)))
323
+
324
+ # Inject anomalous reading if this is the affected field
325
+ if self._injected_anomaly and self._injected_anomaly["field"] == field:
326
+ anom_date = date.fromisoformat(self._injected_anomaly["recorded_at"])
327
+ history.append((anom_date, self._injected_anomaly["value"]))
328
+
329
+ history.append((EPISODE_DATE, float(t5_value)))
330
+ history.sort(key=lambda x: x[0])
331
+
332
+ lines = []
333
+ for i, (d, v) in enumerate(history):
334
+ suffix = " ← latest" if i == len(history) - 1 else ""
335
+ lines.append(f" {v} (recorded: {d}){suffix}")
336
+
337
+ result = (
338
+ f"{field} measurement history for {patient_id} "
339
+ f"(filing date: {EPISODE_DATE}):\n" + "\n".join(lines)
340
+ )
341
+
342
+ # Check for anomaly between consecutive entries within window
343
+ for i in range(len(history) - 1):
344
+ d1, v1 = history[i]
345
+ d2, v2 = history[i + 1]
346
+ days_apart = (d2 - d1).days
347
+ if days_apart <= ANOMALY_WINDOW_DAYS and v1 > 0:
348
+ change = abs(v2 - v1) / v1
349
+ if change >= ANOMALY_THRESHOLD:
350
+ pct = round(change * 100, 1)
351
+ result += (
352
+ f"\n⚠️ ANOMALY DETECTED: {v1} ({d1}) β†’ {v2} ({d2}), "
353
+ f"{days_apart} days apart, {pct}% delta. "
354
+ f"Recommend confirmatory test before filing."
355
+ )
356
+
357
+ return result
358
+
359
+ def _handle_record_value(self, action: PranaAction) -> PranaObservation:
360
+ field = (action.field or "").lower()
361
+ value = action.value
362
+
363
+ if not field or value is None:
364
  return PranaObservation(
365
+ query_result="ERROR: field and value are required for record_value.",
366
  active_task=self._active_task,
367
+ recorded_fields=self._patient_record.copy(),
368
  done=False,
369
  reward=0.0,
370
  )
371
 
372
+ self._patient_record[field] = self._make_entry(value, EPISODE_DATE)
373
+ logger.info(f"{tag} record_value field={field} value={value}")
374
+
375
+ required_fresh = sum(
376
+ 1 for f in KARS_REQUIRED_FIELDS
377
+ if f in self._patient_record and self._patient_record[f].get("value") is not None
378
+ )
379
  return PranaObservation(
380
+ query_result=(
381
+ f"RECORDED: {field} = {value} (as of {EPISODE_DATE}). "
382
+ f"Record has {required_fresh}/{len(KARS_REQUIRED_FIELDS)} required fields."
383
+ ),
384
  active_task=self._active_task,
385
+ recorded_fields=self._patient_record.copy(),
386
  done=False,
387
  reward=0.0,
388
  )
389
 
390
+ def _handle_file_report(self, action: PranaAction) -> PranaObservation:
391
+ self._file_attempts += 1
392
+ passed, issues = kars_validate(self._patient_record)
393
+
394
+ logger.info(
395
+ f"{tag} file_report attempt={self._file_attempts} "
396
+ f"passed={passed} issues={issues}"
397
+ )
398
+
399
+ if passed:
400
+ reward = 15.0 if self._file_attempts == 1 else 10.0
401
+ logger.info(f"{tag} KARS PASSED reward={reward}")
402
+ return PranaObservation(
403
+ query_result="KARS PASSED. SRTR report accepted.",
404
+ active_task=self._active_task,
405
+ kars_result="PASSED",
406
+ missing_fields=[],
407
+ recorded_fields=self._patient_record.copy(),
408
+ done=True,
409
+ reward=reward,
410
+ )
411
+
412
+ if self._file_attempts >= MAX_FILE_ATTEMPTS:
413
+ logger.warning(f"{tag} KARS FAILED unrecoverable after {self._file_attempts} attempts")
414
+ return PranaObservation(
415
+ query_result=f"KARS FAILED (unrecoverable). Issues: {issues}",
416
+ active_task=self._active_task,
417
+ kars_result="FAILED",
418
+ missing_fields=issues,
419
+ recorded_fields=self._patient_record.copy(),
420
+ done=True,
421
+ reward=-10.0,
422
+ )
423
+
424
+ logger.info(f"{tag} KARS FAILED recoverable issues={issues}")
425
+ return PranaObservation(
426
+ query_result=f"KARS FAILED. Issues: {issues}. Fix and file again.",
427
+ active_task=self._active_task,
428
+ kars_result="FAILED",
429
+ missing_fields=issues,
430
+ recorded_fields=self._patient_record.copy(),
431
+ done=False,
432
+ reward=-5.0,
433
+ )
434
+
435
  @property
436
  def state(self) -> State:
437
  return self._state
test_agent.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PRANA-Env agent with full minimal RL loop.
3
+
4
+ The LLM agent must:
5
+ 1. query_db β€” retrieve required fields from PatientDB
6
+ 2. record_value β€” write each field into the episode record
7
+ 3. file_report β€” submit to KARS validator β†’ reward β†’ done
8
+
9
+ Reward signal:
10
+ +15 KARS PASSED first attempt
11
+ +10 KARS PASSED after correction
12
+ -1 redundant query (field already recorded)
13
+ -5 filed with missing fields (recoverable)
14
+ -10 unrecoverable failure
15
+ """
16
+
17
+ import json
18
+ import openai
19
+ from dataclasses import dataclass, field
20
+ from typing import Optional
21
+ from prana_env.client import PranaEnv
22
+ from prana_env.models import PranaAction
23
+
24
+ # ── Tool definitions ──────────────────────────────────────────────────────────
25
+
26
+ TOOLS = [
27
+ {
28
+ "type": "function",
29
+ "function": {
30
+ "name": "query_db",
31
+ "description": "Retrieve a specific field from a clinical datastore for a patient.",
32
+ "parameters": {
33
+ "type": "object",
34
+ "properties": {
35
+ "target": {"type": "string", "description": "PatientDB | ClinicalNotesDB | PharmacyDB | WaitlistDB"},
36
+ "field": {"type": "string", "description": "Field name (e.g. hba1c, gfr, creatinine, blood_type)"},
37
+ "patient_id": {"type": "string", "description": "Patient identifier (e.g. P001)"},
38
+ },
39
+ "required": ["target", "field", "patient_id"],
40
+ },
41
+ },
42
+ },
43
+ {
44
+ "type": "function",
45
+ "function": {
46
+ "name": "record_value",
47
+ "description": "Write a retrieved field value into the episode patient record.",
48
+ "parameters": {
49
+ "type": "object",
50
+ "properties": {
51
+ "field": {"type": "string", "description": "Field name to record"},
52
+ "value": {"type": "string", "description": "Value to record"},
53
+ "source": {"type": "string", "description": "Datastore the value came from"},
54
+ },
55
+ "required": ["field", "value"],
56
+ },
57
+ },
58
+ },
59
+ {
60
+ "type": "function",
61
+ "function": {
62
+ "name": "file_report",
63
+ "description": (
64
+ "Submit the compiled patient record to the KARS validator. "
65
+ "Returns PASSED (done) or FAILED with missing fields. "
66
+ "Call only after recording all required fields: hba1c, gfr, creatinine, blood_type."
67
+ ),
68
+ "parameters": {"type": "object", "properties": {}, "required": []},
69
+ },
70
+ },
71
+ ]
72
+
73
+ SYSTEM_PROMPT = """You are a kidney transplant administrative agent.
74
+
75
+ Your goal is to compile a complete patient record and file a KARS-compliant SRTR report.
76
+
77
+ Required fields: hba1c, gfr, creatinine, blood_type (all from PatientDB).
78
+
79
+ KARS Recency Policy:
80
+ - Time-sensitive fields (hba1c, gfr, creatinine) must be recorded within 90 days of the filing date.
81
+ - Stable fields (blood_type) have no recency requirement.
82
+ - The episode starts with a pre-existing T1 record (~4 months old). These values are STALE.
83
+ - You must re-query and re-record hba1c, gfr, and creatinine before filing.
84
+ - Do NOT re-query blood_type β€” it is stable and already valid.
85
+
86
+ Workflow:
87
+ 1. Check recorded_fields in the observation β€” identify stale time-sensitive fields.
88
+ 2. Use query_db to retrieve fresh values for stale fields only.
89
+ 3. Use record_value to write each fresh value into the patient record.
90
+ 4. Use file_report to submit. If it fails due to stale or missing fields, fix and retry.
91
+
92
+ Do not guess values. Always query before recording."""
93
+
94
+ # ── Trajectory dataclass ──────────────────────────────────────────────────────
95
+
96
+ @dataclass
97
+ class Step:
98
+ action: dict
99
+ observation: str
100
+ reward: float
101
+ done: bool
102
+
103
+ @dataclass
104
+ class Trajectory:
105
+ episode_id: str
106
+ steps: list[Step] = field(default_factory=list)
107
+
108
+ @property
109
+ def total_reward(self) -> float:
110
+ return sum(s.reward for s in self.steps)
111
+
112
+ def __repr__(self):
113
+ terminal = next((s for s in reversed(self.steps) if s.done), None)
114
+ kars = terminal.observation if terminal else "incomplete"
115
+ return (
116
+ f"Trajectory(episode={self.episode_id}, "
117
+ f"steps={len(self.steps)}, "
118
+ f"total_reward={self.total_reward}, "
119
+ f"outcome={kars!r})"
120
+ )
121
+
122
+ # ── RL primitives ─────────────────────────────────────────────────────────────
123
+
124
+ def reset(env: PranaEnv, patient_id: str) -> str:
125
+ result = env.reset(patient_id=patient_id)
126
+ return result.observation.query_result
127
+
128
+
129
+ def step(env: PranaEnv, action_type: str, **kwargs) -> tuple[str, float, bool, list]:
130
+ result = env.step(PranaAction(action_type=action_type, **kwargs))
131
+ obs = result.observation
132
+ return (
133
+ obs.query_result,
134
+ obs.reward or 0.0,
135
+ obs.done or False,
136
+ obs.missing_fields or [],
137
+ )
138
+
139
+
140
+ def rollout(env: PranaEnv, task: str, patient_id: str, episode_id: str, max_turns: int = 20) -> Trajectory:
141
+ """Run one full episode. LLM drives the action loop until done=True."""
142
+ llm = openai.OpenAI()
143
+ messages = [
144
+ {"role": "system", "content": SYSTEM_PROMPT},
145
+ {"role": "user", "content": task},
146
+ ]
147
+ trajectory = Trajectory(episode_id=episode_id)
148
+
149
+ print(f"\n── Episode {episode_id} ──────────────────────────────")
150
+ print(f"Task: {task}")
151
+
152
+ initial_obs = reset(env, patient_id)
153
+ print(f"[reset] {initial_obs}")
154
+
155
+ for turn in range(max_turns):
156
+ response = llm.chat.completions.create(
157
+ model="gpt-4o",
158
+ tools=TOOLS,
159
+ messages=messages,
160
+ )
161
+ msg = response.choices[0].message
162
+ messages.append(msg)
163
+
164
+ # No tool calls β†’ LLM finished without filing (shouldn't happen with good prompt)
165
+ if msg.tool_calls is None:
166
+ print(f"[turn {turn+1}] Agent: {msg.content}")
167
+ trajectory.steps.append(Step(
168
+ action={"type": "end_turn"},
169
+ observation=msg.content or "",
170
+ reward=0.0,
171
+ done=True,
172
+ ))
173
+ break
174
+
175
+ for tool_call in msg.tool_calls:
176
+ action_type = tool_call.function.name
177
+ inp = json.loads(tool_call.function.arguments)
178
+ print(f"[turn {turn+1}] {action_type}({json.dumps(inp)})")
179
+
180
+ obs_str, reward, done, missing = step(env, action_type, **inp)
181
+ print(f"[turn {turn+1}] obs={obs_str!r} reward={reward} done={done}")
182
+
183
+ trajectory.steps.append(Step(
184
+ action={"type": action_type, **inp},
185
+ observation=obs_str,
186
+ reward=reward,
187
+ done=done,
188
+ ))
189
+
190
+ messages.append({
191
+ "role": "tool",
192
+ "tool_call_id": tool_call.id,
193
+ "content": obs_str,
194
+ })
195
+
196
+ if done:
197
+ return trajectory
198
+
199
+ return trajectory
200
+
201
+
202
+ def run_episodes(task: str, patient_id: str, n: int = 1) -> list[Trajectory]:
203
+ """Run N independent episodes. Set n=8 for GRPO rollout batch."""
204
+ trajectories = []
205
+ with PranaEnv(base_url="http://localhost:8000") as env:
206
+ for i in range(n):
207
+ traj = rollout(env, task, patient_id, episode_id=f"ep_{i+1}")
208
+ trajectories.append(traj)
209
+
210
+ print(f"\n── Summary ({n} episode(s)) ──────────────────────────")
211
+ for t in trajectories:
212
+ print(f" {t}")
213
+ return trajectories
214
+
215
+
216
+ # ── Entry point ───────────────────────────────────────────────────────────────
217
+
218
+ if __name__ == "__main__":
219
+ run_episodes(
220
+ task=(
221
+ "File a KARS-compliant SRTR report for patient P001. "
222
+ "A T1 record exists from 4 months ago. "
223
+ "Check which fields are stale, re-query only what's needed, and file."
224
+ ),
225
+ patient_id="P001",
226
+ n=1, # set n=8 for GRPO rollout batch
227
+ )
test_client.py CHANGED
@@ -1,13 +1,16 @@
 
1
  from prana_env.client import PranaEnv
2
  from prana_env.models import PranaAction
3
 
4
- with PranaEnv(base_url="http://localhost:8000") as client:
5
- client.reset()
6
- result = client.step(PranaAction(action_type="query_db",
7
- target="PatientDB",
8
- field="hba1c",
9
- patient_id="P001",
10
- ))
11
- print(result.observation.query_result)
12
-
 
13
 
 
 
1
+ import asyncio
2
  from prana_env.client import PranaEnv
3
  from prana_env.models import PranaAction
4
 
5
+ async def main():
6
+ async with PranaEnv(base_url="http://localhost:8000") as client:
7
+ await client.reset()
8
+ result = await client.step(PranaAction(
9
+ action_type="query_db",
10
+ target="PatientDB",
11
+ field="hba1c",
12
+ patient_id="P001",
13
+ ))
14
+ print(result.observation.query_result)
15
 
16
+ asyncio.run(main())