Upload 15 files
Browse files- .gitattributes +1 -0
- Coreectcodewithoutfronted.py +141 -0
- LICENSE +201 -0
- README.md +182 -12
- alpaca_data.json +3 -0
- chat.py +100 -0
- chatdoctor5k.json +0 -0
- format_dataset.csv +0 -0
- frontend.py +313 -0
- frontend_VOic.py +459 -0
- requirements.txt +10 -0
- teak.py +103 -0
- test.py +328 -0
- train.py +231 -0
- train_lora.py +321 -0
- utils.py +174 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
alpaca_data.json filter=lfs diff=lfs merge=lfs -text
|
Coreectcodewithoutfronted.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
|
| 5 |
+
|
| 6 |
+
# =============================
|
| 7 |
+
# Configuration
|
| 8 |
+
# =============================
|
| 9 |
+
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
|
| 10 |
+
MAX_NEW_TOKENS = 200
|
| 11 |
+
TEMPERATURE = 0.5
|
| 12 |
+
TOP_K = 50
|
| 13 |
+
REPETITION_PENALTY = 1.1
|
| 14 |
+
|
| 15 |
+
# Detect device
|
| 16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
+
print(f"Loading model from {MODEL_PATH} on {device}...")
|
| 18 |
+
|
| 19 |
+
# =============================
|
| 20 |
+
# Load Tokenizer and Model
|
| 21 |
+
# =============================
|
| 22 |
+
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
|
| 23 |
+
model = LlamaForCausalLM.from_pretrained(
|
| 24 |
+
MODEL_PATH,
|
| 25 |
+
device_map="auto",
|
| 26 |
+
torch_dtype=torch.float16,
|
| 27 |
+
low_cpu_mem_usage=True
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
generator = model.generate
|
| 31 |
+
print("✅ ChatDoctor model loaded successfully!\n")
|
| 32 |
+
|
| 33 |
+
# =============================
|
| 34 |
+
# Stopping Criteria
|
| 35 |
+
# =============================
|
| 36 |
+
class StopOnTokens(StoppingCriteria):
|
| 37 |
+
def __init__(self, stop_ids):
|
| 38 |
+
self.stop_ids = stop_ids
|
| 39 |
+
|
| 40 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 41 |
+
for stop_id_seq in self.stop_ids:
|
| 42 |
+
if len(stop_id_seq) == 1:
|
| 43 |
+
if input_ids[0][-1] == stop_id_seq[0]:
|
| 44 |
+
return True
|
| 45 |
+
else:
|
| 46 |
+
if len(input_ids[0]) >= len(stop_id_seq):
|
| 47 |
+
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
|
| 48 |
+
return True
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
# =============================
|
| 52 |
+
# Chat History
|
| 53 |
+
# =============================
|
| 54 |
+
history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"]
|
| 55 |
+
|
| 56 |
+
# =============================
|
| 57 |
+
# Get Response Function
|
| 58 |
+
# =============================
|
| 59 |
+
def get_response(user_input):
|
| 60 |
+
global history
|
| 61 |
+
human_invitation = "Patient: "
|
| 62 |
+
doctor_invitation = "ChatDoctor: "
|
| 63 |
+
|
| 64 |
+
# Add user input to history
|
| 65 |
+
history.append(human_invitation + user_input)
|
| 66 |
+
|
| 67 |
+
# Build conversation prompt
|
| 68 |
+
prompt = "\n".join(history) + "\n" + doctor_invitation
|
| 69 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| 70 |
+
|
| 71 |
+
# Define stop words and their token IDs
|
| 72 |
+
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
|
| 73 |
+
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
|
| 74 |
+
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
|
| 75 |
+
|
| 76 |
+
# Generate model response
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
output_ids = generator(
|
| 79 |
+
input_ids,
|
| 80 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 81 |
+
do_sample=True,
|
| 82 |
+
temperature=TEMPERATURE,
|
| 83 |
+
top_k=TOP_K,
|
| 84 |
+
repetition_penalty=REPETITION_PENALTY,
|
| 85 |
+
stopping_criteria=stopping_criteria,
|
| 86 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 87 |
+
eos_token_id=tokenizer.eos_token_id
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Decode and clean response
|
| 91 |
+
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 92 |
+
response = full_output[len(prompt):].strip()
|
| 93 |
+
|
| 94 |
+
# Remove any "Patient:" that might have slipped through
|
| 95 |
+
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
|
| 96 |
+
if stop_word in response:
|
| 97 |
+
response = response.split(stop_word)[0].strip()
|
| 98 |
+
break
|
| 99 |
+
|
| 100 |
+
# Remove any leading/trailing punctuation artifacts
|
| 101 |
+
response = response.strip()
|
| 102 |
+
|
| 103 |
+
history.append(doctor_invitation + response)
|
| 104 |
+
|
| 105 |
+
# Free memory
|
| 106 |
+
del input_ids, output_ids
|
| 107 |
+
gc.collect()
|
| 108 |
+
torch.cuda.empty_cache()
|
| 109 |
+
|
| 110 |
+
return response
|
| 111 |
+
|
| 112 |
+
# =============================
|
| 113 |
+
# Chat Loop
|
| 114 |
+
# =============================
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
print("\n=== ChatDoctor is ready! ===")
|
| 117 |
+
print("You (the human) = Patient ")
|
| 118 |
+
print("AI = ChatDoctor")
|
| 119 |
+
print("Type 'exit' or 'quit' to end the chat.\n")
|
| 120 |
+
|
| 121 |
+
print("ChatDoctor: Hi there! How can I help you today?\n")
|
| 122 |
+
|
| 123 |
+
while True:
|
| 124 |
+
try:
|
| 125 |
+
user_input = input("Patient: ").strip()
|
| 126 |
+
if user_input.lower() in ["exit", "quit"]:
|
| 127 |
+
print("ChatDoctor: Take care! Goodbye ")
|
| 128 |
+
break
|
| 129 |
+
|
| 130 |
+
if not user_input:
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
response = get_response(user_input)
|
| 134 |
+
print("ChatDoctor:", response, "\n")
|
| 135 |
+
|
| 136 |
+
except KeyboardInterrupt:
|
| 137 |
+
print("\nChatDoctor: Take care! Goodbye")
|
| 138 |
+
break
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"Error: {e}")
|
| 141 |
+
print("Please try again.\n")
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,12 +1,182 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center" width="80%">
|
| 2 |
+
<img src="fig/logo.png" style="width: 40%; min-width: 300px; display: block; margin: auto;">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# [ChatDoctor: A Medical Chat Model Fine-Tuned on a Large Language Model Meta-AI (LLaMA) Using Medical Domain Knowledge](https://www.cureus.com/articles/152858-chatdoctor-a-medical-chat-model-fine-tuned-on-a-large-language-model-meta-ai-llama-using-medical-domain-knowledge#!/)
|
| 7 |
+
Yunxiang Li<sup>1</sup>, Zihan Li<sup>2</sup>, Kai Zhang<sup>3</sup>, Ruilong Dan<sup>4</sup>, Steve Jiang<sup>1</sup>, You Zhang<sup>1</sup>
|
| 8 |
+
<h5>1 UT Southwestern Medical Center, USA</h5>
|
| 9 |
+
<h5>2 University of Illinois at Urbana-Champaign, USA</h5>
|
| 10 |
+
<h5>3 Ohio State University, USA</h5>
|
| 11 |
+
<h5>4 Hangzhou Dianzi University, China</h5>
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
[](https://github.com/HUANGLIZI/ChatDoctor/blob/main/LICENSE)
|
| 15 |
+
[](https://www.python.org/downloads/release/python-390/)
|
| 16 |
+
[](https://www.yunxiangli.top/ChatDoctor/)
|
| 17 |
+
## Resources List
|
| 18 |
+
Autonomous ChatDoctor with Disease Database [Demo](https://huggingface.co/spaces/kenton-li/chatdoctor_csv).
|
| 19 |
+
|
| 20 |
+
100k real conversations between patients and doctors from HealthCareMagic.com [HealthCareMagic-100k](https://drive.google.com/file/d/1lyfqIwlLSClhgrCutWuEe_IACNq6XNUt/view?usp=sharing).
|
| 21 |
+
|
| 22 |
+
Real conversations between patients and doctors from icliniq.com [icliniq-10k](https://drive.google.com/file/d/1ZKbqgYqWc7DJHs3N9TQYQVPdDQmZaClA/view?usp=sharing).
|
| 23 |
+
|
| 24 |
+
Checkpoints of ChatDoctor, [link](https://drive.google.com/drive/folders/11-qPzz9ZdHD6pc47wBSOUSU61MaDPyRh?usp=sharing).
|
| 25 |
+
|
| 26 |
+
Stanford Alpaca data for basic conversational capabilities. [Alpaca link](https://github.com/Kent0n-Li/ChatDoctor/blob/main/alpaca_data.json).
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
<p align="center" width="100%">
|
| 30 |
+
<img src="fig/overview.PNG" style="width: 70%; min-width: 300px; display: block; margin: auto;">
|
| 31 |
+
</p>
|
| 32 |
+
|
| 33 |
+
<p align="center" width="100%">
|
| 34 |
+
<img src="fig/wiki.PNG" style="width: 70%; min-width: 300px; display: block; margin: auto;">
|
| 35 |
+
</p>
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
## Setup:
|
| 39 |
+
In a conda env with pytorch available, run:
|
| 40 |
+
```
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Interactive Demo Page:
|
| 45 |
+
Demo Page: https://huggingface.co/spaces/kenton-li/chatdoctor_csv
|
| 46 |
+
It is worth noting that our model has not yet achieved 100% accurate output, please do not apply it to real clinical scenarios.
|
| 47 |
+
|
| 48 |
+
For those who want to try the online demo, please register for hugging face and fill out this form [link](https://forms.office.com/Pages/ResponsePage.aspx?id=lYZBnaxxMUy1ssGWyOw8ij06Cb8qnDJKvu2bVpV1-ANURUU0TllBWVVHUjQ1MDJUNldGTTZWV1c5UC4u).
|
| 49 |
+
|
| 50 |
+
## Data and model:
|
| 51 |
+
### 1. ChatDoctor Dataset:
|
| 52 |
+
You can download the following training dataset
|
| 53 |
+
|
| 54 |
+
100k real conversations between patients and doctors from HealthCareMagic.com [HealthCareMagic-100k](https://drive.google.com/file/d/1lyfqIwlLSClhgrCutWuEe_IACNq6XNUt/view?usp=sharing).
|
| 55 |
+
|
| 56 |
+
10k real conversations between patients and doctors from icliniq.com [icliniq-10k](https://drive.google.com/file/d/1ZKbqgYqWc7DJHs3N9TQYQVPdDQmZaClA/view?usp=sharing).
|
| 57 |
+
|
| 58 |
+
5k generated conversations between patients and physicians from ChatGPT [GenMedGPT-5k](https://drive.google.com/file/d/1nDTKZ3wZbZWTkFMBkxlamrzbNz0frugg/view?usp=sharing) and [disease database](https://github.com/Kent0n-Li/ChatDoctor/blob/main/format_dataset.csv).
|
| 59 |
+
|
| 60 |
+
Our model was firstly be fine-tuned by Stanford Alpaca's data to have some basic conversational capabilities. [Alpaca link](https://github.com/Kent0n-Li/ChatDoctor/blob/main/alpaca_data.json)
|
| 61 |
+
|
| 62 |
+
### 2. Model Weights:
|
| 63 |
+
|
| 64 |
+
Place the model weights file in the ./pretrained folder.
|
| 65 |
+
|
| 66 |
+
## How to fine-tuning
|
| 67 |
+
|
| 68 |
+
```python
|
| 69 |
+
torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
|
| 70 |
+
--model_name_or_path <your_path_to_hf_converted_llama_ckpt_and_tokenizer> \
|
| 71 |
+
--data_path ./HealthCareMagic-100k.json \
|
| 72 |
+
--bf16 True \
|
| 73 |
+
--output_dir pretrained \
|
| 74 |
+
--num_train_epochs 1 \
|
| 75 |
+
--per_device_train_batch_size 4 \
|
| 76 |
+
--per_device_eval_batch_size 4 \
|
| 77 |
+
--gradient_accumulation_steps 8 \
|
| 78 |
+
--evaluation_strategy "no" \
|
| 79 |
+
--save_strategy "steps" \
|
| 80 |
+
--save_steps 2000 \
|
| 81 |
+
--save_total_limit 1 \
|
| 82 |
+
--learning_rate 2e-6 \
|
| 83 |
+
--weight_decay 0. \
|
| 84 |
+
--warmup_ratio 0.03 \
|
| 85 |
+
--lr_scheduler_type "cosine" \
|
| 86 |
+
--logging_steps 1 \
|
| 87 |
+
--fsdp "full_shard auto_wrap" \
|
| 88 |
+
--fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
|
| 89 |
+
--tf32 True
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
Fine-tuning with Lora
|
| 94 |
+
```python
|
| 95 |
+
WORLD_SIZE=6 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --nproc_per_node=6 --master_port=4567 train_lora.py \
|
| 96 |
+
--base_model './weights-alpaca/' \
|
| 97 |
+
--data_path 'HealthCareMagic-100k.json' \
|
| 98 |
+
--output_dir './lora_models/' \
|
| 99 |
+
--batch_size 32 \
|
| 100 |
+
--micro_batch_size 4 \
|
| 101 |
+
--num_epochs 1 \
|
| 102 |
+
--learning_rate 3e-5 \
|
| 103 |
+
--cutoff_len 256 \
|
| 104 |
+
--val_set_size 120 \
|
| 105 |
+
--adapter_name lora
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## How to inference
|
| 109 |
+
You can build a ChatDoctor model on your own machine and communicate with it.
|
| 110 |
+
```python
|
| 111 |
+
python chat.py
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
cd Autonomous_ChatDoctor_csv
|
| 116 |
+
python chat_csv.py
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
```python
|
| 120 |
+
cd Autonomous_ChatDoctor_Wikipedia
|
| 121 |
+
python chat_wiki.py
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
## Examples:
|
| 125 |
+
We fine-tuned our model on HealthCareMagic-100K, and test our model on iCliniq-10k.
|
| 126 |
+
|
| 127 |
+
```
|
| 128 |
+
Patient question from icliniq: Hi doctor,I am getting one-sided headache. I do not get them often. They are behind my right eye and temple area. Now this time I am also fighting with heavy cold and cough. It came on late morning today and it hurts more when I cough or shake my head or bend over. Prior to this, I had the same headache about four months ago. At that time it came on midafternoon and lasted till next morning. Prior to that, I had the same headache exactly on the right side and lasted for a day. I am on CCB and Protonix. I am scared of brain tumor. Please help
|
| 129 |
+
|
| 130 |
+
Real doctor answer from icliniq: Hi. Based on what you have told, this is more likely to be due to sinusitis. Sinuses are small air filled cavities in your skull bone. They can get infected and when they are infected, you will get headaches behind eyes. Such headaches will be more common during cold and cough. Brain tumor headache will be on both sides, more in forehead and you will also have vomiting in that case. Based on what you have told, this is more likely due to simple sinus infection instead of brain tumor. CT brain and CT paranasal sinuses.
|
| 131 |
+
|
| 132 |
+
Answer form ChatGPT: Hello! I'm sorry to hear you're experiencing these headaches. Based on the location of the pain, it sounds like it may be a tension headache or a migraine. These types of headaches can sometimes be triggered by colds or other illnesses. However, it's always best to get checked out by your doctor to rule out any other underlying conditions. In terms of a brain tumor, that is a very rare cause of headaches. Most people who experience headaches do not have a brain tumor. That being said, if you have any concerning symptoms such as vision changes, weakness, or numbness, it is important to seek medical attention right away. In the meantime, you can try taking over-the-counter pain relievers such as ibuprofen or acetaminophen to help with the headache. It's also important to get plenty of rest and stay hydrated. If the headaches persist or become more frequent, please make an appointment with your doctor. Take care!
|
| 133 |
+
|
| 134 |
+
Answer form ChatDoctor: Hi, Thank you for posting your query. First, I would like to reassure you that there is no need to worry about brain tumor. Also, we should keep migraine apart as well. The most likely cause of your headaches could be sinusitis. This can be confirmed by doing an X-ray of face. Treatment would consist of using decongestant tablets/ nasal sprays. Best wishes, Chat Doctor.
|
| 135 |
+
```
|
| 136 |
+
ChatDoctor accurately mentions sinusitis, which is also present in the real conversation in icliniq.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
## Overview
|
| 140 |
+
ChatDoctor is a next-generation AI doctor model that is based on the [LLaMA](https://github.com/facebookresearch/llama) model. The goal of this project is to provide patients with an intelligent and reliable healthcare companion that can answer their medical queries and provide them with personalized medical advice.
|
| 141 |
+
|
| 142 |
+
The ChatDoctor is an advanced language model that is specifically designed for medical applications. It has been trained on a large corpus of medical literature and has a deep understanding of medical terminology, procedures, and diagnoses. This model serves as the foundation for ChatDoctor, enabling it to analyze patients' symptoms and medical history, provide accurate diagnoses, and suggest appropriate treatment options.
|
| 143 |
+
|
| 144 |
+
The ChatDoctor model is designed to simulate a conversation between a doctor and a patient, using natural language processing (NLP) and machine learning techniques. Patients can interact with the ChatDoctor model through a chat interface, asking questions about their health, symptoms, or medical conditions. The model will then analyze the input and provide a response that is tailored to the patient's unique situation.
|
| 145 |
+
|
| 146 |
+
One of the key features of the ChatDoctor model is its ability to learn and adapt over time. As more patients interact with the model, it will continue to refine its responses and improve its accuracy. This means that patients can expect to receive increasingly personalized and accurate medical advice over time.
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
## Patient-physician Conversation Dataset</h2>
|
| 150 |
+
The first step in fine-tuning is to collect a dataset of patient-physician conversations. In patient-physician conversations, the patient's descriptions of disease symptoms are often colloquial and cursory. If we manually construct the synthesized patient-physician conversation dataset, it often leads to the problem of insufficient diversity and over-specialized descriptions, which are often spaced out from real scenarios. Collecting real patient-physician conversations is a better solution. Therefore, we collected about 100k real doctor-patient conversations from an online medical consultation website HealthCareMagic(www.healthcaremagic.com). We filtered these data both manually and automatically, removed the identity information of the doctor and patient, and used language tools to correct grammatical errors, and we named this dataset HealthCareMagic-100k. In addition, we collected approximately 10k patient-physician conversations from the online medical consultation website iCliniq to evaluate the performance of our model.
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
## Autonomous ChatDoctor based on Knowledge Brain</h2>
|
| 154 |
+
Equipped with the external knowledge brain, i.e., Wikipedia or our constructed database encompassing over 700 diseases, ChatDoctor could retrieve the corresponding knowledge and reliable sources to answer patients' inquiries more accurately. After constructing the external knowledge brain, we need to let our ChatDoctor retrieve the knowledge he needs autonomously, which can generally be achieved in a large language model by constructing appropriate prompts. To automate this process, we design keyword mining prompts for ChatDoctor to extract key terms for relevant knowledge seeking. Then, the top-ranked relevant passages were retrieved from Knowledge Brain with a term-matching retrieval system. As for the disease database, since the model cannot read all the data at once, we first let the model read the data in batches and select for itself the data entries that might help answer the patient's question. Finally, all the data entries selected by the model are given to the model for a final answer. This approach better ensures that patients receive well-informed and precise responses backed by credible references.
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
## Limitations
|
| 159 |
+
We emphasize that ChatDoctor is for academic research only and any commercial use and clinical use is prohibited. There are three factors in this decision: First, ChatDoctor is based on LLaMA and has a non-commercial license, so we necessarily inherited this decision. Second, our model is not licensed for healthcare-related purposes. Also, we have not designed sufficient security measures, and the current model still does not guarantee the full correctness of medical diagnoses.
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
## Reference
|
| 165 |
+
|
| 166 |
+
ChatDoctor: A Medical Chat Model Fine-tuned on LLaMA Model using Medical Domain Knowledge
|
| 167 |
+
|
| 168 |
+
```
|
| 169 |
+
@article{li2023chatdoctor,
|
| 170 |
+
title={ChatDoctor: A Medical Chat Model Fine-Tuned on a Large Language Model Meta-AI (LLaMA) Using Medical Domain Knowledge},
|
| 171 |
+
author={Li, Yunxiang and Li, Zihan and Zhang, Kai and Dan, Ruilong and Jiang, Steve and Zhang, You},
|
| 172 |
+
journal={Cureus},
|
| 173 |
+
volume={15},
|
| 174 |
+
number={6},
|
| 175 |
+
year={2023},
|
| 176 |
+
publisher={Cureus}
|
| 177 |
+
}
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
alpaca_data.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a0a9920c72e27b32013e5c4ad7727d9eede8eaab75c3f4b7eb62eda019561d7
|
| 3 |
+
size 23034003
|
chat.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# chat.py
|
| 2 |
+
import os
|
| 3 |
+
import gc
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM
|
| 6 |
+
|
| 7 |
+
# =============================
|
| 8 |
+
# Configuration
|
| 9 |
+
# =============================
|
| 10 |
+
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
|
| 11 |
+
MAX_NEW_TOKENS = 200
|
| 12 |
+
TEMPERATURE = 0.5
|
| 13 |
+
TOP_K = 50
|
| 14 |
+
REPETITION_PENALTY = 1.1
|
| 15 |
+
|
| 16 |
+
# Detect device
|
| 17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
print(f"Loading model from {MODEL_PATH} on {device}...")
|
| 19 |
+
|
| 20 |
+
# =============================
|
| 21 |
+
# Load Tokenizer and Model
|
| 22 |
+
# =============================
|
| 23 |
+
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
|
| 24 |
+
|
| 25 |
+
model = LlamaForCausalLM.from_pretrained(
|
| 26 |
+
MODEL_PATH,
|
| 27 |
+
device_map="auto", # automatically dispatch weights to GPU
|
| 28 |
+
torch_dtype=torch.float16, # half precision for faster inference
|
| 29 |
+
low_cpu_mem_usage=True # optimize CPU memory
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# DO NOT call model.to(device) when using device_map="auto"
|
| 33 |
+
generator = model.generate
|
| 34 |
+
print("✅ Model loaded successfully!\n")
|
| 35 |
+
|
| 36 |
+
# =============================
|
| 37 |
+
# Chat History
|
| 38 |
+
# =============================
|
| 39 |
+
history = ["ChatDoctor: I am ChatDoctor, what medical questions do you have?"]
|
| 40 |
+
|
| 41 |
+
# =============================
|
| 42 |
+
# Response Function
|
| 43 |
+
# =============================
|
| 44 |
+
def get_response(user_input):
|
| 45 |
+
global history
|
| 46 |
+
human_invitation = "Patient: "
|
| 47 |
+
doctor_invitation = "ChatDoctor: "
|
| 48 |
+
|
| 49 |
+
# Append user input
|
| 50 |
+
history.append(human_invitation + user_input)
|
| 51 |
+
|
| 52 |
+
# Build prompt
|
| 53 |
+
prompt = "\n".join(history) + "\n" + doctor_invitation
|
| 54 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| 55 |
+
|
| 56 |
+
# Generate response
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
output_ids = generator(
|
| 59 |
+
input_ids,
|
| 60 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 61 |
+
do_sample=True,
|
| 62 |
+
temperature=TEMPERATURE,
|
| 63 |
+
top_k=TOP_K,
|
| 64 |
+
repetition_penalty=REPETITION_PENALTY
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Decode response
|
| 68 |
+
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 69 |
+
response = full_output[len(prompt):].strip()
|
| 70 |
+
|
| 71 |
+
# Clean if the model repeats the patient prompt
|
| 72 |
+
if response.startswith("Patient:"):
|
| 73 |
+
response = response[len("Patient:"):].strip()
|
| 74 |
+
|
| 75 |
+
# Append model response to history
|
| 76 |
+
history.append(doctor_invitation + response)
|
| 77 |
+
|
| 78 |
+
# Free memory
|
| 79 |
+
del input_ids, output_ids
|
| 80 |
+
gc.collect()
|
| 81 |
+
torch.cuda.empty_cache()
|
| 82 |
+
|
| 83 |
+
return response
|
| 84 |
+
|
| 85 |
+
# =============================
|
| 86 |
+
# CLI Chat
|
| 87 |
+
# =============================
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
print("\n=== ChatDoctor is ready! Type your questions. ===\n")
|
| 90 |
+
while True:
|
| 91 |
+
try:
|
| 92 |
+
user_input = input("Patient: ").strip()
|
| 93 |
+
if user_input.lower() in ["exit", "quit"]:
|
| 94 |
+
print("Exiting ChatDoctor. Goodbye!")
|
| 95 |
+
break
|
| 96 |
+
response = get_response(user_input)
|
| 97 |
+
print("ChatDoctor: " + response + "\n")
|
| 98 |
+
except KeyboardInterrupt:
|
| 99 |
+
print("\nExiting ChatDoctor. Goodbye!")
|
| 100 |
+
break
|
chatdoctor5k.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
format_dataset.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
frontend.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
|
| 6 |
+
|
| 7 |
+
# =============================
|
| 8 |
+
# Configuration
|
| 9 |
+
# =============================
|
| 10 |
+
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
|
| 11 |
+
MAX_NEW_TOKENS = 200
|
| 12 |
+
TEMPERATURE = 0.5
|
| 13 |
+
TOP_K = 50
|
| 14 |
+
REPETITION_PENALTY = 1.1
|
| 15 |
+
|
| 16 |
+
# Detect device
|
| 17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
print(f"Loading model from {MODEL_PATH} on {device}...")
|
| 19 |
+
|
| 20 |
+
# =============================
|
| 21 |
+
# Load Tokenizer and Model
|
| 22 |
+
# =============================
|
| 23 |
+
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
|
| 24 |
+
model = LlamaForCausalLM.from_pretrained(
|
| 25 |
+
MODEL_PATH,
|
| 26 |
+
device_map="auto",
|
| 27 |
+
torch_dtype=torch.float16,
|
| 28 |
+
low_cpu_mem_usage=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
generator = model.generate
|
| 32 |
+
print("✅ ChatDoctor model loaded successfully!\n")
|
| 33 |
+
|
| 34 |
+
# =============================
|
| 35 |
+
# Stopping Criteria
|
| 36 |
+
# =============================
|
| 37 |
+
class StopOnTokens(StoppingCriteria):
|
| 38 |
+
def __init__(self, stop_ids):
|
| 39 |
+
self.stop_ids = stop_ids
|
| 40 |
+
|
| 41 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 42 |
+
for stop_id_seq in self.stop_ids:
|
| 43 |
+
if len(stop_id_seq) == 1:
|
| 44 |
+
if input_ids[0][-1] == stop_id_seq[0]:
|
| 45 |
+
return True
|
| 46 |
+
else:
|
| 47 |
+
if len(input_ids[0]) >= len(stop_id_seq):
|
| 48 |
+
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
|
| 49 |
+
return True
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
# =============================
|
| 53 |
+
# Chat History (Global)
|
| 54 |
+
# =============================
|
| 55 |
+
conversation_history = []
|
| 56 |
+
|
| 57 |
+
# =============================
|
| 58 |
+
# Get Response Function
|
| 59 |
+
# =============================
|
| 60 |
+
def get_response(user_input, history_context):
|
| 61 |
+
"""Generate response from ChatDoctor model"""
|
| 62 |
+
human_invitation = "Patient: "
|
| 63 |
+
doctor_invitation = "ChatDoctor: "
|
| 64 |
+
|
| 65 |
+
# Build conversation from history
|
| 66 |
+
history_text = []
|
| 67 |
+
for human, assistant in history_context:
|
| 68 |
+
if human:
|
| 69 |
+
history_text.append(human_invitation + human)
|
| 70 |
+
if assistant:
|
| 71 |
+
history_text.append(doctor_invitation + assistant)
|
| 72 |
+
|
| 73 |
+
# Add current user input
|
| 74 |
+
history_text.append(human_invitation + user_input)
|
| 75 |
+
|
| 76 |
+
# Build conversation prompt
|
| 77 |
+
prompt = "\n".join(history_text) + "\n" + doctor_invitation
|
| 78 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| 79 |
+
|
| 80 |
+
# Define stop words and their token IDs
|
| 81 |
+
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
|
| 82 |
+
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
|
| 83 |
+
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
|
| 84 |
+
|
| 85 |
+
# Generate model response
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
output_ids = generator(
|
| 88 |
+
input_ids,
|
| 89 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 90 |
+
do_sample=True,
|
| 91 |
+
temperature=TEMPERATURE,
|
| 92 |
+
top_k=TOP_K,
|
| 93 |
+
repetition_penalty=REPETITION_PENALTY,
|
| 94 |
+
stopping_criteria=stopping_criteria,
|
| 95 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 96 |
+
eos_token_id=tokenizer.eos_token_id
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Decode and clean response
|
| 100 |
+
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 101 |
+
response = full_output[len(prompt):].strip()
|
| 102 |
+
|
| 103 |
+
# Remove any "Patient:" that might have slipped through
|
| 104 |
+
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
|
| 105 |
+
if stop_word in response:
|
| 106 |
+
response = response.split(stop_word)[0].strip()
|
| 107 |
+
break
|
| 108 |
+
|
| 109 |
+
response = response.strip()
|
| 110 |
+
|
| 111 |
+
# Free memory
|
| 112 |
+
del input_ids, output_ids
|
| 113 |
+
gc.collect()
|
| 114 |
+
torch.cuda.empty_cache()
|
| 115 |
+
|
| 116 |
+
return response
|
| 117 |
+
|
| 118 |
+
# =============================
|
| 119 |
+
# Gradio Chat Function
|
| 120 |
+
# =============================
|
| 121 |
+
def chat_function(message, history):
|
| 122 |
+
"""Gradio chat interface function"""
|
| 123 |
+
if not message.strip():
|
| 124 |
+
return ""
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
response = get_response(message, history)
|
| 128 |
+
return response
|
| 129 |
+
except Exception as e:
|
| 130 |
+
return f"Error: {str(e)}"
|
| 131 |
+
|
| 132 |
+
# =============================
|
| 133 |
+
# Custom CSS
|
| 134 |
+
# =============================
|
| 135 |
+
custom_css = """
|
| 136 |
+
#header {
|
| 137 |
+
text-align: center;
|
| 138 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 139 |
+
color: white;
|
| 140 |
+
padding: 20px;
|
| 141 |
+
border-radius: 10px;
|
| 142 |
+
margin-bottom: 20px;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
#header h1 {
|
| 146 |
+
margin: 0;
|
| 147 |
+
font-size: 2.5em;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
#header p {
|
| 151 |
+
margin: 10px 0 0 0;
|
| 152 |
+
font-size: 1.1em;
|
| 153 |
+
opacity: 0.9;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
.disclaimer {
|
| 157 |
+
background-color: #fff3cd;
|
| 158 |
+
border: 1px solid #ffc107;
|
| 159 |
+
border-radius: 8px;
|
| 160 |
+
padding: 15px;
|
| 161 |
+
margin: 20px 0;
|
| 162 |
+
color: #856404;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
.disclaimer h3 {
|
| 166 |
+
margin-top: 0;
|
| 167 |
+
color: #856404;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
footer {
|
| 171 |
+
text-align: center;
|
| 172 |
+
margin-top: 30px;
|
| 173 |
+
color: #666;
|
| 174 |
+
font-size: 0.9em;
|
| 175 |
+
}
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
# =============================
|
| 179 |
+
# Gradio Interface
|
| 180 |
+
# =============================
|
| 181 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 182 |
+
# Header
|
| 183 |
+
gr.HTML("""
|
| 184 |
+
<div id="header">
|
| 185 |
+
<h1>🩺 ChatDoctor AI Assistant</h1>
|
| 186 |
+
<p>Your AI-powered medical conversation partner</p>
|
| 187 |
+
</div>
|
| 188 |
+
""")
|
| 189 |
+
|
| 190 |
+
# Disclaimer
|
| 191 |
+
gr.HTML("""
|
| 192 |
+
<div class="disclaimer">
|
| 193 |
+
<h3>⚠️ Medical Disclaimer</h3>
|
| 194 |
+
<p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
|
| 195 |
+
It is NOT a substitute for professional medical advice, diagnosis, or treatment.
|
| 196 |
+
Always seek the advice of your physician or other qualified health provider with any questions
|
| 197 |
+
you may have regarding a medical condition. Never disregard professional medical advice or
|
| 198 |
+
delay in seeking it because of something you have read here.</p>
|
| 199 |
+
</div>
|
| 200 |
+
""")
|
| 201 |
+
|
| 202 |
+
# Chatbot Interface
|
| 203 |
+
chatbot = gr.Chatbot(
|
| 204 |
+
height=500,
|
| 205 |
+
placeholder="<div style='text-align: center; padding: 40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. How can I assist you today?</p></div>",
|
| 206 |
+
show_label=False,
|
| 207 |
+
avatar_images=(None, "🤖"),
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
with gr.Row():
|
| 211 |
+
msg = gr.Textbox(
|
| 212 |
+
placeholder="Type your message here... (e.g., 'I have a headache')",
|
| 213 |
+
show_label=False,
|
| 214 |
+
scale=9,
|
| 215 |
+
container=False
|
| 216 |
+
)
|
| 217 |
+
submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
|
| 218 |
+
|
| 219 |
+
with gr.Row():
|
| 220 |
+
clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
|
| 221 |
+
retry_btn = gr.Button("🔄 Retry", scale=1)
|
| 222 |
+
|
| 223 |
+
# Examples
|
| 224 |
+
gr.Examples(
|
| 225 |
+
examples=[
|
| 226 |
+
"I have a persistent headache for 3 days. What should I do?",
|
| 227 |
+
"What are the symptoms of diabetes?",
|
| 228 |
+
"How can I improve my sleep quality?",
|
| 229 |
+
"I have a fever and sore throat. Should I be concerned?",
|
| 230 |
+
"What are some natural ways to reduce stress?",
|
| 231 |
+
],
|
| 232 |
+
inputs=msg,
|
| 233 |
+
label="💡 Example Questions"
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Settings (collapsed by default)
|
| 237 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 238 |
+
temperature_slider = gr.Slider(
|
| 239 |
+
minimum=0.1,
|
| 240 |
+
maximum=1.0,
|
| 241 |
+
value=TEMPERATURE,
|
| 242 |
+
step=0.1,
|
| 243 |
+
label="Temperature (Creativity)",
|
| 244 |
+
info="Higher values make responses more creative but less focused"
|
| 245 |
+
)
|
| 246 |
+
max_tokens_slider = gr.Slider(
|
| 247 |
+
minimum=50,
|
| 248 |
+
maximum=500,
|
| 249 |
+
value=MAX_NEW_TOKENS,
|
| 250 |
+
step=50,
|
| 251 |
+
label="Max Response Length",
|
| 252 |
+
info="Maximum number of tokens in response"
|
| 253 |
+
)
|
| 254 |
+
top_k_slider = gr.Slider(
|
| 255 |
+
minimum=1,
|
| 256 |
+
maximum=100,
|
| 257 |
+
value=TOP_K,
|
| 258 |
+
step=1,
|
| 259 |
+
label="Top K",
|
| 260 |
+
info="Limits vocabulary selection"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Footer
|
| 264 |
+
gr.HTML("""
|
| 265 |
+
<footer>
|
| 266 |
+
<p>Powered by ChatDoctor Model | Built with Gradio</p>
|
| 267 |
+
<p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
|
| 268 |
+
</footer>
|
| 269 |
+
""")
|
| 270 |
+
|
| 271 |
+
# Event handlers
|
| 272 |
+
def user_message(user_msg, history):
|
| 273 |
+
return "", history + [[user_msg, None]]
|
| 274 |
+
|
| 275 |
+
def bot_response(history, temp, max_tok, top_k_val):
|
| 276 |
+
global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
|
| 277 |
+
TEMPERATURE = temp
|
| 278 |
+
MAX_NEW_TOKENS = int(max_tok)
|
| 279 |
+
TOP_K = int(top_k_val)
|
| 280 |
+
|
| 281 |
+
user_msg = history[-1][0]
|
| 282 |
+
bot_msg = chat_function(user_msg, history[:-1])
|
| 283 |
+
history[-1][1] = bot_msg
|
| 284 |
+
return history
|
| 285 |
+
|
| 286 |
+
# Connect events
|
| 287 |
+
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 288 |
+
bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 292 |
+
bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
clear_btn.click(lambda: None, None, chatbot, queue=False)
|
| 296 |
+
|
| 297 |
+
def retry_last():
|
| 298 |
+
return None
|
| 299 |
+
|
| 300 |
+
retry_btn.click(retry_last, None, chatbot, queue=False)
|
| 301 |
+
|
| 302 |
+
# =============================
|
| 303 |
+
# Launch Interface
|
| 304 |
+
# =============================
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
print("\n🚀 Launching ChatDoctor Gradio Interface...")
|
| 307 |
+
demo.queue()
|
| 308 |
+
demo.launch(
|
| 309 |
+
server_name="0.0.0.0", # Accessible from network
|
| 310 |
+
server_port=7860,
|
| 311 |
+
share=False, # Set to True to create public link
|
| 312 |
+
show_error=True
|
| 313 |
+
)
|
frontend_VOic.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
|
| 6 |
+
|
| 7 |
+
# =============================
|
| 8 |
+
# Configuration
|
| 9 |
+
# =============================
|
| 10 |
+
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
|
| 11 |
+
MAX_NEW_TOKENS = 200
|
| 12 |
+
TEMPERATURE = 0.5
|
| 13 |
+
TOP_K = 50
|
| 14 |
+
REPETITION_PENALTY = 1.1
|
| 15 |
+
|
| 16 |
+
# Detect device
|
| 17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
print(f"Loading model from {MODEL_PATH} on {device}...")
|
| 19 |
+
|
| 20 |
+
# =============================
|
| 21 |
+
# Load Tokenizer and Model
|
| 22 |
+
# =============================
|
| 23 |
+
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
|
| 24 |
+
model = LlamaForCausalLM.from_pretrained(
|
| 25 |
+
MODEL_PATH,
|
| 26 |
+
device_map="auto",
|
| 27 |
+
torch_dtype=torch.float16,
|
| 28 |
+
low_cpu_mem_usage=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
generator = model.generate
|
| 32 |
+
print("✅ ChatDoctor model loaded successfully!\n")
|
| 33 |
+
|
| 34 |
+
# =============================
|
| 35 |
+
# Stopping Criteria
|
| 36 |
+
# =============================
|
| 37 |
+
class StopOnTokens(StoppingCriteria):
|
| 38 |
+
def __init__(self, stop_ids):
|
| 39 |
+
self.stop_ids = stop_ids
|
| 40 |
+
|
| 41 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 42 |
+
for stop_id_seq in self.stop_ids:
|
| 43 |
+
if len(stop_id_seq) == 1:
|
| 44 |
+
if input_ids[0][-1] == stop_id_seq[0]:
|
| 45 |
+
return True
|
| 46 |
+
else:
|
| 47 |
+
if len(input_ids[0]) >= len(stop_id_seq):
|
| 48 |
+
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
|
| 49 |
+
return True
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
# =============================
|
| 53 |
+
# Get Response Function
|
| 54 |
+
# =============================
|
| 55 |
+
def get_response(user_input, history_context):
|
| 56 |
+
"""Generate response from ChatDoctor model"""
|
| 57 |
+
human_invitation = "Patient: "
|
| 58 |
+
doctor_invitation = "ChatDoctor: "
|
| 59 |
+
|
| 60 |
+
# Build conversation from history
|
| 61 |
+
history_text = []
|
| 62 |
+
for human, assistant in history_context:
|
| 63 |
+
if human:
|
| 64 |
+
history_text.append(human_invitation + human)
|
| 65 |
+
if assistant:
|
| 66 |
+
history_text.append(doctor_invitation + assistant)
|
| 67 |
+
|
| 68 |
+
# Add current user input
|
| 69 |
+
history_text.append(human_invitation + user_input)
|
| 70 |
+
|
| 71 |
+
# Build conversation prompt
|
| 72 |
+
prompt = "\n".join(history_text) + "\n" + doctor_invitation
|
| 73 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| 74 |
+
|
| 75 |
+
# Define stop words and their token IDs
|
| 76 |
+
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
|
| 77 |
+
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
|
| 78 |
+
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
|
| 79 |
+
|
| 80 |
+
# Generate model response
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
output_ids = generator(
|
| 83 |
+
input_ids,
|
| 84 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 85 |
+
do_sample=True,
|
| 86 |
+
temperature=TEMPERATURE,
|
| 87 |
+
top_k=TOP_K,
|
| 88 |
+
repetition_penalty=REPETITION_PENALTY,
|
| 89 |
+
stopping_criteria=stopping_criteria,
|
| 90 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 91 |
+
eos_token_id=tokenizer.eos_token_id
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Decode and clean response
|
| 95 |
+
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 96 |
+
response = full_output[len(prompt):].strip()
|
| 97 |
+
|
| 98 |
+
# Remove any "Patient:" that might have slipped through
|
| 99 |
+
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
|
| 100 |
+
if stop_word in response:
|
| 101 |
+
response = response.split(stop_word)[0].strip()
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
response = response.strip()
|
| 105 |
+
|
| 106 |
+
# Free memory
|
| 107 |
+
del input_ids, output_ids
|
| 108 |
+
gc.collect()
|
| 109 |
+
torch.cuda.empty_cache()
|
| 110 |
+
|
| 111 |
+
return response
|
| 112 |
+
|
| 113 |
+
# =============================
|
| 114 |
+
# Gradio Chat Function
|
| 115 |
+
# =============================
|
| 116 |
+
def chat_function(message, history):
|
| 117 |
+
"""Gradio chat interface function"""
|
| 118 |
+
if not message.strip():
|
| 119 |
+
return ""
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
response = get_response(message, history)
|
| 123 |
+
return response
|
| 124 |
+
except Exception as e:
|
| 125 |
+
return f"Error: {str(e)}"
|
| 126 |
+
|
| 127 |
+
# =============================
|
| 128 |
+
# Text-to-Speech Function
|
| 129 |
+
# =============================
|
| 130 |
+
def text_to_speech(text):
|
| 131 |
+
"""Convert text response to speech"""
|
| 132 |
+
try:
|
| 133 |
+
from gtts import gTTS
|
| 134 |
+
import tempfile
|
| 135 |
+
|
| 136 |
+
if not text or text.startswith("Error:"):
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
# Create speech
|
| 140 |
+
tts = gTTS(text=text, lang='en', slow=False)
|
| 141 |
+
|
| 142 |
+
# Save to temporary file
|
| 143 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
|
| 144 |
+
tts.save(temp_file.name)
|
| 145 |
+
|
| 146 |
+
return temp_file.name
|
| 147 |
+
except Exception as e:
|
| 148 |
+
print(f"TTS Error: {e}")
|
| 149 |
+
return None
|
| 150 |
+
|
| 151 |
+
# =============================
|
| 152 |
+
# Custom CSS
|
| 153 |
+
# =============================
|
| 154 |
+
custom_css = """
|
| 155 |
+
#header {
|
| 156 |
+
text-align: center;
|
| 157 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 158 |
+
color: white;
|
| 159 |
+
padding: 20px;
|
| 160 |
+
border-radius: 10px;
|
| 161 |
+
margin-bottom: 20px;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
#header h1 {
|
| 165 |
+
margin: 0;
|
| 166 |
+
font-size: 2.5em;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
#header p {
|
| 170 |
+
margin: 10px 0 0 0;
|
| 171 |
+
font-size: 1.1em;
|
| 172 |
+
opacity: 0.9;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
.disclaimer {
|
| 176 |
+
background-color: #fff3cd;
|
| 177 |
+
border: 1px solid #ffc107;
|
| 178 |
+
border-radius: 8px;
|
| 179 |
+
padding: 15px;
|
| 180 |
+
margin: 20px 0;
|
| 181 |
+
color: #856404;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
.disclaimer h3 {
|
| 185 |
+
margin-top: 0;
|
| 186 |
+
color: #856404;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
.voice-section {
|
| 190 |
+
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
|
| 191 |
+
padding: 20px;
|
| 192 |
+
border-radius: 10px;
|
| 193 |
+
margin: 20px 0;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
footer {
|
| 197 |
+
text-align: center;
|
| 198 |
+
margin-top: 30px;
|
| 199 |
+
color: #666;
|
| 200 |
+
font-size: 0.9em;
|
| 201 |
+
}
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
# =============================
|
| 205 |
+
# Gradio Interface
|
| 206 |
+
# =============================
|
| 207 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 208 |
+
# Header
|
| 209 |
+
gr.HTML("""
|
| 210 |
+
<div id="header">
|
| 211 |
+
<h1>🩺 ChatDoctor AI Assistant</h1>
|
| 212 |
+
<p>Your AI-powered medical conversation partner with Voice Support</p>
|
| 213 |
+
</div>
|
| 214 |
+
""")
|
| 215 |
+
|
| 216 |
+
# Disclaimer
|
| 217 |
+
gr.HTML("""
|
| 218 |
+
<div class="disclaimer">
|
| 219 |
+
<h3>⚠️ Medical Disclaimer</h3>
|
| 220 |
+
<p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
|
| 221 |
+
It is NOT a substitute for professional medical advice, diagnosis, or treatment.
|
| 222 |
+
Always seek the advice of your physician or other qualified health provider with any questions
|
| 223 |
+
you may have regarding a medical condition. Never disregard professional medical advice or
|
| 224 |
+
delay in seeking it because of something you have read here.</p>
|
| 225 |
+
</div>
|
| 226 |
+
""")
|
| 227 |
+
|
| 228 |
+
with gr.Row():
|
| 229 |
+
with gr.Column(scale=7):
|
| 230 |
+
# Chatbot Interface
|
| 231 |
+
chatbot = gr.Chatbot(
|
| 232 |
+
height=500,
|
| 233 |
+
placeholder="<div style='text-align: center; padding: 40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. Type or speak your question!</p></div>",
|
| 234 |
+
show_label=False,
|
| 235 |
+
avatar_images=(None, "🤖"),
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
with gr.Row():
|
| 239 |
+
msg = gr.Textbox(
|
| 240 |
+
placeholder="Type your message here... (e.g., 'I have a headache')",
|
| 241 |
+
show_label=False,
|
| 242 |
+
scale=9,
|
| 243 |
+
container=False
|
| 244 |
+
)
|
| 245 |
+
submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
|
| 246 |
+
|
| 247 |
+
with gr.Row():
|
| 248 |
+
clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
|
| 249 |
+
retry_btn = gr.Button("🔄 Retry", scale=1)
|
| 250 |
+
|
| 251 |
+
with gr.Column(scale=3):
|
| 252 |
+
# Voice Input Section
|
| 253 |
+
gr.HTML("<div class='voice-section'><h3 style='color: white; text-align: center; margin-top: 0;'>🎤 Voice Features</h3></div>")
|
| 254 |
+
|
| 255 |
+
audio_input = gr.Audio(
|
| 256 |
+
sources=["microphone"],
|
| 257 |
+
type="filepath",
|
| 258 |
+
label="🎙️ Speak Your Question",
|
| 259 |
+
show_download_button=False
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
transcribed_text = gr.Textbox(
|
| 263 |
+
label="📝 Transcribed Text",
|
| 264 |
+
placeholder="Your speech will appear here...",
|
| 265 |
+
interactive=False,
|
| 266 |
+
lines=3
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
send_voice_btn = gr.Button("Send Voice Message 🔊", variant="primary")
|
| 270 |
+
|
| 271 |
+
gr.Markdown("---")
|
| 272 |
+
|
| 273 |
+
# Voice Output
|
| 274 |
+
tts_enabled = gr.Checkbox(
|
| 275 |
+
label="🔊 Enable Text-to-Speech for responses",
|
| 276 |
+
value=True,
|
| 277 |
+
info="Hear the doctor's response"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
audio_output = gr.Audio(
|
| 281 |
+
label="🔈 AI Response Audio",
|
| 282 |
+
autoplay=False,
|
| 283 |
+
visible=True
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Examples
|
| 287 |
+
gr.Examples(
|
| 288 |
+
examples=[
|
| 289 |
+
"I have a persistent headache for 3 days. What should I do?",
|
| 290 |
+
"What are the symptoms of diabetes?",
|
| 291 |
+
"How can I improve my sleep quality?",
|
| 292 |
+
"I have a fever and sore throat. Should I be concerned?",
|
| 293 |
+
"What are some natural ways to reduce stress?",
|
| 294 |
+
],
|
| 295 |
+
inputs=msg,
|
| 296 |
+
label="💡 Example Questions"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Settings (collapsed by default)
|
| 300 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 301 |
+
temperature_slider = gr.Slider(
|
| 302 |
+
minimum=0.1,
|
| 303 |
+
maximum=1.0,
|
| 304 |
+
value=TEMPERATURE,
|
| 305 |
+
step=0.1,
|
| 306 |
+
label="Temperature (Creativity)",
|
| 307 |
+
info="Higher values make responses more creative but less focused"
|
| 308 |
+
)
|
| 309 |
+
max_tokens_slider = gr.Slider(
|
| 310 |
+
minimum=50,
|
| 311 |
+
maximum=500,
|
| 312 |
+
value=MAX_NEW_TOKENS,
|
| 313 |
+
step=50,
|
| 314 |
+
label="Max Response Length",
|
| 315 |
+
info="Maximum number of tokens in response"
|
| 316 |
+
)
|
| 317 |
+
top_k_slider = gr.Slider(
|
| 318 |
+
minimum=1,
|
| 319 |
+
maximum=100,
|
| 320 |
+
value=TOP_K,
|
| 321 |
+
step=1,
|
| 322 |
+
label="Top K",
|
| 323 |
+
info="Limits vocabulary selection"
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Footer
|
| 327 |
+
gr.HTML("""
|
| 328 |
+
<footer>
|
| 329 |
+
<p>Powered by ChatDoctor Model | Built with Gradio | Voice-Enabled 🎤</p>
|
| 330 |
+
<p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
|
| 331 |
+
</footer>
|
| 332 |
+
""")
|
| 333 |
+
|
| 334 |
+
# =============================
|
| 335 |
+
# Event Handlers
|
| 336 |
+
# =============================
|
| 337 |
+
|
| 338 |
+
def user_message(user_msg, history):
|
| 339 |
+
return "", history + [[user_msg, None]], None
|
| 340 |
+
|
| 341 |
+
def bot_response(history, temp, max_tok, top_k_val, tts_enabled_val):
|
| 342 |
+
global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
|
| 343 |
+
TEMPERATURE = temp
|
| 344 |
+
MAX_NEW_TOKENS = int(max_tok)
|
| 345 |
+
TOP_K = int(top_k_val)
|
| 346 |
+
|
| 347 |
+
user_msg = history[-1][0]
|
| 348 |
+
bot_msg = chat_function(user_msg, history[:-1])
|
| 349 |
+
history[-1][1] = bot_msg
|
| 350 |
+
|
| 351 |
+
# Generate audio if TTS is enabled
|
| 352 |
+
audio_file = None
|
| 353 |
+
if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
|
| 354 |
+
audio_file = text_to_speech(bot_msg)
|
| 355 |
+
|
| 356 |
+
return history, audio_file
|
| 357 |
+
|
| 358 |
+
def transcribe_audio(audio_file):
|
| 359 |
+
"""Transcribe audio to text using Whisper"""
|
| 360 |
+
if audio_file is None:
|
| 361 |
+
return ""
|
| 362 |
+
|
| 363 |
+
try:
|
| 364 |
+
import whisper
|
| 365 |
+
model = whisper.load_model("base")
|
| 366 |
+
result = model.transcribe(audio_file)
|
| 367 |
+
return result["text"]
|
| 368 |
+
except ImportError:
|
| 369 |
+
return "Error: Please install whisper: pip install openai-whisper"
|
| 370 |
+
except Exception as e:
|
| 371 |
+
return f"Transcription error: {str(e)}"
|
| 372 |
+
|
| 373 |
+
def process_voice_input(audio_file, history, temp, max_tok, top_k_val, tts_enabled_val):
|
| 374 |
+
"""Process voice input: transcribe -> send -> get response"""
|
| 375 |
+
if audio_file is None:
|
| 376 |
+
return history, "", None, None
|
| 377 |
+
|
| 378 |
+
# Transcribe
|
| 379 |
+
transcribed = transcribe_audio(audio_file)
|
| 380 |
+
|
| 381 |
+
if transcribed.startswith("Error:"):
|
| 382 |
+
return history, transcribed, None, None
|
| 383 |
+
|
| 384 |
+
# Add to chat
|
| 385 |
+
history = history + [[transcribed, None]]
|
| 386 |
+
|
| 387 |
+
# Get response
|
| 388 |
+
global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
|
| 389 |
+
TEMPERATURE = temp
|
| 390 |
+
MAX_NEW_TOKENS = int(max_tok)
|
| 391 |
+
TOP_K = int(top_k_val)
|
| 392 |
+
|
| 393 |
+
bot_msg = chat_function(transcribed, history[:-1])
|
| 394 |
+
history[-1][1] = bot_msg
|
| 395 |
+
|
| 396 |
+
# Generate audio if TTS is enabled
|
| 397 |
+
audio_file = None
|
| 398 |
+
if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
|
| 399 |
+
audio_file = text_to_speech(bot_msg)
|
| 400 |
+
|
| 401 |
+
return history, transcribed, None, audio_file
|
| 402 |
+
|
| 403 |
+
# Text input events
|
| 404 |
+
msg.submit(
|
| 405 |
+
user_message,
|
| 406 |
+
[msg, chatbot],
|
| 407 |
+
[msg, chatbot, audio_output],
|
| 408 |
+
queue=False
|
| 409 |
+
).then(
|
| 410 |
+
bot_response,
|
| 411 |
+
[chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
|
| 412 |
+
[chatbot, audio_output]
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
submit_btn.click(
|
| 416 |
+
user_message,
|
| 417 |
+
[msg, chatbot],
|
| 418 |
+
[msg, chatbot, audio_output],
|
| 419 |
+
queue=False
|
| 420 |
+
).then(
|
| 421 |
+
bot_response,
|
| 422 |
+
[chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
|
| 423 |
+
[chatbot, audio_output]
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Voice input events
|
| 427 |
+
audio_input.change(
|
| 428 |
+
transcribe_audio,
|
| 429 |
+
[audio_input],
|
| 430 |
+
[transcribed_text]
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
send_voice_btn.click(
|
| 434 |
+
process_voice_input,
|
| 435 |
+
[audio_input, chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
|
| 436 |
+
[chatbot, transcribed_text, audio_input, audio_output]
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Clear and retry
|
| 440 |
+
clear_btn.click(lambda: (None, None, None), None, [chatbot, audio_output, transcribed_text], queue=False)
|
| 441 |
+
|
| 442 |
+
retry_btn.click(lambda: None, None, chatbot, queue=False)
|
| 443 |
+
|
| 444 |
+
# =============================
|
| 445 |
+
# Launch Interface
|
| 446 |
+
# =============================
|
| 447 |
+
if __name__ == "__main__":
|
| 448 |
+
print("\n🚀 Launching ChatDoctor Gradio Interface with Voice Support...")
|
| 449 |
+
print("\n📦 Required packages:")
|
| 450 |
+
print(" pip install gradio gTTS openai-whisper")
|
| 451 |
+
print("\nNote: Whisper will download models on first use (~100MB for base model)\n")
|
| 452 |
+
|
| 453 |
+
demo.queue()
|
| 454 |
+
demo.launch(
|
| 455 |
+
server_name="0.0.0.0",
|
| 456 |
+
server_port=7860,
|
| 457 |
+
share=False,
|
| 458 |
+
show_error=True
|
| 459 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
rouge_score
|
| 3 |
+
fire
|
| 4 |
+
openai
|
| 5 |
+
git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176
|
| 6 |
+
torch
|
| 7 |
+
sentencepiece
|
| 8 |
+
tokenizers==0.13.3
|
| 9 |
+
wandb
|
| 10 |
+
accelerate
|
teak.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, itertools, bisect, gc
|
| 2 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
| 4 |
+
import transformers
|
| 5 |
+
import torch
|
| 6 |
+
from accelerate import Accelerator
|
| 7 |
+
import accelerate
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
model = None
|
| 11 |
+
tokenizer = None
|
| 12 |
+
generator = None
|
| 13 |
+
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
| 14 |
+
|
| 15 |
+
def load_model(model_name, eight_bit=0, device_map="auto"):
|
| 16 |
+
global model, tokenizer, generator
|
| 17 |
+
|
| 18 |
+
print("Loading "+model_name+"...")
|
| 19 |
+
|
| 20 |
+
if device_map == "zero":
|
| 21 |
+
device_map = "balanced_low_0"
|
| 22 |
+
|
| 23 |
+
# config
|
| 24 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
+
gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
| 26 |
+
print('gpu_count', gpu_count)
|
| 27 |
+
|
| 28 |
+
tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
| 29 |
+
model = LlamaForCausalLM.from_pretrained(
|
| 30 |
+
model_name,
|
| 31 |
+
#device_map=device_map,
|
| 32 |
+
#device_map="auto",
|
| 33 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 34 |
+
#max_memory = {0: "14GB", 1: "14GB", 2: "14GB", 3: "14GB",4: "14GB",5: "14GB",6: "14GB",7: "14GB"},
|
| 35 |
+
#load_in_8bit=eight_bit,
|
| 36 |
+
#from_tf=True,
|
| 37 |
+
low_cpu_mem_usage=True,
|
| 38 |
+
load_in_8bit=False,
|
| 39 |
+
cache_dir="cache"
|
| 40 |
+
).to(device)
|
| 41 |
+
|
| 42 |
+
generator = model.generate
|
| 43 |
+
|
| 44 |
+
load_model(r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained")
|
| 45 |
+
|
| 46 |
+
First_chat = "ChatDoctor: I am ChatDoctor, what medical questions do you have?"
|
| 47 |
+
print(First_chat)
|
| 48 |
+
history = []
|
| 49 |
+
history.append(First_chat)
|
| 50 |
+
|
| 51 |
+
def go():
|
| 52 |
+
invitation = "ChatDoctor: "
|
| 53 |
+
human_invitation = "Patient: "
|
| 54 |
+
|
| 55 |
+
# input
|
| 56 |
+
msg = input(human_invitation)
|
| 57 |
+
print("")
|
| 58 |
+
|
| 59 |
+
history.append(human_invitation + msg)
|
| 60 |
+
|
| 61 |
+
fulltext = "If you are a doctor, please answer the medical questions based on the patient's description. \n\n" + "\n\n".join(history) + "\n\n" + invitation
|
| 62 |
+
#fulltext = "\n\n".join(history) + "\n\n" + invitation
|
| 63 |
+
|
| 64 |
+
#print('SENDING==========')
|
| 65 |
+
#print(fulltext)
|
| 66 |
+
#print('==========')
|
| 67 |
+
|
| 68 |
+
generated_text = ""
|
| 69 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
+
gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.to(device)
|
| 71 |
+
in_tokens = len(gen_in)
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
generated_ids = generator(
|
| 74 |
+
gen_in,
|
| 75 |
+
max_new_tokens=200,
|
| 76 |
+
use_cache=True,
|
| 77 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 78 |
+
num_return_sequences=1,
|
| 79 |
+
do_sample=True,
|
| 80 |
+
repetition_penalty=1.1, # 1.0 means 'off'. unfortunately if we penalize it it will not output Sphynx:
|
| 81 |
+
temperature=0.5, # default: 1.0
|
| 82 |
+
top_k = 50, # default: 50
|
| 83 |
+
top_p = 1.0, # default: 1.0
|
| 84 |
+
early_stopping=True,
|
| 85 |
+
)
|
| 86 |
+
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # for some reason, batch_decode returns an array of one element?
|
| 87 |
+
|
| 88 |
+
text_without_prompt = generated_text[len(fulltext):]
|
| 89 |
+
|
| 90 |
+
response = text_without_prompt
|
| 91 |
+
|
| 92 |
+
response = response.split(human_invitation)[0]
|
| 93 |
+
|
| 94 |
+
response.strip()
|
| 95 |
+
|
| 96 |
+
print(invitation + response)
|
| 97 |
+
|
| 98 |
+
print("")
|
| 99 |
+
|
| 100 |
+
history.append(invitation + response)
|
| 101 |
+
|
| 102 |
+
while True:
|
| 103 |
+
go()
|
test.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
|
| 6 |
+
|
| 7 |
+
# =============================
|
| 8 |
+
# Configuration
|
| 9 |
+
# =============================
|
| 10 |
+
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
|
| 11 |
+
MAX_NEW_TOKENS = 200
|
| 12 |
+
TEMPERATURE = 0.5
|
| 13 |
+
TOP_K = 50
|
| 14 |
+
REPETITION_PENALTY = 1.1
|
| 15 |
+
|
| 16 |
+
# Detect device
|
| 17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
print(f"Loading model from {MODEL_PATH} on {device}...")
|
| 19 |
+
|
| 20 |
+
# =============================
|
| 21 |
+
# Load Tokenizer and Model
|
| 22 |
+
# =============================
|
| 23 |
+
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
|
| 24 |
+
model = LlamaForCausalLM.from_pretrained(
|
| 25 |
+
MODEL_PATH,
|
| 26 |
+
device_map="auto",
|
| 27 |
+
torch_dtype=torch.float16,
|
| 28 |
+
low_cpu_mem_usage=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
generator = model.generate
|
| 32 |
+
print("✅ ChatDoctor model loaded successfully!\n")
|
| 33 |
+
|
| 34 |
+
# =============================
|
| 35 |
+
# System Prompt
|
| 36 |
+
# =============================
|
| 37 |
+
SYSTEM_PROMPT = """
|
| 38 |
+
You are ChatDoctor — a friendly, professional, and caring virtual doctor.
|
| 39 |
+
Whenever a patient describes their symptoms:
|
| 40 |
+
1. Always include a recommendation for diet, fluids, and proteins appropriate for recovery.
|
| 41 |
+
- Fruits: citrus (orange, lemon), kiwi, papaya
|
| 42 |
+
- Vegetables: leafy greens, carrots, spinach
|
| 43 |
+
- Fluids: warm soups, herbal teas, coconut water
|
| 44 |
+
- Proteins: boiled eggs, lentils, fish, chicken soup
|
| 45 |
+
- Extras: garlic, ginger, turmeric
|
| 46 |
+
2. Recommend safe over-the-counter medicines if applicable (e.g., paracetamol for fever).
|
| 47 |
+
3. Ask follow-up questions if needed to understand the patient's condition better.
|
| 48 |
+
4. Always encourage the patient to see a real doctor if symptoms persist, worsen, or are serious.
|
| 49 |
+
5. Provide clear, warm, and empathetic advice.
|
| 50 |
+
6. Make your response structured and easy to understand.
|
| 51 |
+
7. Even if the patient only mentions a symptom, always include diet, fluids, protein, and care suggestions automatically.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
# =============================
|
| 55 |
+
# Stopping Criteria
|
| 56 |
+
# =============================
|
| 57 |
+
class StopOnTokens(StoppingCriteria):
|
| 58 |
+
def __init__(self, stop_ids):
|
| 59 |
+
self.stop_ids = stop_ids
|
| 60 |
+
|
| 61 |
+
def __call__(self, input_ids, scores, **kwargs):
|
| 62 |
+
for stop_id_seq in self.stop_ids:
|
| 63 |
+
if len(stop_id_seq) == 1:
|
| 64 |
+
if input_ids[0][-1] == stop_id_seq[0]:
|
| 65 |
+
return True
|
| 66 |
+
else:
|
| 67 |
+
if len(input_ids[0]) >= len(stop_id_seq):
|
| 68 |
+
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
|
| 69 |
+
return True
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
# =============================
|
| 73 |
+
# Chat History (Global)
|
| 74 |
+
# =============================
|
| 75 |
+
conversation_history = []
|
| 76 |
+
|
| 77 |
+
# =============================
|
| 78 |
+
# Get Response Function
|
| 79 |
+
# =============================
|
| 80 |
+
def get_response(user_input, history_context):
|
| 81 |
+
"""Generate response from ChatDoctor model"""
|
| 82 |
+
# Build conversation from history
|
| 83 |
+
history_text = []
|
| 84 |
+
for human, assistant in history_context:
|
| 85 |
+
if human:
|
| 86 |
+
history_text.append("Patient: " + human)
|
| 87 |
+
if assistant:
|
| 88 |
+
history_text.append("ChatDoctor: " + assistant)
|
| 89 |
+
|
| 90 |
+
# Add current user input
|
| 91 |
+
history_text.append("Patient: " + user_input)
|
| 92 |
+
|
| 93 |
+
# Build full prompt including system instructions
|
| 94 |
+
prompt = SYSTEM_PROMPT + "\n\nConversation so far:\n" + "\n".join(history_text) + "\nChatDoctor:"
|
| 95 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| 96 |
+
|
| 97 |
+
# Define stop words and their token IDs
|
| 98 |
+
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
|
| 99 |
+
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
|
| 100 |
+
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
|
| 101 |
+
|
| 102 |
+
# Generate model response
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
output_ids = generator(
|
| 105 |
+
input_ids,
|
| 106 |
+
max_new_tokens=MAX_NEW_TOKENS,
|
| 107 |
+
do_sample=True,
|
| 108 |
+
temperature=TEMPERATURE,
|
| 109 |
+
top_k=TOP_K,
|
| 110 |
+
repetition_penalty=REPETITION_PENALTY,
|
| 111 |
+
stopping_criteria=stopping_criteria,
|
| 112 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 113 |
+
eos_token_id=tokenizer.eos_token_id
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Decode and clean response
|
| 117 |
+
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 118 |
+
response = full_output[len(prompt):].strip()
|
| 119 |
+
|
| 120 |
+
# Remove any "Patient:" that might have slipped through
|
| 121 |
+
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
|
| 122 |
+
if stop_word in response:
|
| 123 |
+
response = response.split(stop_word)[0].strip()
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
# Free memory
|
| 127 |
+
del input_ids, output_ids
|
| 128 |
+
gc.collect()
|
| 129 |
+
torch.cuda.empty_cache()
|
| 130 |
+
|
| 131 |
+
return response
|
| 132 |
+
|
| 133 |
+
# =============================
|
| 134 |
+
# Gradio Chat Function
|
| 135 |
+
# =============================
|
| 136 |
+
def chat_function(message, history):
|
| 137 |
+
"""Gradio chat interface function"""
|
| 138 |
+
if not message.strip():
|
| 139 |
+
return ""
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
response = get_response(message, history)
|
| 143 |
+
return response
|
| 144 |
+
except Exception as e:
|
| 145 |
+
return f"Error: {str(e)}"
|
| 146 |
+
|
| 147 |
+
# =============================
|
| 148 |
+
# Custom CSS
|
| 149 |
+
# =============================
|
| 150 |
+
custom_css = """
|
| 151 |
+
#header {
|
| 152 |
+
text-align: center;
|
| 153 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 154 |
+
color: white;
|
| 155 |
+
padding: 20px;
|
| 156 |
+
border-radius: 10px;
|
| 157 |
+
margin-bottom: 20px;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
#header h1 {
|
| 161 |
+
margin: 0;
|
| 162 |
+
font-size: 2.5em;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
#header p {
|
| 166 |
+
margin: 10px 0 0 0;
|
| 167 |
+
font-size: 1.1em;
|
| 168 |
+
opacity: 0.9;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
.disclaimer {
|
| 172 |
+
background-color: #fff3cd;
|
| 173 |
+
border: 1px solid #ffc107;
|
| 174 |
+
border-radius: 8px;
|
| 175 |
+
padding: 15px;
|
| 176 |
+
margin: 20px 0;
|
| 177 |
+
color: #856404;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
.disclaimer h3 {
|
| 181 |
+
margin-top: 0;
|
| 182 |
+
color: #856404;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
footer {
|
| 186 |
+
text-align: center;
|
| 187 |
+
margin-top: 30px;
|
| 188 |
+
color: #666;
|
| 189 |
+
font-size: 0.9em;
|
| 190 |
+
}
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
# =============================
|
| 194 |
+
# Gradio Interface
|
| 195 |
+
# =============================
|
| 196 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 197 |
+
# Header
|
| 198 |
+
gr.HTML("""
|
| 199 |
+
<div id="header">
|
| 200 |
+
<h1>🩺 ChatDoctor AI Assistant</h1>
|
| 201 |
+
<p>Your AI-powered medical conversation partner</p>
|
| 202 |
+
</div>
|
| 203 |
+
""")
|
| 204 |
+
|
| 205 |
+
# Disclaimer
|
| 206 |
+
gr.HTML("""
|
| 207 |
+
<div class="disclaimer">
|
| 208 |
+
<h3>⚠️ Medical Disclaimer</h3>
|
| 209 |
+
<p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
|
| 210 |
+
It is NOT a substitute for professional medical advice, diagnosis, or treatment.
|
| 211 |
+
Always seek the advice of your physician or other qualified health provider with any questions
|
| 212 |
+
you may have regarding a medical condition. Never disregard professional medical advice or
|
| 213 |
+
delay in seeking it because of something you have read here.</p>
|
| 214 |
+
</div>
|
| 215 |
+
""")
|
| 216 |
+
|
| 217 |
+
# Chatbot Interface
|
| 218 |
+
chatbot = gr.Chatbot(
|
| 219 |
+
height=500,
|
| 220 |
+
placeholder="<div style='text-align: center; padding: 40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. How can I assist you today?</p></div>",
|
| 221 |
+
show_label=False,
|
| 222 |
+
avatar_images=(None, "🤖"),
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
with gr.Row():
|
| 226 |
+
msg = gr.Textbox(
|
| 227 |
+
placeholder="Type your message here... (e.g., 'I have a headache')",
|
| 228 |
+
show_label=False,
|
| 229 |
+
scale=9,
|
| 230 |
+
container=False
|
| 231 |
+
)
|
| 232 |
+
submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
|
| 233 |
+
|
| 234 |
+
with gr.Row():
|
| 235 |
+
clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
|
| 236 |
+
retry_btn = gr.Button("🔄 Retry", scale=1)
|
| 237 |
+
|
| 238 |
+
# Examples
|
| 239 |
+
gr.Examples(
|
| 240 |
+
examples=[
|
| 241 |
+
"I have a persistent headache for 3 days. What should I do?",
|
| 242 |
+
"What are the symptoms of diabetes?",
|
| 243 |
+
"How can I improve my sleep quality?",
|
| 244 |
+
"I have a fever and sore throat. Should I be concerned?",
|
| 245 |
+
"What are some natural ways to reduce stress?",
|
| 246 |
+
],
|
| 247 |
+
inputs=msg,
|
| 248 |
+
label="💡 Example Questions"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Settings (collapsed by default)
|
| 252 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 253 |
+
temperature_slider = gr.Slider(
|
| 254 |
+
minimum=0.1,
|
| 255 |
+
maximum=1.0,
|
| 256 |
+
value=TEMPERATURE,
|
| 257 |
+
step=0.1,
|
| 258 |
+
label="Temperature (Creativity)",
|
| 259 |
+
info="Higher values make responses more creative but less focused"
|
| 260 |
+
)
|
| 261 |
+
max_tokens_slider = gr.Slider(
|
| 262 |
+
minimum=50,
|
| 263 |
+
maximum=500,
|
| 264 |
+
value=MAX_NEW_TOKENS,
|
| 265 |
+
step=50,
|
| 266 |
+
label="Max Response Length",
|
| 267 |
+
info="Maximum number of tokens in response"
|
| 268 |
+
)
|
| 269 |
+
top_k_slider = gr.Slider(
|
| 270 |
+
minimum=1,
|
| 271 |
+
maximum=100,
|
| 272 |
+
value=TOP_K,
|
| 273 |
+
step=1,
|
| 274 |
+
label="Top K",
|
| 275 |
+
info="Limits vocabulary selection"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Footer
|
| 279 |
+
gr.HTML(f"""
|
| 280 |
+
<footer>
|
| 281 |
+
<p>Powered by ChatDoctor Model | Built with Gradio</p>
|
| 282 |
+
<p>Device: {device.upper()} | Model: LLaMA-based Medical AI</p>
|
| 283 |
+
</footer>
|
| 284 |
+
""")
|
| 285 |
+
|
| 286 |
+
# Event handlers
|
| 287 |
+
def user_message(user_msg, history):
|
| 288 |
+
return "", history + [[user_msg, None]]
|
| 289 |
+
|
| 290 |
+
def bot_response(history, temp, max_tok, top_k_val):
|
| 291 |
+
global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
|
| 292 |
+
TEMPERATURE = temp
|
| 293 |
+
MAX_NEW_TOKENS = int(max_tok)
|
| 294 |
+
TOP_K = int(top_k_val)
|
| 295 |
+
|
| 296 |
+
user_msg = history[-1][0]
|
| 297 |
+
bot_msg = chat_function(user_msg, history[:-1])
|
| 298 |
+
history[-1][1] = bot_msg
|
| 299 |
+
return history
|
| 300 |
+
|
| 301 |
+
# Connect events
|
| 302 |
+
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 303 |
+
bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 307 |
+
bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
clear_btn.click(lambda: None, None, chatbot, queue=False)
|
| 311 |
+
|
| 312 |
+
def retry_last():
|
| 313 |
+
return None
|
| 314 |
+
|
| 315 |
+
retry_btn.click(retry_last, None, chatbot, queue=False)
|
| 316 |
+
|
| 317 |
+
# =============================
|
| 318 |
+
# Launch Interface
|
| 319 |
+
# =============================
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
print("\n🚀 Launching ChatDoctor Gradio Interface...")
|
| 322 |
+
demo.queue()
|
| 323 |
+
demo.launch(
|
| 324 |
+
server_name="0.0.0.0", # Accessible from network
|
| 325 |
+
server_port=7860,
|
| 326 |
+
share=False, # Set to True to create public link
|
| 327 |
+
show_error=True
|
| 328 |
+
)
|
train.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import copy
|
| 16 |
+
import logging
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Optional, Dict, Sequence
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import transformers
|
| 22 |
+
from torch.utils.data import Dataset
|
| 23 |
+
from transformers import Trainer
|
| 24 |
+
|
| 25 |
+
import utils
|
| 26 |
+
|
| 27 |
+
IGNORE_INDEX = -100
|
| 28 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
| 29 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
| 30 |
+
DEFAULT_BOS_TOKEN = "</s>"
|
| 31 |
+
DEFAULT_UNK_TOKEN = "</s>"
|
| 32 |
+
PROMPT_DICT = {
|
| 33 |
+
"prompt_input": (
|
| 34 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
| 35 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 36 |
+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
| 37 |
+
),
|
| 38 |
+
"prompt_no_input": (
|
| 39 |
+
"Below is an instruction that describes a task. "
|
| 40 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 41 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
| 42 |
+
),
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class ModelArguments:
|
| 48 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class DataArguments:
|
| 53 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 58 |
+
cache_dir: Optional[str] = field(default=None)
|
| 59 |
+
optim: str = field(default="adamw_torch")
|
| 60 |
+
model_max_length: int = field(
|
| 61 |
+
default=512,
|
| 62 |
+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
| 67 |
+
"""Collects the state dict and dump to disk."""
|
| 68 |
+
state_dict = trainer.model.state_dict()
|
| 69 |
+
if trainer.args.should_save:
|
| 70 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
| 71 |
+
del state_dict
|
| 72 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def smart_tokenizer_and_embedding_resize(
|
| 76 |
+
special_tokens_dict: Dict,
|
| 77 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 78 |
+
model: transformers.PreTrainedModel,
|
| 79 |
+
):
|
| 80 |
+
"""Resize tokenizer and embedding.
|
| 81 |
+
|
| 82 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
| 83 |
+
"""
|
| 84 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
| 85 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 86 |
+
|
| 87 |
+
if num_new_tokens > 0:
|
| 88 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
| 89 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
| 90 |
+
|
| 91 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 92 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 93 |
+
|
| 94 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 95 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
| 99 |
+
"""Tokenize a list of strings."""
|
| 100 |
+
tokenized_list = [
|
| 101 |
+
tokenizer(
|
| 102 |
+
text,
|
| 103 |
+
return_tensors="pt",
|
| 104 |
+
padding="longest",
|
| 105 |
+
max_length=tokenizer.model_max_length,
|
| 106 |
+
truncation=True,
|
| 107 |
+
)
|
| 108 |
+
for text in strings
|
| 109 |
+
]
|
| 110 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
| 111 |
+
input_ids_lens = labels_lens = [
|
| 112 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
| 113 |
+
]
|
| 114 |
+
return dict(
|
| 115 |
+
input_ids=input_ids,
|
| 116 |
+
labels=labels,
|
| 117 |
+
input_ids_lens=input_ids_lens,
|
| 118 |
+
labels_lens=labels_lens,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def preprocess(
|
| 123 |
+
sources: Sequence[str],
|
| 124 |
+
targets: Sequence[str],
|
| 125 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 126 |
+
) -> Dict:
|
| 127 |
+
"""Preprocess the data by tokenizing."""
|
| 128 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
| 129 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
| 130 |
+
input_ids = examples_tokenized["input_ids"]
|
| 131 |
+
labels = copy.deepcopy(input_ids)
|
| 132 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
| 133 |
+
label[:source_len] = IGNORE_INDEX
|
| 134 |
+
return dict(input_ids=input_ids, labels=labels)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class SupervisedDataset(Dataset):
|
| 138 |
+
"""Dataset for supervised fine-tuning."""
|
| 139 |
+
|
| 140 |
+
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
| 141 |
+
super(SupervisedDataset, self).__init__()
|
| 142 |
+
logging.warning("Loading data...")
|
| 143 |
+
list_data_dict = utils.jload(data_path)
|
| 144 |
+
|
| 145 |
+
logging.warning("Formatting inputs...")
|
| 146 |
+
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
| 147 |
+
sources = [
|
| 148 |
+
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
| 149 |
+
for example in list_data_dict
|
| 150 |
+
]
|
| 151 |
+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
| 152 |
+
|
| 153 |
+
logging.warning("Tokenizing inputs... This may take some time...")
|
| 154 |
+
data_dict = preprocess(sources, targets, tokenizer)
|
| 155 |
+
|
| 156 |
+
self.input_ids = data_dict["input_ids"]
|
| 157 |
+
self.labels = data_dict["labels"]
|
| 158 |
+
|
| 159 |
+
def __len__(self):
|
| 160 |
+
return len(self.input_ids)
|
| 161 |
+
|
| 162 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 163 |
+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclass
|
| 167 |
+
class DataCollatorForSupervisedDataset(object):
|
| 168 |
+
"""Collate examples for supervised fine-tuning."""
|
| 169 |
+
|
| 170 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 171 |
+
|
| 172 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 173 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| 174 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 175 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
| 176 |
+
)
|
| 177 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 178 |
+
return dict(
|
| 179 |
+
input_ids=input_ids,
|
| 180 |
+
labels=labels,
|
| 181 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
| 186 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 187 |
+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
|
| 188 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 189 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def train():
|
| 193 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
| 194 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 195 |
+
|
| 196 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 197 |
+
model_args.model_name_or_path,
|
| 198 |
+
cache_dir=training_args.cache_dir,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 202 |
+
model_args.model_name_or_path,
|
| 203 |
+
cache_dir=training_args.cache_dir,
|
| 204 |
+
model_max_length=training_args.model_max_length,
|
| 205 |
+
padding_side="right",
|
| 206 |
+
use_fast=False,
|
| 207 |
+
)
|
| 208 |
+
if tokenizer.pad_token is None:
|
| 209 |
+
smart_tokenizer_and_embedding_resize(
|
| 210 |
+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
| 211 |
+
tokenizer=tokenizer,
|
| 212 |
+
model=model,
|
| 213 |
+
)
|
| 214 |
+
if "llama" in model_args.model_name_or_path:
|
| 215 |
+
tokenizer.add_special_tokens(
|
| 216 |
+
{
|
| 217 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
| 218 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
| 219 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
| 220 |
+
}
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
|
| 224 |
+
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
| 225 |
+
trainer.train()
|
| 226 |
+
trainer.save_state()
|
| 227 |
+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
if __name__ == "__main__":
|
| 231 |
+
train()
|
train_lora.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import fire
|
| 6 |
+
import torch
|
| 7 |
+
import transformers
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from typing import List, Optional, Union
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
Unused imports:
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import bitsandbytes as bnb
|
| 15 |
+
"""
|
| 16 |
+
from peft import ( # noqa: E402
|
| 17 |
+
LoraConfig,
|
| 18 |
+
BottleneckConfig,
|
| 19 |
+
get_peft_model,
|
| 20 |
+
get_peft_model_state_dict,
|
| 21 |
+
prepare_model_for_int8_training,
|
| 22 |
+
set_peft_model_state_dict,
|
| 23 |
+
)
|
| 24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, LLaMATokenizer # noqa: F402
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def train(
|
| 28 |
+
# model/data params
|
| 29 |
+
base_model: str = "", # the only required argument
|
| 30 |
+
data_path: str = "yahma/alpaca-cleaned",
|
| 31 |
+
output_dir: str = "./lora-alpaca",
|
| 32 |
+
adapter_name: str = "lora",
|
| 33 |
+
# training hyperparams
|
| 34 |
+
batch_size: int = 128,
|
| 35 |
+
micro_batch_size: int = 4,
|
| 36 |
+
num_epochs: int = 3,
|
| 37 |
+
learning_rate: float = 3e-4,
|
| 38 |
+
cutoff_len: int = 256,
|
| 39 |
+
val_set_size: int = 2000,
|
| 40 |
+
use_gradient_checkpointing: bool = False,
|
| 41 |
+
eval_step: int = 200,
|
| 42 |
+
save_step: int = 200,
|
| 43 |
+
# lora hyperparams
|
| 44 |
+
lora_r: int = 8,
|
| 45 |
+
lora_alpha: int = 16,
|
| 46 |
+
lora_dropout: float = 0.05,
|
| 47 |
+
lora_target_modules: List[str] = None,
|
| 48 |
+
# bottleneck adapter hyperparams
|
| 49 |
+
bottleneck_size: int = 256,
|
| 50 |
+
non_linearity: str = "tanh",
|
| 51 |
+
adapter_dropout: float = 0.0,
|
| 52 |
+
use_parallel_adapter: bool = False,
|
| 53 |
+
use_adapterp: bool = False,
|
| 54 |
+
target_modules: List[str] = None,
|
| 55 |
+
scaling: Union[float, str] = 1.0,
|
| 56 |
+
# llm hyperparams
|
| 57 |
+
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
| 58 |
+
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
| 59 |
+
# wandb params
|
| 60 |
+
wandb_project: str = "",
|
| 61 |
+
wandb_run_name: str = "",
|
| 62 |
+
wandb_watch: str = "", # options: false | gradients | all
|
| 63 |
+
wandb_log_model: str = "", # options: false | true
|
| 64 |
+
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
| 65 |
+
):
|
| 66 |
+
print(
|
| 67 |
+
f"Finetuning model with params:\n"
|
| 68 |
+
f"base_model: {base_model}\n"
|
| 69 |
+
f"data_path: {data_path}\n"
|
| 70 |
+
f"output_dir: {output_dir}\n"
|
| 71 |
+
f"batch_size: {batch_size}\n"
|
| 72 |
+
f"micro_batch_size: {micro_batch_size}\n"
|
| 73 |
+
f"num_epochs: {num_epochs}\n"
|
| 74 |
+
f"learning_rate: {learning_rate}\n"
|
| 75 |
+
f"cutoff_len: {cutoff_len}\n"
|
| 76 |
+
f"val_set_size: {val_set_size}\n"
|
| 77 |
+
f"use_gradient_checkpointing: {use_gradient_checkpointing}\n"
|
| 78 |
+
f"lora_r: {lora_r}\n"
|
| 79 |
+
f"lora_alpha: {lora_alpha}\n"
|
| 80 |
+
f"lora_dropout: {lora_dropout}\n"
|
| 81 |
+
f"lora_target_modules: {lora_target_modules}\n"
|
| 82 |
+
f"bottleneck_size: {bottleneck_size}\n"
|
| 83 |
+
f"non_linearity: {non_linearity}\n"
|
| 84 |
+
f"adapter_dropout: {adapter_dropout}\n"
|
| 85 |
+
f"use_parallel_adapter: {use_parallel_adapter}\n"
|
| 86 |
+
f"use_adapterp: {use_adapterp}\n"
|
| 87 |
+
f"train_on_inputs: {train_on_inputs}\n"
|
| 88 |
+
f"scaling: {scaling}\n"
|
| 89 |
+
f"adapter_name: {adapter_name}\n"
|
| 90 |
+
f"target_modules: {target_modules}\n"
|
| 91 |
+
f"group_by_length: {group_by_length}\n"
|
| 92 |
+
f"wandb_project: {wandb_project}\n"
|
| 93 |
+
f"wandb_run_name: {wandb_run_name}\n"
|
| 94 |
+
f"wandb_watch: {wandb_watch}\n"
|
| 95 |
+
f"wandb_log_model: {wandb_log_model}\n"
|
| 96 |
+
f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
| 97 |
+
)
|
| 98 |
+
assert (
|
| 99 |
+
base_model
|
| 100 |
+
), "Please specify a --base_model, e.g. --base_model='decapoda-research/LLaMA-7b-hf'"
|
| 101 |
+
gradient_accumulation_steps = batch_size // micro_batch_size
|
| 102 |
+
|
| 103 |
+
device_map = "auto"
|
| 104 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 105 |
+
ddp = world_size != 1
|
| 106 |
+
if ddp:
|
| 107 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
| 108 |
+
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
| 109 |
+
|
| 110 |
+
# Check if parameter passed or if set within environ
|
| 111 |
+
use_wandb = len(wandb_project) > 0 or (
|
| 112 |
+
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
|
| 113 |
+
)
|
| 114 |
+
# Only overwrite environ if wandb param passed
|
| 115 |
+
if len(wandb_project) > 0:
|
| 116 |
+
os.environ["WANDB_PROJECT"] = wandb_project
|
| 117 |
+
if len(wandb_watch) > 0:
|
| 118 |
+
os.environ["WANDB_WATCH"] = wandb_watch
|
| 119 |
+
if len(wandb_log_model) > 0:
|
| 120 |
+
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
| 121 |
+
|
| 122 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 123 |
+
base_model,
|
| 124 |
+
load_in_8bit=True,
|
| 125 |
+
torch_dtype=torch.float16,
|
| 126 |
+
device_map=device_map,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if model.config.model_type == "LLaMA":
|
| 130 |
+
# Due to the name of transformers' LLaMATokenizer, we have to do this
|
| 131 |
+
tokenizer = LLaMATokenizer.from_pretrained(base_model)
|
| 132 |
+
else:
|
| 133 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
| 134 |
+
|
| 135 |
+
tokenizer.pad_token_id = (
|
| 136 |
+
0 # unk. we want this to be different from the eos token
|
| 137 |
+
)
|
| 138 |
+
tokenizer.padding_side = "left" # Allow batched inference
|
| 139 |
+
|
| 140 |
+
def tokenize(prompt, add_eos_token=True):
|
| 141 |
+
# there's probably a way to do this with the tokenizer settings
|
| 142 |
+
# but again, gotta move fast
|
| 143 |
+
result = tokenizer(
|
| 144 |
+
prompt,
|
| 145 |
+
truncation=True,
|
| 146 |
+
max_length=cutoff_len,
|
| 147 |
+
padding=False,
|
| 148 |
+
return_tensors=None,
|
| 149 |
+
)
|
| 150 |
+
if (
|
| 151 |
+
result["input_ids"][-1] != tokenizer.eos_token_id
|
| 152 |
+
and len(result["input_ids"]) < cutoff_len
|
| 153 |
+
and add_eos_token
|
| 154 |
+
):
|
| 155 |
+
result["input_ids"].append(tokenizer.eos_token_id)
|
| 156 |
+
result["attention_mask"].append(1)
|
| 157 |
+
|
| 158 |
+
result["labels"] = result["input_ids"].copy()
|
| 159 |
+
|
| 160 |
+
return result
|
| 161 |
+
|
| 162 |
+
def generate_and_tokenize_prompt(data_point):
|
| 163 |
+
full_prompt = generate_prompt(data_point)
|
| 164 |
+
tokenized_full_prompt = tokenize(full_prompt)
|
| 165 |
+
if not train_on_inputs:
|
| 166 |
+
user_prompt = generate_prompt({**data_point, "output": ""})
|
| 167 |
+
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
|
| 168 |
+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
| 169 |
+
|
| 170 |
+
tokenized_full_prompt["labels"] = [
|
| 171 |
+
-100
|
| 172 |
+
] * user_prompt_len + tokenized_full_prompt["labels"][
|
| 173 |
+
user_prompt_len:
|
| 174 |
+
] # could be sped up, probably
|
| 175 |
+
return tokenized_full_prompt
|
| 176 |
+
|
| 177 |
+
model = prepare_model_for_int8_training(model, use_gradient_checkpointing=use_gradient_checkpointing)
|
| 178 |
+
if adapter_name == "lora":
|
| 179 |
+
config = LoraConfig(
|
| 180 |
+
r=lora_r,
|
| 181 |
+
lora_alpha=lora_alpha,
|
| 182 |
+
target_modules=lora_target_modules,
|
| 183 |
+
lora_dropout=lora_dropout,
|
| 184 |
+
bias="none",
|
| 185 |
+
task_type="CAUSAL_LM",
|
| 186 |
+
)
|
| 187 |
+
elif adapter_name == "bottleneck":
|
| 188 |
+
config = BottleneckConfig(
|
| 189 |
+
bottleneck_size=bottleneck_size,
|
| 190 |
+
non_linearity=non_linearity,
|
| 191 |
+
adapter_dropout=adapter_dropout,
|
| 192 |
+
use_parallel_adapter=use_parallel_adapter,
|
| 193 |
+
use_adapterp=use_adapterp,
|
| 194 |
+
target_modules=target_modules,
|
| 195 |
+
scaling=scaling,
|
| 196 |
+
bias="none",
|
| 197 |
+
task_type="CAUSAL_LM",
|
| 198 |
+
)
|
| 199 |
+
model = get_peft_model(model, config)
|
| 200 |
+
|
| 201 |
+
if data_path.endswith(".json"): # todo: support jsonl
|
| 202 |
+
data = load_dataset("json", data_files=data_path)
|
| 203 |
+
else:
|
| 204 |
+
data = load_dataset(data_path)
|
| 205 |
+
|
| 206 |
+
if resume_from_checkpoint:
|
| 207 |
+
# Check the available weights and load them
|
| 208 |
+
checkpoint_name = os.path.join(
|
| 209 |
+
resume_from_checkpoint, "pytorch_model.bin"
|
| 210 |
+
) # Full checkpoint
|
| 211 |
+
if not os.path.exists(checkpoint_name):
|
| 212 |
+
checkpoint_name = os.path.join(
|
| 213 |
+
resume_from_checkpoint, "adapter_model.bin"
|
| 214 |
+
) # only LoRA model - LoRA config above has to fit
|
| 215 |
+
resume_from_checkpoint = (
|
| 216 |
+
False # So the trainer won't try loading its state
|
| 217 |
+
)
|
| 218 |
+
# The two files above have a different name depending on how they were saved, but are actually the same.
|
| 219 |
+
if os.path.exists(checkpoint_name):
|
| 220 |
+
print(f"Restarting from {checkpoint_name}")
|
| 221 |
+
adapters_weights = torch.load(checkpoint_name)
|
| 222 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
| 223 |
+
else:
|
| 224 |
+
print(f"Checkpoint {checkpoint_name} not found")
|
| 225 |
+
|
| 226 |
+
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
| 227 |
+
|
| 228 |
+
if val_set_size > 0:
|
| 229 |
+
train_val = data["train"].train_test_split(
|
| 230 |
+
test_size=val_set_size, shuffle=True, seed=42
|
| 231 |
+
)
|
| 232 |
+
train_data = (
|
| 233 |
+
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
| 234 |
+
)
|
| 235 |
+
val_data = (
|
| 236 |
+
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
| 240 |
+
val_data = None
|
| 241 |
+
|
| 242 |
+
if not ddp and torch.cuda.device_count() > 1:
|
| 243 |
+
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
| 244 |
+
model.is_parallelizable = True
|
| 245 |
+
model.model_parallel = True
|
| 246 |
+
|
| 247 |
+
trainer = transformers.Trainer(
|
| 248 |
+
model=model,
|
| 249 |
+
train_dataset=train_data,
|
| 250 |
+
eval_dataset=val_data,
|
| 251 |
+
args=transformers.TrainingArguments(
|
| 252 |
+
per_device_train_batch_size=micro_batch_size,
|
| 253 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 254 |
+
warmup_steps=100,
|
| 255 |
+
num_train_epochs=num_epochs,
|
| 256 |
+
learning_rate=learning_rate,
|
| 257 |
+
fp16=True,
|
| 258 |
+
logging_steps=10,
|
| 259 |
+
optim="adamw_torch",
|
| 260 |
+
evaluation_strategy="steps" if val_set_size > 0 else "no",
|
| 261 |
+
save_strategy="steps",
|
| 262 |
+
eval_steps=eval_step if val_set_size > 0 else None,
|
| 263 |
+
save_steps=save_step,
|
| 264 |
+
output_dir=output_dir,
|
| 265 |
+
save_total_limit=3,
|
| 266 |
+
load_best_model_at_end=True if val_set_size > 0 else False,
|
| 267 |
+
ddp_find_unused_parameters=False if ddp else None,
|
| 268 |
+
group_by_length=group_by_length,
|
| 269 |
+
report_to="wandb" if use_wandb else None,
|
| 270 |
+
run_name=wandb_run_name if use_wandb else None,
|
| 271 |
+
),
|
| 272 |
+
data_collator=transformers.DataCollatorForSeq2Seq(
|
| 273 |
+
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
| 274 |
+
),
|
| 275 |
+
)
|
| 276 |
+
model.config.use_cache = False
|
| 277 |
+
|
| 278 |
+
old_state_dict = model.state_dict
|
| 279 |
+
model.state_dict = (
|
| 280 |
+
lambda self, *_, **__: get_peft_model_state_dict(
|
| 281 |
+
self, old_state_dict()
|
| 282 |
+
)
|
| 283 |
+
).__get__(model, type(model))
|
| 284 |
+
|
| 285 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
| 286 |
+
model = torch.compile(model)
|
| 287 |
+
|
| 288 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 289 |
+
|
| 290 |
+
model.save_pretrained(output_dir)
|
| 291 |
+
|
| 292 |
+
print(
|
| 293 |
+
"\n If there's a warning about missing keys above, please disregard :)"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def generate_prompt(data_point):
|
| 298 |
+
# sorry about the formatting disaster gotta move fast
|
| 299 |
+
if data_point["input"]:
|
| 300 |
+
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
| 301 |
+
|
| 302 |
+
### Instruction:
|
| 303 |
+
{data_point["instruction"]}
|
| 304 |
+
|
| 305 |
+
### Input:
|
| 306 |
+
{data_point["input"]}
|
| 307 |
+
|
| 308 |
+
### Response:
|
| 309 |
+
{data_point["output"]}""" # noqa: E501
|
| 310 |
+
else:
|
| 311 |
+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
| 312 |
+
|
| 313 |
+
### Instruction:
|
| 314 |
+
{data_point["instruction"]}
|
| 315 |
+
|
| 316 |
+
### Response:
|
| 317 |
+
{data_point["output"]}""" # noqa: E501
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
fire.Fire(train)
|
utils.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import io
|
| 6 |
+
import sys
|
| 7 |
+
import time
|
| 8 |
+
import json
|
| 9 |
+
from typing import Optional, Sequence, Union
|
| 10 |
+
|
| 11 |
+
import openai
|
| 12 |
+
import tqdm
|
| 13 |
+
from openai import openai_object
|
| 14 |
+
import copy
|
| 15 |
+
|
| 16 |
+
StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
|
| 17 |
+
|
| 18 |
+
openai.api_key =''
|
| 19 |
+
openai_org = os.getenv("OPENAI_ORG")
|
| 20 |
+
if openai_org is not None:
|
| 21 |
+
openai.organization = openai_org
|
| 22 |
+
logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclasses.dataclass
|
| 26 |
+
class OpenAIDecodingArguments(object):
|
| 27 |
+
max_tokens: int = 1800
|
| 28 |
+
temperature: float = 0.2
|
| 29 |
+
top_p: float = 1.0
|
| 30 |
+
n: int = 1
|
| 31 |
+
stream: bool = False
|
| 32 |
+
stop: Optional[Sequence[str]] = None
|
| 33 |
+
presence_penalty: float = 0.0
|
| 34 |
+
frequency_penalty: float = 0.0
|
| 35 |
+
suffix: Optional[str] = None
|
| 36 |
+
logprobs: Optional[int] = None
|
| 37 |
+
echo: bool = False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def openai_completion(
|
| 41 |
+
prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
|
| 42 |
+
decoding_args: OpenAIDecodingArguments,
|
| 43 |
+
model_name="text-davinci-003",
|
| 44 |
+
sleep_time=2,
|
| 45 |
+
batch_size=1,
|
| 46 |
+
max_instances=sys.maxsize,
|
| 47 |
+
max_batches=sys.maxsize,
|
| 48 |
+
return_text=False,
|
| 49 |
+
**decoding_kwargs,
|
| 50 |
+
) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
|
| 51 |
+
"""Decode with OpenAI API.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
|
| 55 |
+
as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
|
| 56 |
+
it can also be a dictionary (or list thereof) as explained here:
|
| 57 |
+
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
| 58 |
+
decoding_args: Decoding arguments.
|
| 59 |
+
model_name: Model name. Can be either in the format of "org/model" or just "model".
|
| 60 |
+
sleep_time: Time to sleep once the rate-limit is hit.
|
| 61 |
+
batch_size: Number of prompts to send in a single request. Only for non chat model.
|
| 62 |
+
max_instances: Maximum number of prompts to decode.
|
| 63 |
+
max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
|
| 64 |
+
return_text: If True, return text instead of full completion object (which contains things like logprob).
|
| 65 |
+
decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
A completion or a list of completions.
|
| 69 |
+
Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
|
| 70 |
+
- a string (if return_text is True)
|
| 71 |
+
- an openai_object.OpenAIObject object (if return_text is False)
|
| 72 |
+
- a list of objects of the above types (if decoding_args.n > 1)
|
| 73 |
+
"""
|
| 74 |
+
is_single_prompt = isinstance(prompts, (str, dict))
|
| 75 |
+
if is_single_prompt:
|
| 76 |
+
prompts = [prompts]
|
| 77 |
+
|
| 78 |
+
if max_batches < sys.maxsize:
|
| 79 |
+
logging.warning(
|
| 80 |
+
"`max_batches` will be deprecated in the future, please use `max_instances` instead."
|
| 81 |
+
"Setting `max_instances` to `max_batches * batch_size` for now."
|
| 82 |
+
)
|
| 83 |
+
max_instances = max_batches * batch_size
|
| 84 |
+
|
| 85 |
+
prompts = prompts[:max_instances]
|
| 86 |
+
num_prompts = len(prompts)
|
| 87 |
+
prompt_batches = [
|
| 88 |
+
prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
|
| 89 |
+
for batch_id in range(int(math.ceil(num_prompts / batch_size)))
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
completions = []
|
| 93 |
+
for batch_id, prompt_batch in tqdm.tqdm(
|
| 94 |
+
enumerate(prompt_batches),
|
| 95 |
+
desc="prompt_batches",
|
| 96 |
+
total=len(prompt_batches),
|
| 97 |
+
):
|
| 98 |
+
batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args
|
| 99 |
+
|
| 100 |
+
while True:
|
| 101 |
+
try:
|
| 102 |
+
shared_kwargs = dict(
|
| 103 |
+
model=model_name,
|
| 104 |
+
**batch_decoding_args.__dict__,
|
| 105 |
+
**decoding_kwargs,
|
| 106 |
+
)
|
| 107 |
+
completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
|
| 108 |
+
choices = completion_batch.choices
|
| 109 |
+
|
| 110 |
+
for choice in choices:
|
| 111 |
+
choice["total_tokens"] = completion_batch.usage.total_tokens
|
| 112 |
+
completions.extend(choices)
|
| 113 |
+
break
|
| 114 |
+
except openai.error.OpenAIError as e:
|
| 115 |
+
logging.warning(f"OpenAIError: {e}.")
|
| 116 |
+
if "Please reduce your prompt" in str(e):
|
| 117 |
+
batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
|
| 118 |
+
logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
|
| 119 |
+
else:
|
| 120 |
+
logging.warning("Hit request rate limit; retrying...")
|
| 121 |
+
time.sleep(sleep_time) # Annoying rate limit on requests.
|
| 122 |
+
|
| 123 |
+
if return_text:
|
| 124 |
+
completions = [completion.text for completion in completions]
|
| 125 |
+
if decoding_args.n > 1:
|
| 126 |
+
# make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
|
| 127 |
+
completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
|
| 128 |
+
if is_single_prompt:
|
| 129 |
+
# Return non-tuple if only 1 input and 1 generation.
|
| 130 |
+
(completions,) = completions
|
| 131 |
+
return completions
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _make_w_io_base(f, mode: str):
|
| 135 |
+
if not isinstance(f, io.IOBase):
|
| 136 |
+
f_dirname = os.path.dirname(f)
|
| 137 |
+
if f_dirname != "":
|
| 138 |
+
os.makedirs(f_dirname, exist_ok=True)
|
| 139 |
+
f = open(f, mode=mode)
|
| 140 |
+
return f
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _make_r_io_base(f, mode: str):
|
| 144 |
+
if not isinstance(f, io.IOBase):
|
| 145 |
+
f = open(f, mode=mode)
|
| 146 |
+
return f
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def jdump(obj, f, mode="w", indent=4, default=str):
|
| 150 |
+
"""Dump a str or dictionary to a file in json format.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
obj: An object to be written.
|
| 154 |
+
f: A string path to the location on disk.
|
| 155 |
+
mode: Mode for opening the file.
|
| 156 |
+
indent: Indent for storing json dictionaries.
|
| 157 |
+
default: A function to handle non-serializable entries; defaults to `str`.
|
| 158 |
+
"""
|
| 159 |
+
f = _make_w_io_base(f, mode)
|
| 160 |
+
if isinstance(obj, (dict, list)):
|
| 161 |
+
json.dump(obj, f, indent=indent, default=default)
|
| 162 |
+
elif isinstance(obj, str):
|
| 163 |
+
f.write(obj)
|
| 164 |
+
else:
|
| 165 |
+
raise ValueError(f"Unexpected type: {type(obj)}")
|
| 166 |
+
f.close()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def jload(f, mode="r"):
|
| 170 |
+
"""Load a .json file into a dictionary."""
|
| 171 |
+
f = _make_r_io_base(f, mode)
|
| 172 |
+
jdict = json.load(f)
|
| 173 |
+
f.close()
|
| 174 |
+
return jdict
|