Muhammadidrees commited on
Commit
373f237
·
verified ·
1 Parent(s): 7ba8f15

Upload 15 files

Browse files
Files changed (16) hide show
  1. .gitattributes +1 -0
  2. Coreectcodewithoutfronted.py +141 -0
  3. LICENSE +201 -0
  4. README.md +182 -12
  5. alpaca_data.json +3 -0
  6. chat.py +100 -0
  7. chatdoctor5k.json +0 -0
  8. format_dataset.csv +0 -0
  9. frontend.py +313 -0
  10. frontend_VOic.py +459 -0
  11. requirements.txt +10 -0
  12. teak.py +103 -0
  13. test.py +328 -0
  14. train.py +231 -0
  15. train_lora.py +321 -0
  16. utils.py +174 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ alpaca_data.json filter=lfs diff=lfs merge=lfs -text
Coreectcodewithoutfronted.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
5
+
6
+ # =============================
7
+ # Configuration
8
+ # =============================
9
+ MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
10
+ MAX_NEW_TOKENS = 200
11
+ TEMPERATURE = 0.5
12
+ TOP_K = 50
13
+ REPETITION_PENALTY = 1.1
14
+
15
+ # Detect device
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"Loading model from {MODEL_PATH} on {device}...")
18
+
19
+ # =============================
20
+ # Load Tokenizer and Model
21
+ # =============================
22
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
23
+ model = LlamaForCausalLM.from_pretrained(
24
+ MODEL_PATH,
25
+ device_map="auto",
26
+ torch_dtype=torch.float16,
27
+ low_cpu_mem_usage=True
28
+ )
29
+
30
+ generator = model.generate
31
+ print("✅ ChatDoctor model loaded successfully!\n")
32
+
33
+ # =============================
34
+ # Stopping Criteria
35
+ # =============================
36
+ class StopOnTokens(StoppingCriteria):
37
+ def __init__(self, stop_ids):
38
+ self.stop_ids = stop_ids
39
+
40
+ def __call__(self, input_ids, scores, **kwargs):
41
+ for stop_id_seq in self.stop_ids:
42
+ if len(stop_id_seq) == 1:
43
+ if input_ids[0][-1] == stop_id_seq[0]:
44
+ return True
45
+ else:
46
+ if len(input_ids[0]) >= len(stop_id_seq):
47
+ if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
48
+ return True
49
+ return False
50
+
51
+ # =============================
52
+ # Chat History
53
+ # =============================
54
+ history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"]
55
+
56
+ # =============================
57
+ # Get Response Function
58
+ # =============================
59
+ def get_response(user_input):
60
+ global history
61
+ human_invitation = "Patient: "
62
+ doctor_invitation = "ChatDoctor: "
63
+
64
+ # Add user input to history
65
+ history.append(human_invitation + user_input)
66
+
67
+ # Build conversation prompt
68
+ prompt = "\n".join(history) + "\n" + doctor_invitation
69
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
70
+
71
+ # Define stop words and their token IDs
72
+ stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
73
+ stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
74
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
75
+
76
+ # Generate model response
77
+ with torch.no_grad():
78
+ output_ids = generator(
79
+ input_ids,
80
+ max_new_tokens=MAX_NEW_TOKENS,
81
+ do_sample=True,
82
+ temperature=TEMPERATURE,
83
+ top_k=TOP_K,
84
+ repetition_penalty=REPETITION_PENALTY,
85
+ stopping_criteria=stopping_criteria,
86
+ pad_token_id=tokenizer.eos_token_id,
87
+ eos_token_id=tokenizer.eos_token_id
88
+ )
89
+
90
+ # Decode and clean response
91
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
92
+ response = full_output[len(prompt):].strip()
93
+
94
+ # Remove any "Patient:" that might have slipped through
95
+ for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
96
+ if stop_word in response:
97
+ response = response.split(stop_word)[0].strip()
98
+ break
99
+
100
+ # Remove any leading/trailing punctuation artifacts
101
+ response = response.strip()
102
+
103
+ history.append(doctor_invitation + response)
104
+
105
+ # Free memory
106
+ del input_ids, output_ids
107
+ gc.collect()
108
+ torch.cuda.empty_cache()
109
+
110
+ return response
111
+
112
+ # =============================
113
+ # Chat Loop
114
+ # =============================
115
+ if __name__ == "__main__":
116
+ print("\n=== ChatDoctor is ready! ===")
117
+ print("You (the human) = Patient ")
118
+ print("AI = ChatDoctor")
119
+ print("Type 'exit' or 'quit' to end the chat.\n")
120
+
121
+ print("ChatDoctor: Hi there! How can I help you today?\n")
122
+
123
+ while True:
124
+ try:
125
+ user_input = input("Patient: ").strip()
126
+ if user_input.lower() in ["exit", "quit"]:
127
+ print("ChatDoctor: Take care! Goodbye ")
128
+ break
129
+
130
+ if not user_input:
131
+ continue
132
+
133
+ response = get_response(user_input)
134
+ print("ChatDoctor:", response, "\n")
135
+
136
+ except KeyboardInterrupt:
137
+ print("\nChatDoctor: Take care! Goodbye")
138
+ break
139
+ except Exception as e:
140
+ print(f"Error: {e}")
141
+ print("Please try again.\n")
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,182 @@
1
- ---
2
- title: JAYConverstionalAI
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center" width="80%">
2
+ <img src="fig/logo.png" style="width: 40%; min-width: 300px; display: block; margin: auto;">
3
+ </p>
4
+
5
+
6
+ # [ChatDoctor: A Medical Chat Model Fine-Tuned on a Large Language Model Meta-AI (LLaMA) Using Medical Domain Knowledge](https://www.cureus.com/articles/152858-chatdoctor-a-medical-chat-model-fine-tuned-on-a-large-language-model-meta-ai-llama-using-medical-domain-knowledge#!/)
7
+ Yunxiang Li<sup>1</sup>, Zihan Li<sup>2</sup>, Kai Zhang<sup>3</sup>, Ruilong Dan<sup>4</sup>, Steve Jiang<sup>1</sup>, You Zhang<sup>1</sup>
8
+ <h5>1 UT Southwestern Medical Center, USA</h5>
9
+ <h5>2 University of Illinois at Urbana-Champaign, USA</h5>
10
+ <h5>3 Ohio State University, USA</h5>
11
+ <h5>4 Hangzhou Dianzi University, China</h5>
12
+
13
+
14
+ [![License](https://img.shields.io/badge/License-Apache_2.0-green.svg)](https://github.com/HUANGLIZI/ChatDoctor/blob/main/LICENSE)
15
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/release/python-390/)
16
+ [![Page](https://img.shields.io/badge/Web-Page-yellow)](https://www.yunxiangli.top/ChatDoctor/)
17
+ ## Resources List
18
+ Autonomous ChatDoctor with Disease Database [Demo](https://huggingface.co/spaces/kenton-li/chatdoctor_csv).
19
+
20
+ 100k real conversations between patients and doctors from HealthCareMagic.com [HealthCareMagic-100k](https://drive.google.com/file/d/1lyfqIwlLSClhgrCutWuEe_IACNq6XNUt/view?usp=sharing).
21
+
22
+ Real conversations between patients and doctors from icliniq.com [icliniq-10k](https://drive.google.com/file/d/1ZKbqgYqWc7DJHs3N9TQYQVPdDQmZaClA/view?usp=sharing).
23
+
24
+ Checkpoints of ChatDoctor, [link](https://drive.google.com/drive/folders/11-qPzz9ZdHD6pc47wBSOUSU61MaDPyRh?usp=sharing).
25
+
26
+ Stanford Alpaca data for basic conversational capabilities. [Alpaca link](https://github.com/Kent0n-Li/ChatDoctor/blob/main/alpaca_data.json).
27
+
28
+
29
+ <p align="center" width="100%">
30
+ <img src="fig/overview.PNG" style="width: 70%; min-width: 300px; display: block; margin: auto;">
31
+ </p>
32
+
33
+ <p align="center" width="100%">
34
+ <img src="fig/wiki.PNG" style="width: 70%; min-width: 300px; display: block; margin: auto;">
35
+ </p>
36
+
37
+
38
+ ## Setup:
39
+ In a conda env with pytorch available, run:
40
+ ```
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ ## Interactive Demo Page:
45
+ Demo Page: https://huggingface.co/spaces/kenton-li/chatdoctor_csv
46
+ It is worth noting that our model has not yet achieved 100% accurate output, please do not apply it to real clinical scenarios.
47
+
48
+ For those who want to try the online demo, please register for hugging face and fill out this form [link](https://forms.office.com/Pages/ResponsePage.aspx?id=lYZBnaxxMUy1ssGWyOw8ij06Cb8qnDJKvu2bVpV1-ANURUU0TllBWVVHUjQ1MDJUNldGTTZWV1c5UC4u).
49
+
50
+ ## Data and model:
51
+ ### 1. ChatDoctor Dataset:
52
+ You can download the following training dataset
53
+
54
+ 100k real conversations between patients and doctors from HealthCareMagic.com [HealthCareMagic-100k](https://drive.google.com/file/d/1lyfqIwlLSClhgrCutWuEe_IACNq6XNUt/view?usp=sharing).
55
+
56
+ 10k real conversations between patients and doctors from icliniq.com [icliniq-10k](https://drive.google.com/file/d/1ZKbqgYqWc7DJHs3N9TQYQVPdDQmZaClA/view?usp=sharing).
57
+
58
+ 5k generated conversations between patients and physicians from ChatGPT [GenMedGPT-5k](https://drive.google.com/file/d/1nDTKZ3wZbZWTkFMBkxlamrzbNz0frugg/view?usp=sharing) and [disease database](https://github.com/Kent0n-Li/ChatDoctor/blob/main/format_dataset.csv).
59
+
60
+ Our model was firstly be fine-tuned by Stanford Alpaca's data to have some basic conversational capabilities. [Alpaca link](https://github.com/Kent0n-Li/ChatDoctor/blob/main/alpaca_data.json)
61
+
62
+ ### 2. Model Weights:
63
+
64
+ Place the model weights file in the ./pretrained folder.
65
+
66
+ ## How to fine-tuning
67
+
68
+ ```python
69
+ torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
70
+ --model_name_or_path <your_path_to_hf_converted_llama_ckpt_and_tokenizer> \
71
+ --data_path ./HealthCareMagic-100k.json \
72
+ --bf16 True \
73
+ --output_dir pretrained \
74
+ --num_train_epochs 1 \
75
+ --per_device_train_batch_size 4 \
76
+ --per_device_eval_batch_size 4 \
77
+ --gradient_accumulation_steps 8 \
78
+ --evaluation_strategy "no" \
79
+ --save_strategy "steps" \
80
+ --save_steps 2000 \
81
+ --save_total_limit 1 \
82
+ --learning_rate 2e-6 \
83
+ --weight_decay 0. \
84
+ --warmup_ratio 0.03 \
85
+ --lr_scheduler_type "cosine" \
86
+ --logging_steps 1 \
87
+ --fsdp "full_shard auto_wrap" \
88
+ --fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
89
+ --tf32 True
90
+ ```
91
+
92
+
93
+ Fine-tuning with Lora
94
+ ```python
95
+ WORLD_SIZE=6 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --nproc_per_node=6 --master_port=4567 train_lora.py \
96
+ --base_model './weights-alpaca/' \
97
+ --data_path 'HealthCareMagic-100k.json' \
98
+ --output_dir './lora_models/' \
99
+ --batch_size 32 \
100
+ --micro_batch_size 4 \
101
+ --num_epochs 1 \
102
+ --learning_rate 3e-5 \
103
+ --cutoff_len 256 \
104
+ --val_set_size 120 \
105
+ --adapter_name lora
106
+ ```
107
+
108
+ ## How to inference
109
+ You can build a ChatDoctor model on your own machine and communicate with it.
110
+ ```python
111
+ python chat.py
112
+ ```
113
+
114
+ ```python
115
+ cd Autonomous_ChatDoctor_csv
116
+ python chat_csv.py
117
+ ```
118
+
119
+ ```python
120
+ cd Autonomous_ChatDoctor_Wikipedia
121
+ python chat_wiki.py
122
+ ```
123
+
124
+ ## Examples:
125
+ We fine-tuned our model on HealthCareMagic-100K, and test our model on iCliniq-10k.
126
+
127
+ ```
128
+ Patient question from icliniq: Hi doctor,I am getting one-sided headache. I do not get them often. They are behind my right eye and temple area. Now this time I am also fighting with heavy cold and cough. It came on late morning today and it hurts more when I cough or shake my head or bend over. Prior to this, I had the same headache about four months ago. At that time it came on midafternoon and lasted till next morning. Prior to that, I had the same headache exactly on the right side and lasted for a day. I am on CCB and Protonix. I am scared of brain tumor. Please help
129
+
130
+ Real doctor answer from icliniq: Hi. Based on what you have told, this is more likely to be due to sinusitis. Sinuses are small air filled cavities in your skull bone. They can get infected and when they are infected, you will get headaches behind eyes. Such headaches will be more common during cold and cough. Brain tumor headache will be on both sides, more in forehead and you will also have vomiting in that case. Based on what you have told, this is more likely due to simple sinus infection instead of brain tumor. CT brain and CT paranasal sinuses.
131
+
132
+ Answer form ChatGPT: Hello! I'm sorry to hear you're experiencing these headaches. Based on the location of the pain, it sounds like it may be a tension headache or a migraine. These types of headaches can sometimes be triggered by colds or other illnesses. However, it's always best to get checked out by your doctor to rule out any other underlying conditions. In terms of a brain tumor, that is a very rare cause of headaches. Most people who experience headaches do not have a brain tumor. That being said, if you have any concerning symptoms such as vision changes, weakness, or numbness, it is important to seek medical attention right away. In the meantime, you can try taking over-the-counter pain relievers such as ibuprofen or acetaminophen to help with the headache. It's also important to get plenty of rest and stay hydrated. If the headaches persist or become more frequent, please make an appointment with your doctor. Take care!
133
+
134
+ Answer form ChatDoctor: Hi, Thank you for posting your query. First, I would like to reassure you that there is no need to worry about brain tumor. Also, we should keep migraine apart as well. The most likely cause of your headaches could be sinusitis. This can be confirmed by doing an X-ray of face. Treatment would consist of using decongestant tablets/ nasal sprays. Best wishes, Chat Doctor.
135
+ ```
136
+ ChatDoctor accurately mentions sinusitis, which is also present in the real conversation in icliniq.
137
+
138
+
139
+ ## Overview
140
+ ChatDoctor is a next-generation AI doctor model that is based on the [LLaMA](https://github.com/facebookresearch/llama) model. The goal of this project is to provide patients with an intelligent and reliable healthcare companion that can answer their medical queries and provide them with personalized medical advice.
141
+
142
+ The ChatDoctor is an advanced language model that is specifically designed for medical applications. It has been trained on a large corpus of medical literature and has a deep understanding of medical terminology, procedures, and diagnoses. This model serves as the foundation for ChatDoctor, enabling it to analyze patients' symptoms and medical history, provide accurate diagnoses, and suggest appropriate treatment options.
143
+
144
+ The ChatDoctor model is designed to simulate a conversation between a doctor and a patient, using natural language processing (NLP) and machine learning techniques. Patients can interact with the ChatDoctor model through a chat interface, asking questions about their health, symptoms, or medical conditions. The model will then analyze the input and provide a response that is tailored to the patient's unique situation.
145
+
146
+ One of the key features of the ChatDoctor model is its ability to learn and adapt over time. As more patients interact with the model, it will continue to refine its responses and improve its accuracy. This means that patients can expect to receive increasingly personalized and accurate medical advice over time.
147
+
148
+
149
+ ## Patient-physician Conversation Dataset</h2>
150
+ The first step in fine-tuning is to collect a dataset of patient-physician conversations. In patient-physician conversations, the patient's descriptions of disease symptoms are often colloquial and cursory. If we manually construct the synthesized patient-physician conversation dataset, it often leads to the problem of insufficient diversity and over-specialized descriptions, which are often spaced out from real scenarios. Collecting real patient-physician conversations is a better solution. Therefore, we collected about 100k real doctor-patient conversations from an online medical consultation website HealthCareMagic(www.healthcaremagic.com). We filtered these data both manually and automatically, removed the identity information of the doctor and patient, and used language tools to correct grammatical errors, and we named this dataset HealthCareMagic-100k. In addition, we collected approximately 10k patient-physician conversations from the online medical consultation website iCliniq to evaluate the performance of our model.
151
+
152
+
153
+ ## Autonomous ChatDoctor based on Knowledge Brain</h2>
154
+ Equipped with the external knowledge brain, i.e., Wikipedia or our constructed database encompassing over 700 diseases, ChatDoctor could retrieve the corresponding knowledge and reliable sources to answer patients' inquiries more accurately. After constructing the external knowledge brain, we need to let our ChatDoctor retrieve the knowledge he needs autonomously, which can generally be achieved in a large language model by constructing appropriate prompts. To automate this process, we design keyword mining prompts for ChatDoctor to extract key terms for relevant knowledge seeking. Then, the top-ranked relevant passages were retrieved from Knowledge Brain with a term-matching retrieval system. As for the disease database, since the model cannot read all the data at once, we first let the model read the data in batches and select for itself the data entries that might help answer the patient's question. Finally, all the data entries selected by the model are given to the model for a final answer. This approach better ensures that patients receive well-informed and precise responses backed by credible references.
155
+
156
+
157
+
158
+ ## Limitations
159
+ We emphasize that ChatDoctor is for academic research only and any commercial use and clinical use is prohibited. There are three factors in this decision: First, ChatDoctor is based on LLaMA and has a non-commercial license, so we necessarily inherited this decision. Second, our model is not licensed for healthcare-related purposes. Also, we have not designed sufficient security measures, and the current model still does not guarantee the full correctness of medical diagnoses.
160
+
161
+
162
+
163
+
164
+ ## Reference
165
+
166
+ ChatDoctor: A Medical Chat Model Fine-tuned on LLaMA Model using Medical Domain Knowledge
167
+
168
+ ```
169
+ @article{li2023chatdoctor,
170
+ title={ChatDoctor: A Medical Chat Model Fine-Tuned on a Large Language Model Meta-AI (LLaMA) Using Medical Domain Knowledge},
171
+ author={Li, Yunxiang and Li, Zihan and Zhang, Kai and Dan, Ruilong and Jiang, Steve and Zhang, You},
172
+ journal={Cureus},
173
+ volume={15},
174
+ number={6},
175
+ year={2023},
176
+ publisher={Cureus}
177
+ }
178
+ ```
179
+
180
+
181
+
182
+
alpaca_data.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a0a9920c72e27b32013e5c4ad7727d9eede8eaab75c3f4b7eb62eda019561d7
3
+ size 23034003
chat.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chat.py
2
+ import os
3
+ import gc
4
+ import torch
5
+ from transformers import LlamaTokenizer, LlamaForCausalLM
6
+
7
+ # =============================
8
+ # Configuration
9
+ # =============================
10
+ MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
11
+ MAX_NEW_TOKENS = 200
12
+ TEMPERATURE = 0.5
13
+ TOP_K = 50
14
+ REPETITION_PENALTY = 1.1
15
+
16
+ # Detect device
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"Loading model from {MODEL_PATH} on {device}...")
19
+
20
+ # =============================
21
+ # Load Tokenizer and Model
22
+ # =============================
23
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
24
+
25
+ model = LlamaForCausalLM.from_pretrained(
26
+ MODEL_PATH,
27
+ device_map="auto", # automatically dispatch weights to GPU
28
+ torch_dtype=torch.float16, # half precision for faster inference
29
+ low_cpu_mem_usage=True # optimize CPU memory
30
+ )
31
+
32
+ # DO NOT call model.to(device) when using device_map="auto"
33
+ generator = model.generate
34
+ print("✅ Model loaded successfully!\n")
35
+
36
+ # =============================
37
+ # Chat History
38
+ # =============================
39
+ history = ["ChatDoctor: I am ChatDoctor, what medical questions do you have?"]
40
+
41
+ # =============================
42
+ # Response Function
43
+ # =============================
44
+ def get_response(user_input):
45
+ global history
46
+ human_invitation = "Patient: "
47
+ doctor_invitation = "ChatDoctor: "
48
+
49
+ # Append user input
50
+ history.append(human_invitation + user_input)
51
+
52
+ # Build prompt
53
+ prompt = "\n".join(history) + "\n" + doctor_invitation
54
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
55
+
56
+ # Generate response
57
+ with torch.no_grad():
58
+ output_ids = generator(
59
+ input_ids,
60
+ max_new_tokens=MAX_NEW_TOKENS,
61
+ do_sample=True,
62
+ temperature=TEMPERATURE,
63
+ top_k=TOP_K,
64
+ repetition_penalty=REPETITION_PENALTY
65
+ )
66
+
67
+ # Decode response
68
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
69
+ response = full_output[len(prompt):].strip()
70
+
71
+ # Clean if the model repeats the patient prompt
72
+ if response.startswith("Patient:"):
73
+ response = response[len("Patient:"):].strip()
74
+
75
+ # Append model response to history
76
+ history.append(doctor_invitation + response)
77
+
78
+ # Free memory
79
+ del input_ids, output_ids
80
+ gc.collect()
81
+ torch.cuda.empty_cache()
82
+
83
+ return response
84
+
85
+ # =============================
86
+ # CLI Chat
87
+ # =============================
88
+ if __name__ == "__main__":
89
+ print("\n=== ChatDoctor is ready! Type your questions. ===\n")
90
+ while True:
91
+ try:
92
+ user_input = input("Patient: ").strip()
93
+ if user_input.lower() in ["exit", "quit"]:
94
+ print("Exiting ChatDoctor. Goodbye!")
95
+ break
96
+ response = get_response(user_input)
97
+ print("ChatDoctor: " + response + "\n")
98
+ except KeyboardInterrupt:
99
+ print("\nExiting ChatDoctor. Goodbye!")
100
+ break
chatdoctor5k.json ADDED
The diff for this file is too large to render. See raw diff
 
format_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
frontend.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
6
+
7
+ # =============================
8
+ # Configuration
9
+ # =============================
10
+ MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
11
+ MAX_NEW_TOKENS = 200
12
+ TEMPERATURE = 0.5
13
+ TOP_K = 50
14
+ REPETITION_PENALTY = 1.1
15
+
16
+ # Detect device
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"Loading model from {MODEL_PATH} on {device}...")
19
+
20
+ # =============================
21
+ # Load Tokenizer and Model
22
+ # =============================
23
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
24
+ model = LlamaForCausalLM.from_pretrained(
25
+ MODEL_PATH,
26
+ device_map="auto",
27
+ torch_dtype=torch.float16,
28
+ low_cpu_mem_usage=True
29
+ )
30
+
31
+ generator = model.generate
32
+ print("✅ ChatDoctor model loaded successfully!\n")
33
+
34
+ # =============================
35
+ # Stopping Criteria
36
+ # =============================
37
+ class StopOnTokens(StoppingCriteria):
38
+ def __init__(self, stop_ids):
39
+ self.stop_ids = stop_ids
40
+
41
+ def __call__(self, input_ids, scores, **kwargs):
42
+ for stop_id_seq in self.stop_ids:
43
+ if len(stop_id_seq) == 1:
44
+ if input_ids[0][-1] == stop_id_seq[0]:
45
+ return True
46
+ else:
47
+ if len(input_ids[0]) >= len(stop_id_seq):
48
+ if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
49
+ return True
50
+ return False
51
+
52
+ # =============================
53
+ # Chat History (Global)
54
+ # =============================
55
+ conversation_history = []
56
+
57
+ # =============================
58
+ # Get Response Function
59
+ # =============================
60
+ def get_response(user_input, history_context):
61
+ """Generate response from ChatDoctor model"""
62
+ human_invitation = "Patient: "
63
+ doctor_invitation = "ChatDoctor: "
64
+
65
+ # Build conversation from history
66
+ history_text = []
67
+ for human, assistant in history_context:
68
+ if human:
69
+ history_text.append(human_invitation + human)
70
+ if assistant:
71
+ history_text.append(doctor_invitation + assistant)
72
+
73
+ # Add current user input
74
+ history_text.append(human_invitation + user_input)
75
+
76
+ # Build conversation prompt
77
+ prompt = "\n".join(history_text) + "\n" + doctor_invitation
78
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
79
+
80
+ # Define stop words and their token IDs
81
+ stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
82
+ stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
83
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
84
+
85
+ # Generate model response
86
+ with torch.no_grad():
87
+ output_ids = generator(
88
+ input_ids,
89
+ max_new_tokens=MAX_NEW_TOKENS,
90
+ do_sample=True,
91
+ temperature=TEMPERATURE,
92
+ top_k=TOP_K,
93
+ repetition_penalty=REPETITION_PENALTY,
94
+ stopping_criteria=stopping_criteria,
95
+ pad_token_id=tokenizer.eos_token_id,
96
+ eos_token_id=tokenizer.eos_token_id
97
+ )
98
+
99
+ # Decode and clean response
100
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
101
+ response = full_output[len(prompt):].strip()
102
+
103
+ # Remove any "Patient:" that might have slipped through
104
+ for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
105
+ if stop_word in response:
106
+ response = response.split(stop_word)[0].strip()
107
+ break
108
+
109
+ response = response.strip()
110
+
111
+ # Free memory
112
+ del input_ids, output_ids
113
+ gc.collect()
114
+ torch.cuda.empty_cache()
115
+
116
+ return response
117
+
118
+ # =============================
119
+ # Gradio Chat Function
120
+ # =============================
121
+ def chat_function(message, history):
122
+ """Gradio chat interface function"""
123
+ if not message.strip():
124
+ return ""
125
+
126
+ try:
127
+ response = get_response(message, history)
128
+ return response
129
+ except Exception as e:
130
+ return f"Error: {str(e)}"
131
+
132
+ # =============================
133
+ # Custom CSS
134
+ # =============================
135
+ custom_css = """
136
+ #header {
137
+ text-align: center;
138
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
139
+ color: white;
140
+ padding: 20px;
141
+ border-radius: 10px;
142
+ margin-bottom: 20px;
143
+ }
144
+
145
+ #header h1 {
146
+ margin: 0;
147
+ font-size: 2.5em;
148
+ }
149
+
150
+ #header p {
151
+ margin: 10px 0 0 0;
152
+ font-size: 1.1em;
153
+ opacity: 0.9;
154
+ }
155
+
156
+ .disclaimer {
157
+ background-color: #fff3cd;
158
+ border: 1px solid #ffc107;
159
+ border-radius: 8px;
160
+ padding: 15px;
161
+ margin: 20px 0;
162
+ color: #856404;
163
+ }
164
+
165
+ .disclaimer h3 {
166
+ margin-top: 0;
167
+ color: #856404;
168
+ }
169
+
170
+ footer {
171
+ text-align: center;
172
+ margin-top: 30px;
173
+ color: #666;
174
+ font-size: 0.9em;
175
+ }
176
+ """
177
+
178
+ # =============================
179
+ # Gradio Interface
180
+ # =============================
181
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
182
+ # Header
183
+ gr.HTML("""
184
+ <div id="header">
185
+ <h1>🩺 ChatDoctor AI Assistant</h1>
186
+ <p>Your AI-powered medical conversation partner</p>
187
+ </div>
188
+ """)
189
+
190
+ # Disclaimer
191
+ gr.HTML("""
192
+ <div class="disclaimer">
193
+ <h3>⚠️ Medical Disclaimer</h3>
194
+ <p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
195
+ It is NOT a substitute for professional medical advice, diagnosis, or treatment.
196
+ Always seek the advice of your physician or other qualified health provider with any questions
197
+ you may have regarding a medical condition. Never disregard professional medical advice or
198
+ delay in seeking it because of something you have read here.</p>
199
+ </div>
200
+ """)
201
+
202
+ # Chatbot Interface
203
+ chatbot = gr.Chatbot(
204
+ height=500,
205
+ placeholder="<div style='text-align: center; padding: 40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. How can I assist you today?</p></div>",
206
+ show_label=False,
207
+ avatar_images=(None, "🤖"),
208
+ )
209
+
210
+ with gr.Row():
211
+ msg = gr.Textbox(
212
+ placeholder="Type your message here... (e.g., 'I have a headache')",
213
+ show_label=False,
214
+ scale=9,
215
+ container=False
216
+ )
217
+ submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
218
+
219
+ with gr.Row():
220
+ clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
221
+ retry_btn = gr.Button("🔄 Retry", scale=1)
222
+
223
+ # Examples
224
+ gr.Examples(
225
+ examples=[
226
+ "I have a persistent headache for 3 days. What should I do?",
227
+ "What are the symptoms of diabetes?",
228
+ "How can I improve my sleep quality?",
229
+ "I have a fever and sore throat. Should I be concerned?",
230
+ "What are some natural ways to reduce stress?",
231
+ ],
232
+ inputs=msg,
233
+ label="💡 Example Questions"
234
+ )
235
+
236
+ # Settings (collapsed by default)
237
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
238
+ temperature_slider = gr.Slider(
239
+ minimum=0.1,
240
+ maximum=1.0,
241
+ value=TEMPERATURE,
242
+ step=0.1,
243
+ label="Temperature (Creativity)",
244
+ info="Higher values make responses more creative but less focused"
245
+ )
246
+ max_tokens_slider = gr.Slider(
247
+ minimum=50,
248
+ maximum=500,
249
+ value=MAX_NEW_TOKENS,
250
+ step=50,
251
+ label="Max Response Length",
252
+ info="Maximum number of tokens in response"
253
+ )
254
+ top_k_slider = gr.Slider(
255
+ minimum=1,
256
+ maximum=100,
257
+ value=TOP_K,
258
+ step=1,
259
+ label="Top K",
260
+ info="Limits vocabulary selection"
261
+ )
262
+
263
+ # Footer
264
+ gr.HTML("""
265
+ <footer>
266
+ <p>Powered by ChatDoctor Model | Built with Gradio</p>
267
+ <p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
268
+ </footer>
269
+ """)
270
+
271
+ # Event handlers
272
+ def user_message(user_msg, history):
273
+ return "", history + [[user_msg, None]]
274
+
275
+ def bot_response(history, temp, max_tok, top_k_val):
276
+ global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
277
+ TEMPERATURE = temp
278
+ MAX_NEW_TOKENS = int(max_tok)
279
+ TOP_K = int(top_k_val)
280
+
281
+ user_msg = history[-1][0]
282
+ bot_msg = chat_function(user_msg, history[:-1])
283
+ history[-1][1] = bot_msg
284
+ return history
285
+
286
+ # Connect events
287
+ msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
288
+ bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
289
+ )
290
+
291
+ submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
292
+ bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
293
+ )
294
+
295
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
296
+
297
+ def retry_last():
298
+ return None
299
+
300
+ retry_btn.click(retry_last, None, chatbot, queue=False)
301
+
302
+ # =============================
303
+ # Launch Interface
304
+ # =============================
305
+ if __name__ == "__main__":
306
+ print("\n🚀 Launching ChatDoctor Gradio Interface...")
307
+ demo.queue()
308
+ demo.launch(
309
+ server_name="0.0.0.0", # Accessible from network
310
+ server_port=7860,
311
+ share=False, # Set to True to create public link
312
+ show_error=True
313
+ )
frontend_VOic.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
6
+
7
+ # =============================
8
+ # Configuration
9
+ # =============================
10
+ MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
11
+ MAX_NEW_TOKENS = 200
12
+ TEMPERATURE = 0.5
13
+ TOP_K = 50
14
+ REPETITION_PENALTY = 1.1
15
+
16
+ # Detect device
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"Loading model from {MODEL_PATH} on {device}...")
19
+
20
+ # =============================
21
+ # Load Tokenizer and Model
22
+ # =============================
23
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
24
+ model = LlamaForCausalLM.from_pretrained(
25
+ MODEL_PATH,
26
+ device_map="auto",
27
+ torch_dtype=torch.float16,
28
+ low_cpu_mem_usage=True
29
+ )
30
+
31
+ generator = model.generate
32
+ print("✅ ChatDoctor model loaded successfully!\n")
33
+
34
+ # =============================
35
+ # Stopping Criteria
36
+ # =============================
37
+ class StopOnTokens(StoppingCriteria):
38
+ def __init__(self, stop_ids):
39
+ self.stop_ids = stop_ids
40
+
41
+ def __call__(self, input_ids, scores, **kwargs):
42
+ for stop_id_seq in self.stop_ids:
43
+ if len(stop_id_seq) == 1:
44
+ if input_ids[0][-1] == stop_id_seq[0]:
45
+ return True
46
+ else:
47
+ if len(input_ids[0]) >= len(stop_id_seq):
48
+ if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
49
+ return True
50
+ return False
51
+
52
+ # =============================
53
+ # Get Response Function
54
+ # =============================
55
+ def get_response(user_input, history_context):
56
+ """Generate response from ChatDoctor model"""
57
+ human_invitation = "Patient: "
58
+ doctor_invitation = "ChatDoctor: "
59
+
60
+ # Build conversation from history
61
+ history_text = []
62
+ for human, assistant in history_context:
63
+ if human:
64
+ history_text.append(human_invitation + human)
65
+ if assistant:
66
+ history_text.append(doctor_invitation + assistant)
67
+
68
+ # Add current user input
69
+ history_text.append(human_invitation + user_input)
70
+
71
+ # Build conversation prompt
72
+ prompt = "\n".join(history_text) + "\n" + doctor_invitation
73
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
74
+
75
+ # Define stop words and their token IDs
76
+ stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
77
+ stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
78
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
79
+
80
+ # Generate model response
81
+ with torch.no_grad():
82
+ output_ids = generator(
83
+ input_ids,
84
+ max_new_tokens=MAX_NEW_TOKENS,
85
+ do_sample=True,
86
+ temperature=TEMPERATURE,
87
+ top_k=TOP_K,
88
+ repetition_penalty=REPETITION_PENALTY,
89
+ stopping_criteria=stopping_criteria,
90
+ pad_token_id=tokenizer.eos_token_id,
91
+ eos_token_id=tokenizer.eos_token_id
92
+ )
93
+
94
+ # Decode and clean response
95
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
96
+ response = full_output[len(prompt):].strip()
97
+
98
+ # Remove any "Patient:" that might have slipped through
99
+ for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
100
+ if stop_word in response:
101
+ response = response.split(stop_word)[0].strip()
102
+ break
103
+
104
+ response = response.strip()
105
+
106
+ # Free memory
107
+ del input_ids, output_ids
108
+ gc.collect()
109
+ torch.cuda.empty_cache()
110
+
111
+ return response
112
+
113
+ # =============================
114
+ # Gradio Chat Function
115
+ # =============================
116
+ def chat_function(message, history):
117
+ """Gradio chat interface function"""
118
+ if not message.strip():
119
+ return ""
120
+
121
+ try:
122
+ response = get_response(message, history)
123
+ return response
124
+ except Exception as e:
125
+ return f"Error: {str(e)}"
126
+
127
+ # =============================
128
+ # Text-to-Speech Function
129
+ # =============================
130
+ def text_to_speech(text):
131
+ """Convert text response to speech"""
132
+ try:
133
+ from gtts import gTTS
134
+ import tempfile
135
+
136
+ if not text or text.startswith("Error:"):
137
+ return None
138
+
139
+ # Create speech
140
+ tts = gTTS(text=text, lang='en', slow=False)
141
+
142
+ # Save to temporary file
143
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
144
+ tts.save(temp_file.name)
145
+
146
+ return temp_file.name
147
+ except Exception as e:
148
+ print(f"TTS Error: {e}")
149
+ return None
150
+
151
+ # =============================
152
+ # Custom CSS
153
+ # =============================
154
+ custom_css = """
155
+ #header {
156
+ text-align: center;
157
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
158
+ color: white;
159
+ padding: 20px;
160
+ border-radius: 10px;
161
+ margin-bottom: 20px;
162
+ }
163
+
164
+ #header h1 {
165
+ margin: 0;
166
+ font-size: 2.5em;
167
+ }
168
+
169
+ #header p {
170
+ margin: 10px 0 0 0;
171
+ font-size: 1.1em;
172
+ opacity: 0.9;
173
+ }
174
+
175
+ .disclaimer {
176
+ background-color: #fff3cd;
177
+ border: 1px solid #ffc107;
178
+ border-radius: 8px;
179
+ padding: 15px;
180
+ margin: 20px 0;
181
+ color: #856404;
182
+ }
183
+
184
+ .disclaimer h3 {
185
+ margin-top: 0;
186
+ color: #856404;
187
+ }
188
+
189
+ .voice-section {
190
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
191
+ padding: 20px;
192
+ border-radius: 10px;
193
+ margin: 20px 0;
194
+ }
195
+
196
+ footer {
197
+ text-align: center;
198
+ margin-top: 30px;
199
+ color: #666;
200
+ font-size: 0.9em;
201
+ }
202
+ """
203
+
204
+ # =============================
205
+ # Gradio Interface
206
+ # =============================
207
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
208
+ # Header
209
+ gr.HTML("""
210
+ <div id="header">
211
+ <h1>🩺 ChatDoctor AI Assistant</h1>
212
+ <p>Your AI-powered medical conversation partner with Voice Support</p>
213
+ </div>
214
+ """)
215
+
216
+ # Disclaimer
217
+ gr.HTML("""
218
+ <div class="disclaimer">
219
+ <h3>⚠️ Medical Disclaimer</h3>
220
+ <p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
221
+ It is NOT a substitute for professional medical advice, diagnosis, or treatment.
222
+ Always seek the advice of your physician or other qualified health provider with any questions
223
+ you may have regarding a medical condition. Never disregard professional medical advice or
224
+ delay in seeking it because of something you have read here.</p>
225
+ </div>
226
+ """)
227
+
228
+ with gr.Row():
229
+ with gr.Column(scale=7):
230
+ # Chatbot Interface
231
+ chatbot = gr.Chatbot(
232
+ height=500,
233
+ placeholder="<div style='text-align: center; padding: 40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. Type or speak your question!</p></div>",
234
+ show_label=False,
235
+ avatar_images=(None, "🤖"),
236
+ )
237
+
238
+ with gr.Row():
239
+ msg = gr.Textbox(
240
+ placeholder="Type your message here... (e.g., 'I have a headache')",
241
+ show_label=False,
242
+ scale=9,
243
+ container=False
244
+ )
245
+ submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
246
+
247
+ with gr.Row():
248
+ clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
249
+ retry_btn = gr.Button("🔄 Retry", scale=1)
250
+
251
+ with gr.Column(scale=3):
252
+ # Voice Input Section
253
+ gr.HTML("<div class='voice-section'><h3 style='color: white; text-align: center; margin-top: 0;'>🎤 Voice Features</h3></div>")
254
+
255
+ audio_input = gr.Audio(
256
+ sources=["microphone"],
257
+ type="filepath",
258
+ label="🎙️ Speak Your Question",
259
+ show_download_button=False
260
+ )
261
+
262
+ transcribed_text = gr.Textbox(
263
+ label="📝 Transcribed Text",
264
+ placeholder="Your speech will appear here...",
265
+ interactive=False,
266
+ lines=3
267
+ )
268
+
269
+ send_voice_btn = gr.Button("Send Voice Message 🔊", variant="primary")
270
+
271
+ gr.Markdown("---")
272
+
273
+ # Voice Output
274
+ tts_enabled = gr.Checkbox(
275
+ label="🔊 Enable Text-to-Speech for responses",
276
+ value=True,
277
+ info="Hear the doctor's response"
278
+ )
279
+
280
+ audio_output = gr.Audio(
281
+ label="🔈 AI Response Audio",
282
+ autoplay=False,
283
+ visible=True
284
+ )
285
+
286
+ # Examples
287
+ gr.Examples(
288
+ examples=[
289
+ "I have a persistent headache for 3 days. What should I do?",
290
+ "What are the symptoms of diabetes?",
291
+ "How can I improve my sleep quality?",
292
+ "I have a fever and sore throat. Should I be concerned?",
293
+ "What are some natural ways to reduce stress?",
294
+ ],
295
+ inputs=msg,
296
+ label="💡 Example Questions"
297
+ )
298
+
299
+ # Settings (collapsed by default)
300
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
301
+ temperature_slider = gr.Slider(
302
+ minimum=0.1,
303
+ maximum=1.0,
304
+ value=TEMPERATURE,
305
+ step=0.1,
306
+ label="Temperature (Creativity)",
307
+ info="Higher values make responses more creative but less focused"
308
+ )
309
+ max_tokens_slider = gr.Slider(
310
+ minimum=50,
311
+ maximum=500,
312
+ value=MAX_NEW_TOKENS,
313
+ step=50,
314
+ label="Max Response Length",
315
+ info="Maximum number of tokens in response"
316
+ )
317
+ top_k_slider = gr.Slider(
318
+ minimum=1,
319
+ maximum=100,
320
+ value=TOP_K,
321
+ step=1,
322
+ label="Top K",
323
+ info="Limits vocabulary selection"
324
+ )
325
+
326
+ # Footer
327
+ gr.HTML("""
328
+ <footer>
329
+ <p>Powered by ChatDoctor Model | Built with Gradio | Voice-Enabled 🎤</p>
330
+ <p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
331
+ </footer>
332
+ """)
333
+
334
+ # =============================
335
+ # Event Handlers
336
+ # =============================
337
+
338
+ def user_message(user_msg, history):
339
+ return "", history + [[user_msg, None]], None
340
+
341
+ def bot_response(history, temp, max_tok, top_k_val, tts_enabled_val):
342
+ global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
343
+ TEMPERATURE = temp
344
+ MAX_NEW_TOKENS = int(max_tok)
345
+ TOP_K = int(top_k_val)
346
+
347
+ user_msg = history[-1][0]
348
+ bot_msg = chat_function(user_msg, history[:-1])
349
+ history[-1][1] = bot_msg
350
+
351
+ # Generate audio if TTS is enabled
352
+ audio_file = None
353
+ if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
354
+ audio_file = text_to_speech(bot_msg)
355
+
356
+ return history, audio_file
357
+
358
+ def transcribe_audio(audio_file):
359
+ """Transcribe audio to text using Whisper"""
360
+ if audio_file is None:
361
+ return ""
362
+
363
+ try:
364
+ import whisper
365
+ model = whisper.load_model("base")
366
+ result = model.transcribe(audio_file)
367
+ return result["text"]
368
+ except ImportError:
369
+ return "Error: Please install whisper: pip install openai-whisper"
370
+ except Exception as e:
371
+ return f"Transcription error: {str(e)}"
372
+
373
+ def process_voice_input(audio_file, history, temp, max_tok, top_k_val, tts_enabled_val):
374
+ """Process voice input: transcribe -> send -> get response"""
375
+ if audio_file is None:
376
+ return history, "", None, None
377
+
378
+ # Transcribe
379
+ transcribed = transcribe_audio(audio_file)
380
+
381
+ if transcribed.startswith("Error:"):
382
+ return history, transcribed, None, None
383
+
384
+ # Add to chat
385
+ history = history + [[transcribed, None]]
386
+
387
+ # Get response
388
+ global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
389
+ TEMPERATURE = temp
390
+ MAX_NEW_TOKENS = int(max_tok)
391
+ TOP_K = int(top_k_val)
392
+
393
+ bot_msg = chat_function(transcribed, history[:-1])
394
+ history[-1][1] = bot_msg
395
+
396
+ # Generate audio if TTS is enabled
397
+ audio_file = None
398
+ if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
399
+ audio_file = text_to_speech(bot_msg)
400
+
401
+ return history, transcribed, None, audio_file
402
+
403
+ # Text input events
404
+ msg.submit(
405
+ user_message,
406
+ [msg, chatbot],
407
+ [msg, chatbot, audio_output],
408
+ queue=False
409
+ ).then(
410
+ bot_response,
411
+ [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
412
+ [chatbot, audio_output]
413
+ )
414
+
415
+ submit_btn.click(
416
+ user_message,
417
+ [msg, chatbot],
418
+ [msg, chatbot, audio_output],
419
+ queue=False
420
+ ).then(
421
+ bot_response,
422
+ [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
423
+ [chatbot, audio_output]
424
+ )
425
+
426
+ # Voice input events
427
+ audio_input.change(
428
+ transcribe_audio,
429
+ [audio_input],
430
+ [transcribed_text]
431
+ )
432
+
433
+ send_voice_btn.click(
434
+ process_voice_input,
435
+ [audio_input, chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
436
+ [chatbot, transcribed_text, audio_input, audio_output]
437
+ )
438
+
439
+ # Clear and retry
440
+ clear_btn.click(lambda: (None, None, None), None, [chatbot, audio_output, transcribed_text], queue=False)
441
+
442
+ retry_btn.click(lambda: None, None, chatbot, queue=False)
443
+
444
+ # =============================
445
+ # Launch Interface
446
+ # =============================
447
+ if __name__ == "__main__":
448
+ print("\n🚀 Launching ChatDoctor Gradio Interface with Voice Support...")
449
+ print("\n📦 Required packages:")
450
+ print(" pip install gradio gTTS openai-whisper")
451
+ print("\nNote: Whisper will download models on first use (~100MB for base model)\n")
452
+
453
+ demo.queue()
454
+ demo.launch(
455
+ server_name="0.0.0.0",
456
+ server_port=7860,
457
+ share=False,
458
+ show_error=True
459
+ )
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ rouge_score
3
+ fire
4
+ openai
5
+ git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176
6
+ torch
7
+ sentencepiece
8
+ tokenizers==0.13.3
9
+ wandb
10
+ accelerate
teak.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, itertools, bisect, gc
2
+ from transformers import LlamaTokenizer, LlamaForCausalLM
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
4
+ import transformers
5
+ import torch
6
+ from accelerate import Accelerator
7
+ import accelerate
8
+ import time
9
+
10
+ model = None
11
+ tokenizer = None
12
+ generator = None
13
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
14
+
15
+ def load_model(model_name, eight_bit=0, device_map="auto"):
16
+ global model, tokenizer, generator
17
+
18
+ print("Loading "+model_name+"...")
19
+
20
+ if device_map == "zero":
21
+ device_map = "balanced_low_0"
22
+
23
+ # config
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
26
+ print('gpu_count', gpu_count)
27
+
28
+ tokenizer = LlamaTokenizer.from_pretrained(model_name)
29
+ model = LlamaForCausalLM.from_pretrained(
30
+ model_name,
31
+ #device_map=device_map,
32
+ #device_map="auto",
33
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
34
+ #max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"},
35
+ #load_in_8bit=eight_bit,
36
+ #from_tf=True,
37
+ low_cpu_mem_usage=True,
38
+ load_in_8bit=False,
39
+ cache_dir="cache"
40
+ ).to(device)
41
+
42
+ generator = model.generate
43
+
44
+ load_model(r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained")
45
+
46
+ First_chat = "ChatDoctor: I am ChatDoctor, what medical questions do you have?"
47
+ print(First_chat)
48
+ history = []
49
+ history.append(First_chat)
50
+
51
+ def go():
52
+ invitation = "ChatDoctor: "
53
+ human_invitation = "Patient: "
54
+
55
+ # input
56
+ msg = input(human_invitation)
57
+ print("")
58
+
59
+ history.append(human_invitation + msg)
60
+
61
+ fulltext = "If you are a doctor, please answer the medical questions based on the patient's description. \n\n" + "\n\n".join(history) + "\n\n" + invitation
62
+ #fulltext = "\n\n".join(history) + "\n\n" + invitation
63
+
64
+ #print('SENDING==========')
65
+ #print(fulltext)
66
+ #print('==========')
67
+
68
+ generated_text = ""
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+ gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.to(device)
71
+ in_tokens = len(gen_in)
72
+ with torch.no_grad():
73
+ generated_ids = generator(
74
+ gen_in,
75
+ max_new_tokens=200,
76
+ use_cache=True,
77
+ pad_token_id=tokenizer.eos_token_id,
78
+ num_return_sequences=1,
79
+ do_sample=True,
80
+ repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx:
81
+ temperature=0.5, # default: 1.0
82
+ top_k = 50, # default: 50
83
+ top_p = 1.0, # default: 1.0
84
+ early_stopping=True,
85
+ )
86
+ generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # for some reason, batch_decode returns an array of one element?
87
+
88
+ text_without_prompt = generated_text[len(fulltext):]
89
+
90
+ response = text_without_prompt
91
+
92
+ response = response.split(human_invitation)[0]
93
+
94
+ response.strip()
95
+
96
+ print(invitation + response)
97
+
98
+ print("")
99
+
100
+ history.append(invitation + response)
101
+
102
+ while True:
103
+ go()
test.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
6
+
7
+ # =============================
8
+ # Configuration
9
+ # =============================
10
+ MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
11
+ MAX_NEW_TOKENS = 200
12
+ TEMPERATURE = 0.5
13
+ TOP_K = 50
14
+ REPETITION_PENALTY = 1.1
15
+
16
+ # Detect device
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"Loading model from {MODEL_PATH} on {device}...")
19
+
20
+ # =============================
21
+ # Load Tokenizer and Model
22
+ # =============================
23
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
24
+ model = LlamaForCausalLM.from_pretrained(
25
+ MODEL_PATH,
26
+ device_map="auto",
27
+ torch_dtype=torch.float16,
28
+ low_cpu_mem_usage=True
29
+ )
30
+
31
+ generator = model.generate
32
+ print("✅ ChatDoctor model loaded successfully!\n")
33
+
34
+ # =============================
35
+ # System Prompt
36
+ # =============================
37
+ SYSTEM_PROMPT = """
38
+ You are ChatDoctor — a friendly, professional, and caring virtual doctor.
39
+ Whenever a patient describes their symptoms:
40
+ 1. Always include a recommendation for diet, fluids, and proteins appropriate for recovery.
41
+ - Fruits: citrus (orange, lemon), kiwi, papaya
42
+ - Vegetables: leafy greens, carrots, spinach
43
+ - Fluids: warm soups, herbal teas, coconut water
44
+ - Proteins: boiled eggs, lentils, fish, chicken soup
45
+ - Extras: garlic, ginger, turmeric
46
+ 2. Recommend safe over-the-counter medicines if applicable (e.g., paracetamol for fever).
47
+ 3. Ask follow-up questions if needed to understand the patient's condition better.
48
+ 4. Always encourage the patient to see a real doctor if symptoms persist, worsen, or are serious.
49
+ 5. Provide clear, warm, and empathetic advice.
50
+ 6. Make your response structured and easy to understand.
51
+ 7. Even if the patient only mentions a symptom, always include diet, fluids, protein, and care suggestions automatically.
52
+ """
53
+
54
+ # =============================
55
+ # Stopping Criteria
56
+ # =============================
57
+ class StopOnTokens(StoppingCriteria):
58
+ def __init__(self, stop_ids):
59
+ self.stop_ids = stop_ids
60
+
61
+ def __call__(self, input_ids, scores, **kwargs):
62
+ for stop_id_seq in self.stop_ids:
63
+ if len(stop_id_seq) == 1:
64
+ if input_ids[0][-1] == stop_id_seq[0]:
65
+ return True
66
+ else:
67
+ if len(input_ids[0]) >= len(stop_id_seq):
68
+ if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
69
+ return True
70
+ return False
71
+
72
+ # =============================
73
+ # Chat History (Global)
74
+ # =============================
75
+ conversation_history = []
76
+
77
+ # =============================
78
+ # Get Response Function
79
+ # =============================
80
+ def get_response(user_input, history_context):
81
+ """Generate response from ChatDoctor model"""
82
+ # Build conversation from history
83
+ history_text = []
84
+ for human, assistant in history_context:
85
+ if human:
86
+ history_text.append("Patient: " + human)
87
+ if assistant:
88
+ history_text.append("ChatDoctor: " + assistant)
89
+
90
+ # Add current user input
91
+ history_text.append("Patient: " + user_input)
92
+
93
+ # Build full prompt including system instructions
94
+ prompt = SYSTEM_PROMPT + "\n\nConversation so far:\n" + "\n".join(history_text) + "\nChatDoctor:"
95
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
96
+
97
+ # Define stop words and their token IDs
98
+ stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
99
+ stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
100
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
101
+
102
+ # Generate model response
103
+ with torch.no_grad():
104
+ output_ids = generator(
105
+ input_ids,
106
+ max_new_tokens=MAX_NEW_TOKENS,
107
+ do_sample=True,
108
+ temperature=TEMPERATURE,
109
+ top_k=TOP_K,
110
+ repetition_penalty=REPETITION_PENALTY,
111
+ stopping_criteria=stopping_criteria,
112
+ pad_token_id=tokenizer.eos_token_id,
113
+ eos_token_id=tokenizer.eos_token_id
114
+ )
115
+
116
+ # Decode and clean response
117
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
118
+ response = full_output[len(prompt):].strip()
119
+
120
+ # Remove any "Patient:" that might have slipped through
121
+ for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
122
+ if stop_word in response:
123
+ response = response.split(stop_word)[0].strip()
124
+ break
125
+
126
+ # Free memory
127
+ del input_ids, output_ids
128
+ gc.collect()
129
+ torch.cuda.empty_cache()
130
+
131
+ return response
132
+
133
+ # =============================
134
+ # Gradio Chat Function
135
+ # =============================
136
+ def chat_function(message, history):
137
+ """Gradio chat interface function"""
138
+ if not message.strip():
139
+ return ""
140
+
141
+ try:
142
+ response = get_response(message, history)
143
+ return response
144
+ except Exception as e:
145
+ return f"Error: {str(e)}"
146
+
147
+ # =============================
148
+ # Custom CSS
149
+ # =============================
150
+ custom_css = """
151
+ #header {
152
+ text-align: center;
153
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
154
+ color: white;
155
+ padding: 20px;
156
+ border-radius: 10px;
157
+ margin-bottom: 20px;
158
+ }
159
+
160
+ #header h1 {
161
+ margin: 0;
162
+ font-size: 2.5em;
163
+ }
164
+
165
+ #header p {
166
+ margin: 10px 0 0 0;
167
+ font-size: 1.1em;
168
+ opacity: 0.9;
169
+ }
170
+
171
+ .disclaimer {
172
+ background-color: #fff3cd;
173
+ border: 1px solid #ffc107;
174
+ border-radius: 8px;
175
+ padding: 15px;
176
+ margin: 20px 0;
177
+ color: #856404;
178
+ }
179
+
180
+ .disclaimer h3 {
181
+ margin-top: 0;
182
+ color: #856404;
183
+ }
184
+
185
+ footer {
186
+ text-align: center;
187
+ margin-top: 30px;
188
+ color: #666;
189
+ font-size: 0.9em;
190
+ }
191
+ """
192
+
193
+ # =============================
194
+ # Gradio Interface
195
+ # =============================
196
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
197
+ # Header
198
+ gr.HTML("""
199
+ <div id="header">
200
+ <h1>🩺 ChatDoctor AI Assistant</h1>
201
+ <p>Your AI-powered medical conversation partner</p>
202
+ </div>
203
+ """)
204
+
205
+ # Disclaimer
206
+ gr.HTML("""
207
+ <div class="disclaimer">
208
+ <h3>⚠️ Medical Disclaimer</h3>
209
+ <p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
210
+ It is NOT a substitute for professional medical advice, diagnosis, or treatment.
211
+ Always seek the advice of your physician or other qualified health provider with any questions
212
+ you may have regarding a medical condition. Never disregard professional medical advice or
213
+ delay in seeking it because of something you have read here.</p>
214
+ </div>
215
+ """)
216
+
217
+ # Chatbot Interface
218
+ chatbot = gr.Chatbot(
219
+ height=500,
220
+ placeholder="<div style='text-align: center; padding: 40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. How can I assist you today?</p></div>",
221
+ show_label=False,
222
+ avatar_images=(None, "🤖"),
223
+ )
224
+
225
+ with gr.Row():
226
+ msg = gr.Textbox(
227
+ placeholder="Type your message here... (e.g., 'I have a headache')",
228
+ show_label=False,
229
+ scale=9,
230
+ container=False
231
+ )
232
+ submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
233
+
234
+ with gr.Row():
235
+ clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
236
+ retry_btn = gr.Button("🔄 Retry", scale=1)
237
+
238
+ # Examples
239
+ gr.Examples(
240
+ examples=[
241
+ "I have a persistent headache for 3 days. What should I do?",
242
+ "What are the symptoms of diabetes?",
243
+ "How can I improve my sleep quality?",
244
+ "I have a fever and sore throat. Should I be concerned?",
245
+ "What are some natural ways to reduce stress?",
246
+ ],
247
+ inputs=msg,
248
+ label="💡 Example Questions"
249
+ )
250
+
251
+ # Settings (collapsed by default)
252
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
253
+ temperature_slider = gr.Slider(
254
+ minimum=0.1,
255
+ maximum=1.0,
256
+ value=TEMPERATURE,
257
+ step=0.1,
258
+ label="Temperature (Creativity)",
259
+ info="Higher values make responses more creative but less focused"
260
+ )
261
+ max_tokens_slider = gr.Slider(
262
+ minimum=50,
263
+ maximum=500,
264
+ value=MAX_NEW_TOKENS,
265
+ step=50,
266
+ label="Max Response Length",
267
+ info="Maximum number of tokens in response"
268
+ )
269
+ top_k_slider = gr.Slider(
270
+ minimum=1,
271
+ maximum=100,
272
+ value=TOP_K,
273
+ step=1,
274
+ label="Top K",
275
+ info="Limits vocabulary selection"
276
+ )
277
+
278
+ # Footer
279
+ gr.HTML(f"""
280
+ <footer>
281
+ <p>Powered by ChatDoctor Model | Built with Gradio</p>
282
+ <p>Device: {device.upper()} | Model: LLaMA-based Medical AI</p>
283
+ </footer>
284
+ """)
285
+
286
+ # Event handlers
287
+ def user_message(user_msg, history):
288
+ return "", history + [[user_msg, None]]
289
+
290
+ def bot_response(history, temp, max_tok, top_k_val):
291
+ global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
292
+ TEMPERATURE = temp
293
+ MAX_NEW_TOKENS = int(max_tok)
294
+ TOP_K = int(top_k_val)
295
+
296
+ user_msg = history[-1][0]
297
+ bot_msg = chat_function(user_msg, history[:-1])
298
+ history[-1][1] = bot_msg
299
+ return history
300
+
301
+ # Connect events
302
+ msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
303
+ bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
304
+ )
305
+
306
+ submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
307
+ bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
308
+ )
309
+
310
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
311
+
312
+ def retry_last():
313
+ return None
314
+
315
+ retry_btn.click(retry_last, None, chatbot, queue=False)
316
+
317
+ # =============================
318
+ # Launch Interface
319
+ # =============================
320
+ if __name__ == "__main__":
321
+ print("\n🚀 Launching ChatDoctor Gradio Interface...")
322
+ demo.queue()
323
+ demo.launch(
324
+ server_name="0.0.0.0", # Accessible from network
325
+ server_port=7860,
326
+ share=False, # Set to True to create public link
327
+ show_error=True
328
+ )
train.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import logging
17
+ from dataclasses import dataclass, field
18
+ from typing import Optional, Dict, Sequence
19
+
20
+ import torch
21
+ import transformers
22
+ from torch.utils.data import Dataset
23
+ from transformers import Trainer
24
+
25
+ import utils
26
+
27
+ IGNORE_INDEX = -100
28
+ DEFAULT_PAD_TOKEN = "[PAD]"
29
+ DEFAULT_EOS_TOKEN = "</s>"
30
+ DEFAULT_BOS_TOKEN = "</s>"
31
+ DEFAULT_UNK_TOKEN = "</s>"
32
+ PROMPT_DICT = {
33
+ "prompt_input": (
34
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
35
+ "Write a response that appropriately completes the request.\n\n"
36
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
37
+ ),
38
+ "prompt_no_input": (
39
+ "Below is an instruction that describes a task. "
40
+ "Write a response that appropriately completes the request.\n\n"
41
+ "### Instruction:\n{instruction}\n\n### Response:"
42
+ ),
43
+ }
44
+
45
+
46
+ @dataclass
47
+ class ModelArguments:
48
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
49
+
50
+
51
+ @dataclass
52
+ class DataArguments:
53
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
54
+
55
+
56
+ @dataclass
57
+ class TrainingArguments(transformers.TrainingArguments):
58
+ cache_dir: Optional[str] = field(default=None)
59
+ optim: str = field(default="adamw_torch")
60
+ model_max_length: int = field(
61
+ default=512,
62
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
63
+ )
64
+
65
+
66
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
67
+ """Collects the state dict and dump to disk."""
68
+ state_dict = trainer.model.state_dict()
69
+ if trainer.args.should_save:
70
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
71
+ del state_dict
72
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
73
+
74
+
75
+ def smart_tokenizer_and_embedding_resize(
76
+ special_tokens_dict: Dict,
77
+ tokenizer: transformers.PreTrainedTokenizer,
78
+ model: transformers.PreTrainedModel,
79
+ ):
80
+ """Resize tokenizer and embedding.
81
+
82
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
83
+ """
84
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
85
+ model.resize_token_embeddings(len(tokenizer))
86
+
87
+ if num_new_tokens > 0:
88
+ input_embeddings = model.get_input_embeddings().weight.data
89
+ output_embeddings = model.get_output_embeddings().weight.data
90
+
91
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
92
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
93
+
94
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
95
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
96
+
97
+
98
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
99
+ """Tokenize a list of strings."""
100
+ tokenized_list = [
101
+ tokenizer(
102
+ text,
103
+ return_tensors="pt",
104
+ padding="longest",
105
+ max_length=tokenizer.model_max_length,
106
+ truncation=True,
107
+ )
108
+ for text in strings
109
+ ]
110
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
111
+ input_ids_lens = labels_lens = [
112
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
113
+ ]
114
+ return dict(
115
+ input_ids=input_ids,
116
+ labels=labels,
117
+ input_ids_lens=input_ids_lens,
118
+ labels_lens=labels_lens,
119
+ )
120
+
121
+
122
+ def preprocess(
123
+ sources: Sequence[str],
124
+ targets: Sequence[str],
125
+ tokenizer: transformers.PreTrainedTokenizer,
126
+ ) -> Dict:
127
+ """Preprocess the data by tokenizing."""
128
+ examples = [s + t for s, t in zip(sources, targets)]
129
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
130
+ input_ids = examples_tokenized["input_ids"]
131
+ labels = copy.deepcopy(input_ids)
132
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
133
+ label[:source_len] = IGNORE_INDEX
134
+ return dict(input_ids=input_ids, labels=labels)
135
+
136
+
137
+ class SupervisedDataset(Dataset):
138
+ """Dataset for supervised fine-tuning."""
139
+
140
+ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
141
+ super(SupervisedDataset, self).__init__()
142
+ logging.warning("Loading data...")
143
+ list_data_dict = utils.jload(data_path)
144
+
145
+ logging.warning("Formatting inputs...")
146
+ prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
147
+ sources = [
148
+ prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
149
+ for example in list_data_dict
150
+ ]
151
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
152
+
153
+ logging.warning("Tokenizing inputs... This may take some time...")
154
+ data_dict = preprocess(sources, targets, tokenizer)
155
+
156
+ self.input_ids = data_dict["input_ids"]
157
+ self.labels = data_dict["labels"]
158
+
159
+ def __len__(self):
160
+ return len(self.input_ids)
161
+
162
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
163
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
164
+
165
+
166
+ @dataclass
167
+ class DataCollatorForSupervisedDataset(object):
168
+ """Collate examples for supervised fine-tuning."""
169
+
170
+ tokenizer: transformers.PreTrainedTokenizer
171
+
172
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
173
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
174
+ input_ids = torch.nn.utils.rnn.pad_sequence(
175
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
176
+ )
177
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
178
+ return dict(
179
+ input_ids=input_ids,
180
+ labels=labels,
181
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
182
+ )
183
+
184
+
185
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
186
+ """Make dataset and collator for supervised fine-tuning."""
187
+ train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
188
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
189
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
190
+
191
+
192
+ def train():
193
+ parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
194
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
195
+
196
+ model = transformers.AutoModelForCausalLM.from_pretrained(
197
+ model_args.model_name_or_path,
198
+ cache_dir=training_args.cache_dir,
199
+ )
200
+
201
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
202
+ model_args.model_name_or_path,
203
+ cache_dir=training_args.cache_dir,
204
+ model_max_length=training_args.model_max_length,
205
+ padding_side="right",
206
+ use_fast=False,
207
+ )
208
+ if tokenizer.pad_token is None:
209
+ smart_tokenizer_and_embedding_resize(
210
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
211
+ tokenizer=tokenizer,
212
+ model=model,
213
+ )
214
+ if "llama" in model_args.model_name_or_path:
215
+ tokenizer.add_special_tokens(
216
+ {
217
+ "eos_token": DEFAULT_EOS_TOKEN,
218
+ "bos_token": DEFAULT_BOS_TOKEN,
219
+ "unk_token": DEFAULT_UNK_TOKEN,
220
+ }
221
+ )
222
+
223
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
224
+ trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
225
+ trainer.train()
226
+ trainer.save_state()
227
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
228
+
229
+
230
+ if __name__ == "__main__":
231
+ train()
train_lora.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List
4
+
5
+ import fire
6
+ import torch
7
+ import transformers
8
+ from datasets import load_dataset
9
+ from typing import List, Optional, Union
10
+
11
+ """
12
+ Unused imports:
13
+ import torch.nn as nn
14
+ import bitsandbytes as bnb
15
+ """
16
+ from peft import ( # noqa: E402
17
+ LoraConfig,
18
+ BottleneckConfig,
19
+ get_peft_model,
20
+ get_peft_model_state_dict,
21
+ prepare_model_for_int8_training,
22
+ set_peft_model_state_dict,
23
+ )
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LLaMATokenizer # noqa: F402
25
+
26
+
27
+ def train(
28
+ # model/data params
29
+ base_model: str = "", # the only required argument
30
+ data_path: str = "yahma/alpaca-cleaned",
31
+ output_dir: str = "./lora-alpaca",
32
+ adapter_name: str = "lora",
33
+ # training hyperparams
34
+ batch_size: int = 128,
35
+ micro_batch_size: int = 4,
36
+ num_epochs: int = 3,
37
+ learning_rate: float = 3e-4,
38
+ cutoff_len: int = 256,
39
+ val_set_size: int = 2000,
40
+ use_gradient_checkpointing: bool = False,
41
+ eval_step: int = 200,
42
+ save_step: int = 200,
43
+ # lora hyperparams
44
+ lora_r: int = 8,
45
+ lora_alpha: int = 16,
46
+ lora_dropout: float = 0.05,
47
+ lora_target_modules: List[str] = None,
48
+ # bottleneck adapter hyperparams
49
+ bottleneck_size: int = 256,
50
+ non_linearity: str = "tanh",
51
+ adapter_dropout: float = 0.0,
52
+ use_parallel_adapter: bool = False,
53
+ use_adapterp: bool = False,
54
+ target_modules: List[str] = None,
55
+ scaling: Union[float, str] = 1.0,
56
+ # llm hyperparams
57
+ train_on_inputs: bool = True, # if False, masks out inputs in loss
58
+ group_by_length: bool = False, # faster, but produces an odd training loss curve
59
+ # wandb params
60
+ wandb_project: str = "",
61
+ wandb_run_name: str = "",
62
+ wandb_watch: str = "", # options: false | gradients | all
63
+ wandb_log_model: str = "", # options: false | true
64
+ resume_from_checkpoint: str = None, # either training checkpoint or final adapter
65
+ ):
66
+ print(
67
+ f"Finetuning model with params:\n"
68
+ f"base_model: {base_model}\n"
69
+ f"data_path: {data_path}\n"
70
+ f"output_dir: {output_dir}\n"
71
+ f"batch_size: {batch_size}\n"
72
+ f"micro_batch_size: {micro_batch_size}\n"
73
+ f"num_epochs: {num_epochs}\n"
74
+ f"learning_rate: {learning_rate}\n"
75
+ f"cutoff_len: {cutoff_len}\n"
76
+ f"val_set_size: {val_set_size}\n"
77
+ f"use_gradient_checkpointing: {use_gradient_checkpointing}\n"
78
+ f"lora_r: {lora_r}\n"
79
+ f"lora_alpha: {lora_alpha}\n"
80
+ f"lora_dropout: {lora_dropout}\n"
81
+ f"lora_target_modules: {lora_target_modules}\n"
82
+ f"bottleneck_size: {bottleneck_size}\n"
83
+ f"non_linearity: {non_linearity}\n"
84
+ f"adapter_dropout: {adapter_dropout}\n"
85
+ f"use_parallel_adapter: {use_parallel_adapter}\n"
86
+ f"use_adapterp: {use_adapterp}\n"
87
+ f"train_on_inputs: {train_on_inputs}\n"
88
+ f"scaling: {scaling}\n"
89
+ f"adapter_name: {adapter_name}\n"
90
+ f"target_modules: {target_modules}\n"
91
+ f"group_by_length: {group_by_length}\n"
92
+ f"wandb_project: {wandb_project}\n"
93
+ f"wandb_run_name: {wandb_run_name}\n"
94
+ f"wandb_watch: {wandb_watch}\n"
95
+ f"wandb_log_model: {wandb_log_model}\n"
96
+ f"resume_from_checkpoint: {resume_from_checkpoint}\n"
97
+ )
98
+ assert (
99
+ base_model
100
+ ), "Please specify a --base_model, e.g. --base_model='decapoda-research/LLaMA-7b-hf'"
101
+ gradient_accumulation_steps = batch_size // micro_batch_size
102
+
103
+ device_map = "auto"
104
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
105
+ ddp = world_size != 1
106
+ if ddp:
107
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
108
+ gradient_accumulation_steps = gradient_accumulation_steps // world_size
109
+
110
+ # Check if parameter passed or if set within environ
111
+ use_wandb = len(wandb_project) > 0 or (
112
+ "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
113
+ )
114
+ # Only overwrite environ if wandb param passed
115
+ if len(wandb_project) > 0:
116
+ os.environ["WANDB_PROJECT"] = wandb_project
117
+ if len(wandb_watch) > 0:
118
+ os.environ["WANDB_WATCH"] = wandb_watch
119
+ if len(wandb_log_model) > 0:
120
+ os.environ["WANDB_LOG_MODEL"] = wandb_log_model
121
+
122
+ model = AutoModelForCausalLM.from_pretrained(
123
+ base_model,
124
+ load_in_8bit=True,
125
+ torch_dtype=torch.float16,
126
+ device_map=device_map,
127
+ )
128
+
129
+ if model.config.model_type == "LLaMA":
130
+ # Due to the name of transformers' LLaMATokenizer, we have to do this
131
+ tokenizer = LLaMATokenizer.from_pretrained(base_model)
132
+ else:
133
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
134
+
135
+ tokenizer.pad_token_id = (
136
+ 0 # unk. we want this to be different from the eos token
137
+ )
138
+ tokenizer.padding_side = "left" # Allow batched inference
139
+
140
+ def tokenize(prompt, add_eos_token=True):
141
+ # there's probably a way to do this with the tokenizer settings
142
+ # but again, gotta move fast
143
+ result = tokenizer(
144
+ prompt,
145
+ truncation=True,
146
+ max_length=cutoff_len,
147
+ padding=False,
148
+ return_tensors=None,
149
+ )
150
+ if (
151
+ result["input_ids"][-1] != tokenizer.eos_token_id
152
+ and len(result["input_ids"]) < cutoff_len
153
+ and add_eos_token
154
+ ):
155
+ result["input_ids"].append(tokenizer.eos_token_id)
156
+ result["attention_mask"].append(1)
157
+
158
+ result["labels"] = result["input_ids"].copy()
159
+
160
+ return result
161
+
162
+ def generate_and_tokenize_prompt(data_point):
163
+ full_prompt = generate_prompt(data_point)
164
+ tokenized_full_prompt = tokenize(full_prompt)
165
+ if not train_on_inputs:
166
+ user_prompt = generate_prompt({**data_point, "output": ""})
167
+ tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
168
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
169
+
170
+ tokenized_full_prompt["labels"] = [
171
+ -100
172
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
173
+ user_prompt_len:
174
+ ] # could be sped up, probably
175
+ return tokenized_full_prompt
176
+
177
+ model = prepare_model_for_int8_training(model, use_gradient_checkpointing=use_gradient_checkpointing)
178
+ if adapter_name == "lora":
179
+ config = LoraConfig(
180
+ r=lora_r,
181
+ lora_alpha=lora_alpha,
182
+ target_modules=lora_target_modules,
183
+ lora_dropout=lora_dropout,
184
+ bias="none",
185
+ task_type="CAUSAL_LM",
186
+ )
187
+ elif adapter_name == "bottleneck":
188
+ config = BottleneckConfig(
189
+ bottleneck_size=bottleneck_size,
190
+ non_linearity=non_linearity,
191
+ adapter_dropout=adapter_dropout,
192
+ use_parallel_adapter=use_parallel_adapter,
193
+ use_adapterp=use_adapterp,
194
+ target_modules=target_modules,
195
+ scaling=scaling,
196
+ bias="none",
197
+ task_type="CAUSAL_LM",
198
+ )
199
+ model = get_peft_model(model, config)
200
+
201
+ if data_path.endswith(".json"): # todo: support jsonl
202
+ data = load_dataset("json", data_files=data_path)
203
+ else:
204
+ data = load_dataset(data_path)
205
+
206
+ if resume_from_checkpoint:
207
+ # Check the available weights and load them
208
+ checkpoint_name = os.path.join(
209
+ resume_from_checkpoint, "pytorch_model.bin"
210
+ ) # Full checkpoint
211
+ if not os.path.exists(checkpoint_name):
212
+ checkpoint_name = os.path.join(
213
+ resume_from_checkpoint, "adapter_model.bin"
214
+ ) # only LoRA model - LoRA config above has to fit
215
+ resume_from_checkpoint = (
216
+ False # So the trainer won't try loading its state
217
+ )
218
+ # The two files above have a different name depending on how they were saved, but are actually the same.
219
+ if os.path.exists(checkpoint_name):
220
+ print(f"Restarting from {checkpoint_name}")
221
+ adapters_weights = torch.load(checkpoint_name)
222
+ model = set_peft_model_state_dict(model, adapters_weights)
223
+ else:
224
+ print(f"Checkpoint {checkpoint_name} not found")
225
+
226
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
227
+
228
+ if val_set_size > 0:
229
+ train_val = data["train"].train_test_split(
230
+ test_size=val_set_size, shuffle=True, seed=42
231
+ )
232
+ train_data = (
233
+ train_val["train"].shuffle().map(generate_and_tokenize_prompt)
234
+ )
235
+ val_data = (
236
+ train_val["test"].shuffle().map(generate_and_tokenize_prompt)
237
+ )
238
+ else:
239
+ train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
240
+ val_data = None
241
+
242
+ if not ddp and torch.cuda.device_count() > 1:
243
+ # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
244
+ model.is_parallelizable = True
245
+ model.model_parallel = True
246
+
247
+ trainer = transformers.Trainer(
248
+ model=model,
249
+ train_dataset=train_data,
250
+ eval_dataset=val_data,
251
+ args=transformers.TrainingArguments(
252
+ per_device_train_batch_size=micro_batch_size,
253
+ gradient_accumulation_steps=gradient_accumulation_steps,
254
+ warmup_steps=100,
255
+ num_train_epochs=num_epochs,
256
+ learning_rate=learning_rate,
257
+ fp16=True,
258
+ logging_steps=10,
259
+ optim="adamw_torch",
260
+ evaluation_strategy="steps" if val_set_size > 0 else "no",
261
+ save_strategy="steps",
262
+ eval_steps=eval_step if val_set_size > 0 else None,
263
+ save_steps=save_step,
264
+ output_dir=output_dir,
265
+ save_total_limit=3,
266
+ load_best_model_at_end=True if val_set_size > 0 else False,
267
+ ddp_find_unused_parameters=False if ddp else None,
268
+ group_by_length=group_by_length,
269
+ report_to="wandb" if use_wandb else None,
270
+ run_name=wandb_run_name if use_wandb else None,
271
+ ),
272
+ data_collator=transformers.DataCollatorForSeq2Seq(
273
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
274
+ ),
275
+ )
276
+ model.config.use_cache = False
277
+
278
+ old_state_dict = model.state_dict
279
+ model.state_dict = (
280
+ lambda self, *_, **__: get_peft_model_state_dict(
281
+ self, old_state_dict()
282
+ )
283
+ ).__get__(model, type(model))
284
+
285
+ if torch.__version__ >= "2" and sys.platform != "win32":
286
+ model = torch.compile(model)
287
+
288
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
289
+
290
+ model.save_pretrained(output_dir)
291
+
292
+ print(
293
+ "\n If there's a warning about missing keys above, please disregard :)"
294
+ )
295
+
296
+
297
+ def generate_prompt(data_point):
298
+ # sorry about the formatting disaster gotta move fast
299
+ if data_point["input"]:
300
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
301
+
302
+ ### Instruction:
303
+ {data_point["instruction"]}
304
+
305
+ ### Input:
306
+ {data_point["input"]}
307
+
308
+ ### Response:
309
+ {data_point["output"]}""" # noqa: E501
310
+ else:
311
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
312
+
313
+ ### Instruction:
314
+ {data_point["instruction"]}
315
+
316
+ ### Response:
317
+ {data_point["output"]}""" # noqa: E501
318
+
319
+
320
+ if __name__ == "__main__":
321
+ fire.Fire(train)
utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ import math
4
+ import os
5
+ import io
6
+ import sys
7
+ import time
8
+ import json
9
+ from typing import Optional, Sequence, Union
10
+
11
+ import openai
12
+ import tqdm
13
+ from openai import openai_object
14
+ import copy
15
+
16
+ StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
17
+
18
+ openai.api_key =''
19
+ openai_org = os.getenv("OPENAI_ORG")
20
+ if openai_org is not None:
21
+ openai.organization = openai_org
22
+ logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
23
+
24
+
25
+ @dataclasses.dataclass
26
+ class OpenAIDecodingArguments(object):
27
+ max_tokens: int = 1800
28
+ temperature: float = 0.2
29
+ top_p: float = 1.0
30
+ n: int = 1
31
+ stream: bool = False
32
+ stop: Optional[Sequence[str]] = None
33
+ presence_penalty: float = 0.0
34
+ frequency_penalty: float = 0.0
35
+ suffix: Optional[str] = None
36
+ logprobs: Optional[int] = None
37
+ echo: bool = False
38
+
39
+
40
+ def openai_completion(
41
+ prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
42
+ decoding_args: OpenAIDecodingArguments,
43
+ model_name="text-davinci-003",
44
+ sleep_time=2,
45
+ batch_size=1,
46
+ max_instances=sys.maxsize,
47
+ max_batches=sys.maxsize,
48
+ return_text=False,
49
+ **decoding_kwargs,
50
+ ) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
51
+ """Decode with OpenAI API.
52
+
53
+ Args:
54
+ prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
55
+ as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
56
+ it can also be a dictionary (or list thereof) as explained here:
57
+ https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
58
+ decoding_args: Decoding arguments.
59
+ model_name: Model name. Can be either in the format of "org/model" or just "model".
60
+ sleep_time: Time to sleep once the rate-limit is hit.
61
+ batch_size: Number of prompts to send in a single request. Only for non chat model.
62
+ max_instances: Maximum number of prompts to decode.
63
+ max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
64
+ return_text: If True, return text instead of full completion object (which contains things like logprob).
65
+ decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
66
+
67
+ Returns:
68
+ A completion or a list of completions.
69
+ Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
70
+ - a string (if return_text is True)
71
+ - an openai_object.OpenAIObject object (if return_text is False)
72
+ - a list of objects of the above types (if decoding_args.n > 1)
73
+ """
74
+ is_single_prompt = isinstance(prompts, (str, dict))
75
+ if is_single_prompt:
76
+ prompts = [prompts]
77
+
78
+ if max_batches < sys.maxsize:
79
+ logging.warning(
80
+ "`max_batches` will be deprecated in the future, please use `max_instances` instead."
81
+ "Setting `max_instances` to `max_batches * batch_size` for now."
82
+ )
83
+ max_instances = max_batches * batch_size
84
+
85
+ prompts = prompts[:max_instances]
86
+ num_prompts = len(prompts)
87
+ prompt_batches = [
88
+ prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
89
+ for batch_id in range(int(math.ceil(num_prompts / batch_size)))
90
+ ]
91
+
92
+ completions = []
93
+ for batch_id, prompt_batch in tqdm.tqdm(
94
+ enumerate(prompt_batches),
95
+ desc="prompt_batches",
96
+ total=len(prompt_batches),
97
+ ):
98
+ batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args
99
+
100
+ while True:
101
+ try:
102
+ shared_kwargs = dict(
103
+ model=model_name,
104
+ **batch_decoding_args.__dict__,
105
+ **decoding_kwargs,
106
+ )
107
+ completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
108
+ choices = completion_batch.choices
109
+
110
+ for choice in choices:
111
+ choice["total_tokens"] = completion_batch.usage.total_tokens
112
+ completions.extend(choices)
113
+ break
114
+ except openai.error.OpenAIError as e:
115
+ logging.warning(f"OpenAIError: {e}.")
116
+ if "Please reduce your prompt" in str(e):
117
+ batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
118
+ logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
119
+ else:
120
+ logging.warning("Hit request rate limit; retrying...")
121
+ time.sleep(sleep_time) # Annoying rate limit on requests.
122
+
123
+ if return_text:
124
+ completions = [completion.text for completion in completions]
125
+ if decoding_args.n > 1:
126
+ # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
127
+ completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
128
+ if is_single_prompt:
129
+ # Return non-tuple if only 1 input and 1 generation.
130
+ (completions,) = completions
131
+ return completions
132
+
133
+
134
+ def _make_w_io_base(f, mode: str):
135
+ if not isinstance(f, io.IOBase):
136
+ f_dirname = os.path.dirname(f)
137
+ if f_dirname != "":
138
+ os.makedirs(f_dirname, exist_ok=True)
139
+ f = open(f, mode=mode)
140
+ return f
141
+
142
+
143
+ def _make_r_io_base(f, mode: str):
144
+ if not isinstance(f, io.IOBase):
145
+ f = open(f, mode=mode)
146
+ return f
147
+
148
+
149
+ def jdump(obj, f, mode="w", indent=4, default=str):
150
+ """Dump a str or dictionary to a file in json format.
151
+
152
+ Args:
153
+ obj: An object to be written.
154
+ f: A string path to the location on disk.
155
+ mode: Mode for opening the file.
156
+ indent: Indent for storing json dictionaries.
157
+ default: A function to handle non-serializable entries; defaults to `str`.
158
+ """
159
+ f = _make_w_io_base(f, mode)
160
+ if isinstance(obj, (dict, list)):
161
+ json.dump(obj, f, indent=indent, default=default)
162
+ elif isinstance(obj, str):
163
+ f.write(obj)
164
+ else:
165
+ raise ValueError(f"Unexpected type: {type(obj)}")
166
+ f.close()
167
+
168
+
169
+ def jload(f, mode="r"):
170
+ """Load a .json file into a dictionary."""
171
+ f = _make_r_io_base(f, mode)
172
+ jdict = json.load(f)
173
+ f.close()
174
+ return jdict