Question Answering
Transformers
Safetensors

Improve model card: Add library name and pipeline tag

#1
by nielsr HF Staff - opened
Files changed (1) hide show
  1. README.md +107 -4
README.md CHANGED
@@ -1,9 +1,10 @@
1
  ---
2
  license: apache-2.0
 
 
3
  ---
4
 
5
  # MedReason: Eliciting Factual Medical Reasoning Steps in LLMs via Knowledge Graphs
6
-
7
  <p align="center">
8
  📃 <a href="https://arxiv.org/abs/2504.00993" target="_blank">Paper</a> |🤗 <a href="https://huggingface.co/UCSC-VLAA/MedReason-8B" target="_blank">MedReason-8B</a> | 📚 <a href="https://huggingface.co/datasets/UCSC-VLAA/MedReason" target="_blank">MedReason Data</a>
9
  </p>
@@ -11,13 +12,35 @@ license: apache-2.0
11
 
12
  ## ⚡Introduction
13
 
 
 
14
  **MedReason** is a large-scale high-quality medical reasoning dataset designed to enable faithful and explainable medical problem-solving in large language models (LLMs).
15
 
16
  - We utilize a structured medical knowledge graph (KG) to convert clinical QA pairs into logical chains of reasoning, or “thinking paths”.
17
  - Our pipeline generates detailed reasoning for various medical questions from 7 medical datasets, resulting in a dataset of **32,682** question-answer pairs, each with detailed, step-by-step explanations.
