File size: 9,706 Bytes
ea339c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
---
license: apache-2.0
language:
- en
base_model:
- Qwen/Qwen2.5-7B-Instruct
tags:
- medical
- diagnosis
- RL
---
# DiagAgent-7B: RL-Optimized Diagnostic Agent

<div align="center">
  <img src="https://raw.githubusercontent.com/MAGIC-AI4Med/DiagGym/main/assets/logo.png" width="150"/>
  <div align="center"></div>
</div>


DiagAgent‑7B is a reinforcement learning‑optimized large language model for interactive, multi‑turn diagnostic reasoning. Unlike one‑shot medical LLMs, it can:
- recommend the most informative examinations,
- update the working diagnosis as new evidence arrives,
- decide when to stop and commit to a final diagnosis.

DiagAgent‑7B is trained end‑to‑end inside the `DiagGym` virtual clinical environment with multi‑turn RL (GRPO), enabling safe, closed‑loop learning without real‑world risk.

Details can be found in our paper https://arxiv.org/abs/2510.24654

## Quickstart

### Run locally with `transformers`

```py
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "Henrychur/DiagAgent-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2"
)

def chat(messages, max_new_tokens=1024, temperature=0.0):
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=(temperature > 0),
        eos_token_id=tokenizer.eos_token_id
    )
    gen_ids = outputs[0][inputs.input_ids.shape[-1]:]
    return tokenizer.decode(gen_ids, skip_special_tokens=True).strip()

SYSTEM_PROMPT = (
    "You are a medical AI assistant. Analyze patient information, suggest relevant tests, "
    "and provide a final diagnosis when sufficient information is available.\n\n"
    "RESPONSE FORMAT:\n"
    "If more information is needed:\n"
    "```\n"
    "Current diagnosis: <your current best diagnosis>\n"
    "Based on the patient's initial presentation, the following investigation(s) should be performed: <one additional test>\n"
    "Reason: <reason for the test>\n"
    "```\n"
    "If sufficient information exists for diagnosis:\n"
    "```\n"
    "The available information is sufficient to make a diagnosis.\n"
    "Diagnosis: <final diagnosis>\n"
    "Reason: <brief justification>\n"
    "```"
)

initial_inquiry = (
    "- Patient Information: ___ y/o F\n"
    "- Chief Complaint: Early satiety, weight loss, abdominal pain\n"
    "- HPI: 1-month weight loss (10 lbs), early satiety, fatigue; prior emesis; reduced intake; denies fever/chills.\n"
    "- PMH: Asthma, hyperlipidemia, HTN, osteoarthritis, polymyalgia rheumatica, CAD (NSTEMI), osteoporosis, H. pylori, s/p TAH/USO.\n"
    "- Allergy: Lisinopril."
)

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": initial_inquiry}
]

print(chat(messages))
```

Tips:
- Use `device_map="auto"` to spread weights; for constrained VRAM consider 4‑bit loading (`bitsandbytes`).
- Keep `temperature` low (0–0.3) for more deterministic decision‑making.

### Serve with vLLM and use the OpenAI‑compatible Chat API

```bash
vllm serve Henrychur/DiagAgent-7B --served-model-name DiagAgent-7B
```

```py
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="NONE")

messages = [
    {"role": "system", "content": "You are a medical AI assistant... (same format as above)"},
    {"role": "user", "content": "Initial inquiry text..."}
]

