File size: 9,706 Bytes
ea339c4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | ---
license: apache-2.0
language:
- en
base_model:
- Qwen/Qwen2.5-7B-Instruct
tags:
- medical
- diagnosis
- RL
---
# DiagAgent-7B: RL-Optimized Diagnostic Agent
<div align="center">
<img src="https://raw.githubusercontent.com/MAGIC-AI4Med/DiagGym/main/assets/logo.png" width="150"/>
<div align="center"></div>
</div>
DiagAgent‑7B is a reinforcement learning‑optimized large language model for interactive, multi‑turn diagnostic reasoning. Unlike one‑shot medical LLMs, it can:
- recommend the most informative examinations,
- update the working diagnosis as new evidence arrives,
- decide when to stop and commit to a final diagnosis.
DiagAgent‑7B is trained end‑to‑end inside the `DiagGym` virtual clinical environment with multi‑turn RL (GRPO), enabling safe, closed‑loop learning without real‑world risk.
Details can be found in our paper https://arxiv.org/abs/2510.24654
## Quickstart
### Run locally with `transformers`
```py
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "Henrychur/DiagAgent-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2"
)
def chat(messages, max_new_tokens=1024, temperature=0.0):
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=(temperature > 0),
eos_token_id=tokenizer.eos_token_id
)
gen_ids = outputs[0][inputs.input_ids.shape[-1]:]
return tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
SYSTEM_PROMPT = (
"You are a medical AI assistant. Analyze patient information, suggest relevant tests, "
"and provide a final diagnosis when sufficient information is available.\n\n"
"RESPONSE FORMAT:\n"
"If more information is needed:\n"
"```\n"
"Current diagnosis: <your current best diagnosis>\n"
"Based on the patient's initial presentation, the following investigation(s) should be performed: <one additional test>\n"
"Reason: <reason for the test>\n"
"```\n"
"If sufficient information exists for diagnosis:\n"
"```\n"
"The available information is sufficient to make a diagnosis.\n"
"Diagnosis: <final diagnosis>\n"
"Reason: <brief justification>\n"
"```"
)
initial_inquiry = (
"- Patient Information: ___ y/o F\n"
"- Chief Complaint: Early satiety, weight loss, abdominal pain\n"
"- HPI: 1-month weight loss (10 lbs), early satiety, fatigue; prior emesis; reduced intake; denies fever/chills.\n"
"- PMH: Asthma, hyperlipidemia, HTN, osteoarthritis, polymyalgia rheumatica, CAD (NSTEMI), osteoporosis, H. pylori, s/p TAH/USO.\n"
"- Allergy: Lisinopril."
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": initial_inquiry}
]
print(chat(messages))
```
Tips:
- Use `device_map="auto"` to spread weights; for constrained VRAM consider 4‑bit loading (`bitsandbytes`).
- Keep `temperature` low (0–0.3) for more deterministic decision‑making.
### Serve with vLLM and use the OpenAI‑compatible Chat API
```bash
vllm serve Henrychur/DiagAgent-7B --served-model-name DiagAgent-7B
```
```py
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="NONE")
messages = [
{"role": "system", "content": "You are a medical AI assistant... (same format as above)"},
{"role": "user", "content": "Initial inquiry text..."}
]
resp = client.chat.completions.create(
model="DiagAgent-7B",
messages=messages,
temperature=0.0,
max_tokens=1024
)
print(resp.choices[0].message.content)
```
## Evaluation Results
The following tables are taken directly from the project evaluation. For evaluation details and scripts, see the paper and the GitHub repository.
**Single‑Turn Evaluation**
| Model | Size | Year | Hit Ratio | Diagnosis Accuracy |
|-----------------|-------|--------|-----------|--------------------|
| **Basic LLM** | | | | |
| GPT-4o | - | 2024.8 | 18.46 | 70.17 |
| Claude-4-sonnet | - | 2025.5 | 30.62 | 73.25 |
| Qwen2.5 | 72B | 2024.9 | 17.61 | 72.72 |
| Llama3.3 | 70B | 2024.12| 18.29 | 64.65 |
| DeepSeek-v3 | 671B | 2025.3 | 19.14 | 69.64 |
| Qwen3 | 235B | 2025.7 | 19.82 | 70.81 |
| GPT-OSS | 120B | 2025.8 | 16.71 | 66.67 |
| OpenBioLLM | 70B | 2024.4 | 22.94 | 64.01 |
| Baichuan-M1 | 14B | 2025.2 | 18.50 | 77.39 |
| MedGemma | 27B | 2025.7 | 27.07 | 68.90 |
| **Agentic System** | | | | |
| MedAgents | - | 2024.1 | 18.31 | 68.26 |
| MDAgents | - | 2024.10| 19.24 | 71.66 |
| **Our Method** | | | | |
| DiagAgent | 7B | - | **71.12** | 85.03 |
| DiagAgent | 8B | - | 54.61 | 81.10 |
| DiagAgent | 14B | - | 67.54 | **86.73** |
**End‑to‑End Evaluation (Automatic Metrics)**
| Model | Size | Year | Avg. Turns | Precision | Recall | F1 | Accuracy |
|-----------------|-------|--------|------------|-----------|--------|--------|----------|
| **Basic LLM** | | | | | | | |
| GPT-4o | - | 2024.8 | 3.30 | 29.02 | 14.25 | 19.12 | 41.44 |
| Claude-4-sonnet | - | 2025.5 | 3.91 | 33.06 | 23.41 | 27.41 | 45.14 |
| Qwen2.5 | 72B | 2024.9 | 2.47 | 32.67 | 11.09 | 16.56 | 32.66 |
| Llama3.3 | 70B | 2024.12| 4.25 | 26.53 | 20.52 | 23.14 | 36.79 |
| DeepSeek-v3 | 671B | 2025.3 | 2.49 | 32.60 | 12.20 | 17.75 | 46.51 |
| Qwen3 | 235B | 2025.7 | 3.34 | 27.71 | 17.40 | 21.38 | 44.19 |
| GPT-OSS | 120B | 2025.8 | 4.08 | 25.94 | 16.16 | 19.92 | 43.02 |
| OpenBioLLM | 70B | 2024.4 | 2.59 | 31.94 | 13.75 | 19.22 | 32.88 |
| Baichuan-M1 | 14B | 2025.2 | 2.30 | 28.82 | 10.98 | 15.90 | 32.35 |
| MedGemma | 27B | 2025.7 | 4.10 | 32.50 | 20.04 | 24.80 | 42.39 |
| **Agentic System** | | | | | | | |
| MedAgents | - | 2024.1 | 2.31 | 29.86 | 11.21 | 16.30 | 44.40 |
| MDAgents | - | 2024.10| 2.40 | 30.12 | 11.25 | 16.38 | 43.55 |
| **Our Method** | | | | | | | |
| DiagAgent | 7B | - | 5.47 | **46.08** | 47.66 | 46.86 | 60.78 |
| DiagAgent | 8B | - | 5.71 | 41.58 | 44.56 | 43.02 | 53.85 |
| DiagAgent | 14B | - | 6.77 | 43.87 | **52.74** | **47.89** | **61.63** |
**End‑to‑End Evaluation (Rubric-based Metrics)**
| Model | Size | Year | Rubric Score (%) |
|------------------|------|--------|------------------|
| **Basic LLM** | | | |
| GPT-4o | - | 2024.8 | 18.02 |
| Claude-4-sonnet | - | 2025.5 | 25.84 |
| Qwen2.5 | 72B | 2024.9 | 15.63 |
| Llama3.3 | 70B | 2024.12| 16.69 |
| DeepSeek-v3 | - | - | 19.07 |
| Qwen3 | 235B | 2025.7 | 24.49 |
| GPT-OSS | 120B | 2025.8 | 22.26 |
| OpenbioLLM | 70B | 2024.4 | 14.11 |
| Baichuan-M1 | 14B | 2025.2 | 17.43 |
| MedGemma | 27B | 2025.7 | 20.72 |
| **Agentic System** | | | |
| MedAgent | - | 2024.1 | 19.49 |
| MDAgent | - | 2024.10| 21.64 |
| **Our Method** | | | |
| DiagAgent-14B | 14B | - | 32.86 |
## Training Details
DiagAgent‑7B is optimized with multi‑turn RL (GRPO) inside `DiagGym`.
- Trajectory construction:
- Initial inquiry (structured patient history without final diagnosis)
- Iterative steps: preliminary diagnosis → recommended exam + rationale → exam result
- Final diagnosis focused on a single primary condition
- Reward design:
- Diagnosis Accuracy (exact match)
- Examination Recommendation F1 (overlap with reference EHR exams)
- Turn Penalty (discourages >12 interaction turns)
For implementation and scripts, see `DiagAgent/train/rl/` in the GitHub.
## Citation
```
@misc{qiu2025evolvingdiagnosticagentsvirtual,
title={Evolving Diagnostic Agents in a Virtual Clinical Environment},
author={Pengcheng Qiu and Chaoyi Wu and Junwei Liu and Qiaoyu Zheng and Yusheng Liao and Haowen Wang and Yun Yue and Qianrui Fan and Shuai Zhen and Jian Wang and Jinjie Gu and Yanfeng Wang and Ya Zhang and Weidi Xie},
year={2025},
eprint={2510.24654},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2510.24654},
}
```
## Contact
- Email: henrychur@sjtu.edu.cn
- GitHub: https://github.com/MAGIC-AI4Med/DiagGym |