18
  - By finetuning with proposed [MedReason dataset](https://huggingface.co/datasets/UCSC-VLAA/MedReason), our best model [MedReason-8B](https://huggingface.co/UCSC-VLAA/MedReason-8B), achieves *state-of-the-art* performance.
19
 
20
- We open-sourced our model here.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  ## 👨‍⚕️ Model
23
 
@@ -49,6 +72,87 @@ outputs = model.generate(**inputs, max_new_tokens=2048)
49
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
50
  ```
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ## 🙏🏼 Acknowledgement
53
 
54
  We gratefully acknowledge the inspiring work of [HuatuoGPT-o1](https://github.com/FreedomIntelligence/HuatuoGPT-o1), which laid important groundwork for this research. We also thank the developers of the excellent tools [curator](https://github.com/bespokelabsai/curator/), [trl](https://github.com/huggingface/trl), and [sglang](https://github.com/sgl-project/sglang) for making this work possible.
@@ -65,5 +169,4 @@ We gratefully acknowledge the inspiring work of [HuatuoGPT-o1](https://github.co
65
  primaryClass={cs.CL},
66
  url={https://arxiv.org/abs/2504.00993},
67
  }
68
- ```
69
-
 
1
  ---
2
  license: apache-2.0
3
+ library_name: transformers
4
+ pipeline_tag: question-answering
5
  ---
6
 
7
  # MedReason: Eliciting Factual Medical Reasoning Steps in LLMs via Knowledge Graphs
 
8
  <p align="center">
9
  📃 <a href="https://arxiv.org/abs/2504.00993" target="_blank">Paper</a> |🤗 <a href="https://huggingface.co/UCSC-VLAA/MedReason-8B" target="_blank">MedReason-8B</a> | 📚 <a href="https://huggingface.co/datasets/UCSC-VLAA/MedReason" target="_blank">MedReason Data</a>
10
  </p>
 
12
 
13
  ## ⚡Introduction
14
 
15
+ <img src="./assets/main.png" alt="main" style="zoom: 33%;" />
16
+
17
  **MedReason** is a large-scale high-quality medical reasoning dataset designed to enable faithful and explainable medical problem-solving in large language models (LLMs).
18
 
19
  - We utilize a structured medical knowledge graph (KG) to convert clinical QA pairs into logical chains of reasoning, or “thinking paths”.
20
  - Our pipeline generates detailed reasoning for various medical questions from 7 medical datasets, resulting in a dataset of **32,682** question-answer pairs, each with detailed, step-by-step explanations.
21
  - By finetuning with proposed [MedReason dataset](https://huggingface.co/datasets/UCSC-VLAA/MedReason), our best model [MedReason-8B](https://huggingface.co/UCSC-VLAA/MedReason-8B), achieves *state-of-the-art* performance.
22
 
23
+ We open-sourced our models, data, and code here.
24
+
25
+ ## 📚 Data
26
+
27
+ - **Data Access**
28
+
29
+ | Data | Description | Link |
30
+ | --------- | --------------------------------- | ----------------------------------------------------------- |
31
+ | MedReason | Our quality filtered data for SFT | [Link](https://huggingface.co/datasets/UCSC-VLAA/MedReason) |
32
+
33
+ - **Data Generation**
34
+
35
+ We provide the code for generating Chain-of-Thought reasoning based on medical QA pairs and knowledge-graph (KG) in `./src/data_generation`
36
+
37
+ 1. Set the file path of each datasets in `./configs/dataset_configs.yml`
38
+ 2. Fill your Azure endpoint and API key in `./src/data_generation/utils.py`
39
+ 3. Run the following script
40
+
41
+ ```bash
42
+ python ./src/data_generation/Generate_Reasoning.py --dataset medqa --sample <number_of_samples> --start_idx 0 --batch_size 1&
43
+ ```
44
 
45
  ## 👨‍⚕️ Model
46
 
 
72
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
73
  ```
74
 
75
+ ## 🚀 Training Piepline
76
+
77
+ Simply Supervised-Finetuning (SFT) using MedReason data improves the LLM’s medical reasoning capability.
78
+
79
+ Fine-tune the model on 8-GPU:
80
+
81
+ ```bash
82
+ # based on Huatuo-o1-8B
83
+ accelerate launch --config_file ./configs/deepspeed_zero3.yaml \
84
+ --num_processes 8 \
85
+ --num_machines 1 \
86
+ --machine_rank 0 \
87
+ --deepspeed_multinode_launcher standard ./src/model_training/SFT.py \
88
+ --model_path FreedomIntelligence/HuatuoGPT-o1-8B \
89
+ --data_path /path/to/your/data \
90
+ --n_epochs 3 \
91
+ --experiment_name huatuo_o1_medreason_8B \
92
+ --base_model Llama
93
+
94
+ # based on DeepSeek-distilled-Llama-8B
95
+ accelerate launch --config_file ./configs/deepspeed_zero3.yaml \
96
+ --num_processes 8 \
97
+ --num_machines 1 \
98
+ --machine_rank 0 \
99
+ --deepspeed_multinode_launcher standard ./src/model_training/SFT.py \
100
+ --model_path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
101
+ --data_path /path/to/your/data\
102
+ --n_epochs 3 \
103
+ --experiment_name distilled_llama_medreason_8B \
104
+ --base_model Llama
105
+ ```
106
+
107
+ ## 🧐 Evaluation
108
+
109
+ - **Qualitative Results:**
110
+
111
+ Case Study on Medbullets Benchmark. **MedReason-8B** generates accurate reasoning with reliable knowledge.
112
+
113
+ <img src="./assets/case_v6.png" alt="case_v6" style="zoom: 40%;" />
114
+
115
+ - **Performance on medical benchmarks**:
116
+
117
+ Results of instruction-tuned LLMs fine-tuned with MedReason data:
118
+
119
+ <img src="./assets/tab1.png" alt="tab1" style="zoom:50%;" />
120
+
121
+ Performance of MedReason-8B on challenging and common medical QA benchmarks:
122
+
123
+ <img src="./assets/tab3.png" alt="tab3" style="zoom:50%;" />
124
+
125
+ - **Run evaluation**:
126
+ 1. You first need to install [Sglang](https://github.com/sgl-project/sglang). After installation, deploy the model you want to test using Sglang with the following command:
127
+
128
+ ```bash
129
+ # deploy on 8 GPUs
130
+ log_num=0
131
+ model_name=UCSC-VLAA/MedReason-8B
132
+ port=28${log_num}35
133
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m sglang.launch_server --model-path $model_name --port $port --mem-fraction-static 0.8 --dp 8 --tp 1 > sglang${log_num}.log 2>&1 &
134
+ ```
135
+
136
+ 2. Wait for the model to be deployed. After deployment, you can run the following code for evaluation. We use prompts that allow the model to respond freely. We find that the extracted results are consistently reliable and broadly cover the intended scope. You can also set the `--strict_prompt` option to use stricter prompts for more precise answer extraction.
137
+
138
+ ```bash
139
+ log_num=0
140
+ task_floder=MedReason-8B-results
141
+ model_name=UCSC-VLAA/MedReason-8B
142
+ port=28${log_num}35
143
+
144
+ eval_file=./eval_data/medbullets_op4.jsonl
145
+ python ./src/evaluation/eval.py --model_name $model_name --eval_file $eval_file --port $port --strict_prompt --batch_size 1000 --max_new_tokens 2000 --task_floder $task_floder
146
+ ```
147
+
148
+ 3. After completing the evaluation, run the following code to stop the Sglang service and release GPU memory.
149
+
150
+ ```bash
151
+ pkill -f sglang
152
+ pkill -f multiprocessing.spawn
153
+ ```
154
+
155
+
156
  ## 🙏🏼 Acknowledgement
157
 
158
  We gratefully acknowledge the inspiring work of [HuatuoGPT-o1](https://github.com/FreedomIntelligence/HuatuoGPT-o1), which laid important groundwork for this research. We also thank the developers of the excellent tools [curator](https://github.com/bespokelabsai/curator/), [trl](https://github.com/huggingface/trl), and [sglang](https://github.com/sgl-project/sglang) for making this work possible.
 
169
  primaryClass={cs.CL},
170
  url={https://arxiv.org/abs/2504.00993},
171
  }
172
+ ```