resp = client.chat.completions.create(
    model="DiagAgent-7B",
    messages=messages,
    temperature=0.0,
    max_tokens=1024
)
print(resp.choices[0].message.content)
```



## Evaluation Results

The following tables are taken directly from the project evaluation. For evaluation details and scripts, see the paper and the GitHub repository.

**Single‑Turn Evaluation**

| Model           | Size  | Year   | Hit Ratio | Diagnosis Accuracy |
|-----------------|-------|--------|-----------|--------------------|
| **Basic LLM**   |       |        |           |                    |
| GPT-4o          | -     | 2024.8 | 18.46     | 70.17              |
| Claude-4-sonnet | -     | 2025.5 | 30.62     | 73.25              |
| Qwen2.5         | 72B   | 2024.9 | 17.61     | 72.72              |
| Llama3.3        | 70B   | 2024.12| 18.29     | 64.65              |
| DeepSeek-v3     | 671B  | 2025.3 | 19.14     | 69.64              |
| Qwen3           | 235B  | 2025.7 | 19.82     | 70.81              |
| GPT-OSS         | 120B  | 2025.8 | 16.71     | 66.67              |
| OpenBioLLM      | 70B   | 2024.4 | 22.94     | 64.01              |
| Baichuan-M1     | 14B   | 2025.2 | 18.50     | 77.39              |
| MedGemma        | 27B   | 2025.7 | 27.07     | 68.90              |
| **Agentic System** |    |        |           |                    |
| MedAgents       | -     | 2024.1 | 18.31     | 68.26              |
| MDAgents        | -     | 2024.10| 19.24     | 71.66              |
| **Our Method**  |       |        |           |                    |
| DiagAgent       | 7B    | -      | **71.12** | 85.03              |
| DiagAgent       | 8B    | -      | 54.61     | 81.10              |
| DiagAgent       | 14B   | -      | 67.54     | **86.73**          |

**End‑to‑End Evaluation (Automatic Metrics)**

| Model           | Size  | Year   | Avg. Turns | Precision | Recall | F1     | Accuracy |
|-----------------|-------|--------|------------|-----------|--------|--------|----------|
| **Basic LLM**   |       |        |            |           |        |        |          |
| GPT-4o          | -     | 2024.8 | 3.30       | 29.02     | 14.25  | 19.12  | 41.44    |
| Claude-4-sonnet | -     | 2025.5 | 3.91       | 33.06     | 23.41  | 27.41  | 45.14    |
| Qwen2.5         | 72B   | 2024.9 | 2.47       | 32.67     | 11.09  | 16.56  | 32.66    |
| Llama3.3        | 70B   | 2024.12| 4.25       | 26.53     | 20.52  | 23.14  | 36.79    |
| DeepSeek-v3     | 671B  | 2025.3 | 2.49       | 32.60     | 12.20  | 17.75  | 46.51    |
| Qwen3           | 235B  | 2025.7 | 3.34       | 27.71     | 17.40  | 21.38  | 44.19    |
| GPT-OSS         | 120B  | 2025.8 | 4.08       | 25.94     | 16.16  | 19.92  | 43.02    |
| OpenBioLLM      | 70B   | 2024.4 | 2.59       | 31.94     | 13.75  | 19.22  | 32.88    |
| Baichuan-M1     | 14B   | 2025.2 | 2.30       | 28.82     | 10.98  | 15.90  | 32.35    |
| MedGemma        | 27B   | 2025.7 | 4.10       | 32.50     | 20.04  | 24.80  | 42.39    |
| **Agentic System** |    |        |            |           |        |        |          |
| MedAgents       | -     | 2024.1 | 2.31       | 29.86     | 11.21  | 16.30  | 44.40    |
| MDAgents        | -     | 2024.10| 2.40       | 30.12     | 11.25  | 16.38  | 43.55    |
| **Our Method**  |       |        |            |           |        |        |          |
| DiagAgent       | 7B    | -      | 5.47       | **46.08** | 47.66  | 46.86  | 60.78    |
| DiagAgent       | 8B    | -      | 5.71       | 41.58     | 44.56  | 43.02  | 53.85    |
| DiagAgent       | 14B   | -      | 6.77       | 43.87     | **52.74** | **47.89** | **61.63** |

**End‑to‑End Evaluation (Rubric-based Metrics)**


| Model            | Size | Year   | Rubric Score (%) |
|------------------|------|--------|------------------|
| **Basic LLM**    |      |        |                  |
| GPT-4o           | -    | 2024.8 | 18.02            |
| Claude-4-sonnet  | -    | 2025.5 | 25.84            |
| Qwen2.5          | 72B  | 2024.9 | 15.63            |
| Llama3.3         | 70B  | 2024.12| 16.69            |
| DeepSeek-v3      | -    | -      | 19.07            |
| Qwen3            | 235B | 2025.7 | 24.49            |
| GPT-OSS          | 120B | 2025.8 | 22.26            |
| OpenbioLLM       | 70B  | 2024.4 | 14.11            |
| Baichuan-M1      | 14B  | 2025.2 | 17.43            |
| MedGemma         | 27B  | 2025.7 | 20.72            |
| **Agentic System** |    |        |                  |
| MedAgent         | -    | 2024.1 | 19.49            |
| MDAgent          | -    | 2024.10| 21.64            |
| **Our Method**   |      |        |                  |
| DiagAgent-14B    | 14B  | -      | 32.86            |

## Training Details

DiagAgent‑7B is optimized with multi‑turn RL (GRPO) inside `DiagGym`.

- Trajectory construction:
  - Initial inquiry (structured patient history without final diagnosis)
  - Iterative steps: preliminary diagnosis → recommended exam + rationale → exam result
  - Final diagnosis focused on a single primary condition
- Reward design:
  - Diagnosis Accuracy (exact match)
  - Examination Recommendation F1 (overlap with reference EHR exams)
  - Turn Penalty (discourages >12 interaction turns)

For implementation and scripts, see `DiagAgent/train/rl/` in the GitHub.

## Citation
```
@misc{qiu2025evolvingdiagnosticagentsvirtual,
      title={Evolving Diagnostic Agents in a Virtual Clinical Environment}, 
      author={Pengcheng Qiu and Chaoyi Wu and Junwei Liu and Qiaoyu Zheng and Yusheng Liao and Haowen Wang and Yun Yue and Qianrui Fan and Shuai Zhen and Jian Wang and Jinjie Gu and Yanfeng Wang and Ya Zhang and Weidi Xie},
      year={2025},
      eprint={2510.24654},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2510.24654}, 
}
```


## Contact

- Email: henrychur@sjtu.edu.cn
- GitHub: https://github.com/MAGIC-AI4Med/DiagGym