adtserapio commited on
Commit
b75748e
Β·
verified Β·
1 Parent(s): 81c9ec1

Upload notebooks/ehrgym_grpo_training.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. notebooks/ehrgym_grpo_training.ipynb +21 -24
notebooks/ehrgym_grpo_training.ipynb CHANGED
@@ -7,10 +7,10 @@
7
  "source": [
8
  "# EHRGym GRPO Training with TRL + OpenEnv\n",
9
  "\n",
10
- "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/openenv-community/EHRGym/blob/main/notebooks/ehrgym_grpo_training.ipynb)\n",
11
  "[![HF Space](https://img.shields.io/badge/%F0%9F%8F%A5-EHRGym%20Space-blue)](https://huggingface.co/spaces/openenv-community/EHRGym)\n",
12
  "\n",
13
- "Train a language model to operate an Epic-style Electronic Health Records (EHR) system using **GRPO** (Group Relative Policy Optimization) via [TRL](https://github.com/huggingface/trl) and the [OpenEnv](https://github.com/meta-pytorch/OpenEnv) framework.\n",
14
  "\n",
15
  "The agent learns to:\n",
16
  "- Navigate an EHR interface (patient charts, labs, notes, orders)\n",
@@ -18,21 +18,18 @@
18
  "- Write SOAP-style clinical notes\n",
19
  "- Sign encounters to complete clinical workflows\n",
20
  "\n",
21
- "**Architecture:**\n",
22
- "```\n",
23
- "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n",
24
- "β”‚ GRPOTrainer (GPU) β”‚\n",
25
- "β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β” β”‚\n",
26
- "β”‚ β”‚ Model β”‚β†’ β”‚ rollout β”‚β†’ β”‚EHRGym β”‚ β”‚\n",
27
- "β”‚ β”‚(Qwen3) │← β”‚ func │← β”‚Env β”‚ β”‚\n",
28
- "β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”¬β”€β”€β”€β”€β”˜ β”‚\n",
29
- "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”˜\n",
30
- " β”‚ HTTP\n",
31
- "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”\n",
32
- "β”‚ EHRGym Server (Docker/Space) β–Ό β”‚\n",
33
- "β”‚ FastAPI β†’ Playwright β†’ Next.js EHR β”‚\n",
34
- "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n",
35
- "```"
36
  ]
37
  },
38
  {
@@ -42,7 +39,7 @@
42
  "source": [
43
  "## Install dependencies\n",
44
  "\n",
45
- "We install **TRL** with vLLM support, and **EHRGym** directly from the Hugging Face Space."
46
  ]
47
  },
48
  {
@@ -52,7 +49,7 @@
52
  "metadata": {},
53
  "outputs": [],
54
  "source": [
55
- "!pip install -Uq \"trl[vllm]\" git+https://huggingface.co/spaces/openenv-community/EHRGym trackio"
56
  ]
57
  },
58
  {
@@ -156,8 +153,8 @@
156
  "source": [
157
  "## Init model and tokenizer\n",
158
  "\n",
159
- "We use [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B), a lightweight model suitable for agent training. \n",
160
- "For better clinical task performance, scale up to larger models (e.g., Qwen3-8B or Qwen3-32B)."
161
  ]
162
  },
