Spaces:
Sleeping
Sleeping
Upload notebooks/ehrgym_grpo_training.ipynb with huggingface_hub
Browse files
notebooks/ehrgym_grpo_training.ipynb
CHANGED
|
@@ -7,10 +7,10 @@
|
|
| 7 |
"source": [
|
| 8 |
"# EHRGym GRPO Training with TRL + OpenEnv\n",
|
| 9 |
"\n",
|
| 10 |
-
"[](https://colab.research.google.com/github/
|
| 11 |
"[](https://huggingface.co/spaces/openenv-community/EHRGym)\n",
|
| 12 |
"\n",
|
| 13 |
-
"Train a language model to operate an Epic-style Electronic Health Records (EHR) system using **GRPO** (Group Relative Policy Optimization) via [TRL](https://github.com/huggingface/trl) and the [OpenEnv](https://
|
| 14 |
"\n",
|
| 15 |
"The agent learns to:\n",
|
| 16 |
"- Navigate an EHR interface (patient charts, labs, notes, orders)\n",
|
|
@@ -18,21 +18,18 @@
|
|
| 18 |
"- Write SOAP-style clinical notes\n",
|
| 19 |
"- Sign encounters to complete clinical workflows\n",
|
| 20 |
"\n",
|
| 21 |
-
"**Architecture:**\n",
|
| 22 |
-
"
|
| 23 |
-
"
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
-
"
|
| 27 |
-
"
|
| 28 |
-
"
|
| 29 |
-
"
|
| 30 |
-
"
|
| 31 |
-
"
|
| 32 |
-
"
|
| 33 |
-
"β FastAPI β Playwright β Next.js EHR β\n",
|
| 34 |
-
"ββββββββββββββββββββββββββββββββββββββββββ\n",
|
| 35 |
-
"```"
|
| 36 |
]
|
| 37 |
},
|
| 38 |
{
|
|
@@ -42,7 +39,7 @@
|
|
| 42 |
"source": [
|
| 43 |
"## Install dependencies\n",
|
| 44 |
"\n",
|
| 45 |
-
"We install **TRL** with vLLM support, and **EHRGym** directly from
|
| 46 |
]
|
| 47 |
},
|
| 48 |
{
|
|
@@ -52,7 +49,7 @@
|
|
| 52 |
"metadata": {},
|
| 53 |
"outputs": [],
|
| 54 |
"source": [
|
| 55 |
-
"!pip install -Uq \"trl[vllm]\" git+https://
|
| 56 |
]
|
| 57 |
},
|
| 58 |
{
|
|
@@ -156,8 +153,8 @@
|
|
| 156 |
"source": [
|
| 157 |
"## Init model and tokenizer\n",
|
| 158 |
"\n",
|
| 159 |
-
"We use [Qwen/Qwen3
|
| 160 |
-
"For better clinical task performance, scale up to larger models (e.g., Qwen3-
|
| 161 |
]
|
| 162 |
},
|
| 163 |
{
|
|
@@ -169,7 +166,7 @@
|
|
| 169 |
"source": [
|
| 170 |
"from transformers import AutoTokenizer\n",
|
| 171 |
"\n",
|
| 172 |
-
"model_name = \"Qwen/Qwen3
|
| 173 |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
| 174 |
"tokenizer.pad_token = tokenizer.eos_token"
|
| 175 |
]
|
|
@@ -585,7 +582,7 @@
|
|
| 585 |
"source": [
|
| 586 |
"from trl import GRPOConfig\n",
|
| 587 |
"\n",
|
| 588 |
-
"output_dir = \"ehrgym-grpo-qwen3
|
| 589 |
"\n",
|
| 590 |
"grpo_config = GRPOConfig(\n",
|
| 591 |
" # Training schedule\n",
|
|
@@ -851,7 +848,7 @@
|
|
| 851 |
"\n",
|
| 852 |
"## What's next?\n",
|
| 853 |
"\n",
|
| 854 |
-
"- **Scale up the model**: Try `Qwen/Qwen3-
|
| 855 |
"- **More training steps**: Increase `dataset_size` and `num_train_epochs`\n",
|
| 856 |
"- **Multi-GPU**: Use `vllm_mode=\"server\"` with `trl vllm-serve` for distributed training\n",
|
| 857 |
"- **Local Docker**: Run EHRGym locally for faster episode throughput:\n",
|
|
|
|
| 7 |
"source": [
|
| 8 |
"# EHRGym GRPO Training with TRL + OpenEnv\n",
|
| 9 |
"\n",
|
| 10 |
+
"[](https://colab.research.google.com/github/adtserapio/EHRGym/blob/main/notebooks/ehrgym_grpo_training.ipynb)\n",
|
| 11 |
"[](https://huggingface.co/spaces/openenv-community/EHRGym)\n",
|
| 12 |
"\n",
|
| 13 |
+
"Train a language model to operate an Epic-style Electronic Health Records (EHR) system using **GRPO** (Group Relative Policy Optimization) via [TRL](https://github.com/huggingface/trl) and the [OpenEnv](https://huggingface.co/docs/trl/openenv) framework.\n",
|
| 14 |
"\n",
|
| 15 |
"The agent learns to:\n",
|
| 16 |
"- Navigate an EHR interface (patient charts, labs, notes, orders)\n",
|
|
|
|
| 18 |
"- Write SOAP-style clinical notes\n",
|
| 19 |
"- Sign encounters to complete clinical workflows\n",
|
| 20 |
"\n",
|
| 21 |
+
"**Architecture:** `GRPOTrainer` β `rollout_func` β `EHRGymEnv` β· EHRGym Server (FastAPI + Playwright + Next.js EHR)\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"<table><tr><td>\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"| Component | Role |\n",
|
| 26 |
+
"|-----------|------|\n",
|
| 27 |
+
"| **GRPOTrainer** | Generates actions via vLLM, computes policy gradients |\n",
|
| 28 |
+
"| **rollout_func** | Orchestrates multi-turn episodes, builds `env_mask` |\n",
|
| 29 |
+
"| **EHRGymEnv** | HTTP client β sends browser actions to the server |\n",
|
| 30 |
+
"| **EHRGym Server** | FastAPI + Playwright driving a Next.js EHR app |\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"</td></tr></table>"
|
|
|
|
|
|
|
|
|
|
| 33 |
]
|
| 34 |
},
|
| 35 |
{
|
|
|
|
| 39 |
"source": [
|
| 40 |
"## Install dependencies\n",
|
| 41 |
"\n",
|
| 42 |
+
"We install **TRL** with vLLM support, and **EHRGym** directly from GitHub."
|
| 43 |
]
|
| 44 |
},
|
| 45 |
{
|
|
|
|
| 49 |
"metadata": {},
|
| 50 |
"outputs": [],
|
| 51 |
"source": [
|
| 52 |
+
"!pip install -Uq \"trl[vllm]\" git+https://github.com/adtserapio/EHRGym.git trackio"
|
| 53 |
]
|
| 54 |
},
|
| 55 |
{
|
|
|
|
| 153 |
"source": [
|
| 154 |
"## Init model and tokenizer\n",
|
| 155 |
"\n",
|
| 156 |
+
"We use [Qwen/Qwen3.5-2B](https://huggingface.co/Qwen/Qwen3.5-2B), a lightweight model suitable for agent training. \n",
|
| 157 |
+
"For better clinical task performance, scale up to larger models (e.g., Qwen3.5-7B or Qwen3.5-32B)."
|
| 158 |
]
|
| 159 |
},
|
| 160 |
{
|
|
|
|
| 166 |
"source": [
|
| 167 |
"from transformers import AutoTokenizer\n",
|
| 168 |
"\n",
|
| 169 |
+
"model_name = \"Qwen/Qwen3.5-2B\"\n",
|
| 170 |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
| 171 |
"tokenizer.pad_token = tokenizer.eos_token"
|
| 172 |
]
|
|
|
|
| 582 |
"source": [
|
| 583 |
"from trl import GRPOConfig\n",
|
| 584 |
"\n",
|
| 585 |
+
"output_dir = \"ehrgym-grpo-qwen3.5-2b\"\n",
|
| 586 |
"\n",
|
| 587 |
"grpo_config = GRPOConfig(\n",
|
| 588 |
" # Training schedule\n",
|
|
|
|
| 848 |
"\n",
|
| 849 |
"## What's next?\n",
|
| 850 |
"\n",
|
| 851 |
+
"- **Scale up the model**: Try `Qwen/Qwen3.5-7B` or larger for better clinical reasoning\n",
|
| 852 |
"- **More training steps**: Increase `dataset_size` and `num_train_epochs`\n",
|
| 853 |
"- **Multi-GPU**: Use `vllm_mode=\"server\"` with `trl vllm-serve` for distributed training\n",
|
| 854 |
"- **Local Docker**: Run EHRGym locally for faster episode throughput:\n",
|