|
|
import os, json, itertools, bisect, gc
|
|
|
from transformers import LlamaTokenizer, LlamaForCausalLM
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
|
|
|
import transformers
|
|
|
import torch
|
|
|
from accelerate import Accelerator
|
|
|
import accelerate
|
|
|
import time
|
|
|
|
|
|
model = None
|
|
|
tokenizer = None
|
|
|
generator = None
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
|
|
|
|
|
def load_model(model_name, eight_bit=0, device_map="auto"):
|
|
|
global model, tokenizer, generator
|
|
|
|
|
|
print("Loading "+model_name+"...")
|
|
|
|
|
|
if device_map == "zero":
|
|
|
device_map = "balanced_low_0"
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
|
print('gpu_count', gpu_count)
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
model_name,
|
|
|
|
|
|
|
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
|
|
|
|
|
|
|
|
|
|
low_cpu_mem_usage=True,
|
|
|
load_in_8bit=False,
|
|
|
cache_dir="cache"
|
|
|
).to(device)
|
|
|
|
|
|
generator = model.generate
|
|
|
|
|
|
load_model(r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained")
|
|
|
|
|
|
First_chat = "ChatDoctor: I am ChatDoctor, what medical questions do you have?"
|
|
|
print(First_chat)
|
|
|
history = []
|
|
|
history.append(First_chat)
|
|
|
|
|
|
def go():
|
|
|
invitation = "ChatDoctor: "
|
|
|
human_invitation = "Patient: "
|
|
|
|
|
|
|
|
|
msg = input(human_invitation)
|
|
|
print("")
|
|
|
|
|
|
history.append(human_invitation + msg)
|
|
|
|
|
|
fulltext = "If you are a doctor, please answer the medical questions based on the patient's description. \n\n" + "\n\n".join(history) + "\n\n" + invitation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generated_text = ""
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
gen_in = tokenizer(fulltext, return_tensors="pt").input_ids.to(device)
|
|
|
in_tokens = len(gen_in)
|
|
|
with torch.no_grad():
|
|
|
generated_ids = generator(
|
|
|
gen_in,
|
|
|
max_new_tokens=200,
|
|
|
use_cache=True,
|
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
|
num_return_sequences=1,
|
|
|
do_sample=True,
|
|
|
repetition_penalty=1.1,
|
|
|
temperature=0.5,
|
|
|
top_k = 50,
|
|
|
top_p = 1.0,
|
|
|
early_stopping=True,
|
|
|
)
|
|
|
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
|
|
|
text_without_prompt = generated_text[len(fulltext):]
|
|
|
|
|
|
response = text_without_prompt
|
|
|
|
|
|
response = response.split(human_invitation)[0]
|
|
|
|
|
|
response.strip()
|
|
|
|
|
|
print(invitation + response)
|
|
|
|
|
|
print("")
|
|
|
|
|
|
history.append(invitation + response)
|
|
|
|
|
|
while True:
|
|
|
go()
|
|
|
|