163
  {
@@ -169,7 +166,7 @@
169
  "source": [
170
  "from transformers import AutoTokenizer\n",
171
  "\n",
172
- "model_name = \"Qwen/Qwen3-1.7B\"\n",
173
  "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
174
  "tokenizer.pad_token = tokenizer.eos_token"
175
  ]
@@ -585,7 +582,7 @@
585
  "source": [
586
  "from trl import GRPOConfig\n",
587
  "\n",
588
- "output_dir = \"ehrgym-grpo-qwen3-1.7b\"\n",
589
  "\n",
590
  "grpo_config = GRPOConfig(\n",
591
  " # Training schedule\n",
@@ -851,7 +848,7 @@
851
  "\n",
852
  "## What's next?\n",
853
  "\n",
854
- "- **Scale up the model**: Try `Qwen/Qwen3-8B` or larger for better clinical reasoning\n",
855
  "- **More training steps**: Increase `dataset_size` and `num_train_epochs`\n",
856
  "- **Multi-GPU**: Use `vllm_mode=\"server\"` with `trl vllm-serve` for distributed training\n",
857
  "- **Local Docker**: Run EHRGym locally for faster episode throughput:\n",
 
7
  "source": [
8
  "# EHRGym GRPO Training with TRL + OpenEnv\n",
9
  "\n",
10
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adtserapio/EHRGym/blob/main/notebooks/ehrgym_grpo_training.ipynb)\n",
11
  "[![HF Space](https://img.shields.io/badge/%F0%9F%8F%A5-EHRGym%20Space-blue)](https://huggingface.co/spaces/openenv-community/EHRGym)\n",
12
  "\n",
13
+ "Train a language model to operate an Epic-style Electronic Health Records (EHR) system using **GRPO** (Group Relative Policy Optimization) via [TRL](https://github.com/huggingface/trl) and the [OpenEnv](https://huggingface.co/docs/trl/openenv) framework.\n",
14
  "\n",
15
  "The agent learns to:\n",
16
  "- Navigate an EHR interface (patient charts, labs, notes, orders)\n",
 
18
  "- Write SOAP-style clinical notes\n",
19
  "- Sign encounters to complete clinical workflows\n",
20
  "\n",
21
+ "**Architecture:** `GRPOTrainer` ➜ `rollout_func` ➜ `EHRGymEnv` ⟷ EHRGym Server (FastAPI + Playwright + Next.js EHR)\n",
22
+ "\n",
23
+ "<table><tr><td>\n",
24
+ "\n",
25
+ "| Component | Role |\n",
26
+ "|-----------|------|\n",
27
+ "| **GRPOTrainer** | Generates actions via vLLM, computes policy gradients |\n",
28
+ "| **rollout_func** | Orchestrates multi-turn episodes, builds `env_mask` |\n",
29
+ "| **EHRGymEnv** | HTTP client β€” sends browser actions to the server |\n",
30
+ "| **EHRGym Server** | FastAPI + Playwright driving a Next.js EHR app |\n",
31
+ "\n",
32
+ "</td></tr></table>"
 
 
 
33
  ]
34
  },
35
  {
 
39
  "source": [
40
  "## Install dependencies\n",
41
  "\n",
42
+ "We install **TRL** with vLLM support, and **EHRGym** directly from GitHub."
43
  ]
44
  },
45
  {
 
49
  "metadata": {},
50
  "outputs": [],
51
  "source": [
52
+ "!pip install -Uq \"trl[vllm]\" git+https://github.com/adtserapio/EHRGym.git trackio"
53
  ]
54
  },
55
  {
 
153
  "source": [
154
  "## Init model and tokenizer\n",
155
  "\n",
156
+ "We use [Qwen/Qwen3.5-2B](https://huggingface.co/Qwen/Qwen3.5-2B), a lightweight model suitable for agent training. \n",
157
+ "For better clinical task performance, scale up to larger models (e.g., Qwen3.5-7B or Qwen3.5-32B)."
158
  ]
159
  },
160
  {
 
166
  "source": [
167
  "from transformers import AutoTokenizer\n",
168
  "\n",
169
+ "model_name = \"Qwen/Qwen3.5-2B\"\n",
170
  "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
171
  "tokenizer.pad_token = tokenizer.eos_token"
172
  ]
 
582
  "source": [
583
  "from trl import GRPOConfig\n",
584
  "\n",
585
+ "output_dir = \"ehrgym-grpo-qwen3.5-2b\"\n",
586
  "\n",
587
  "grpo_config = GRPOConfig(\n",
588
  " # Training schedule\n",
 
848
  "\n",
849
  "## What's next?\n",
850
  "\n",
851
+ "- **Scale up the model**: Try `Qwen/Qwen3.5-7B` or larger for better clinical reasoning\n",
852
  "- **More training steps**: Increase `dataset_size` and `num_train_epochs`\n",
853
  "- **Multi-GPU**: Use `vllm_mode=\"server\"` with `trl vllm-serve` for distributed training\n",
854
  "- **Local Docker**: Run EHRGym locally for faster episode throughput:\n",