JAYConverstionalAI / withoutfrontend.py
Muhammadidrees's picture
Rename app.py to withoutfrontend.py
f5f43e6 verified
import os
import gc
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
from huggingface_hub import login
import os
# Login using the token stored in repository secrets
login(token=os.getenv("HUGGINGFACE_TOKEN"))
# =============================
# Configuration
# =============================
MODEL_PATH = r"zl111/ChatDoctor"
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.5
TOP_K = 50
REPETITION_PENALTY = 1.1
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model from {MODEL_PATH} on {device}...")
# =============================
# Load Tokenizer and Model
# =============================
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LlamaForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
generator = model.generate
print("✅ ChatDoctor model loaded successfully!\n")
# =============================
# Stopping Criteria
# =============================
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_ids):
self.stop_ids = stop_ids
def __call__(self, input_ids, scores, **kwargs):
for stop_id_seq in self.stop_ids:
if len(stop_id_seq) == 1:
if input_ids[0][-1] == stop_id_seq[0]:
return True
else:
if len(input_ids[0]) >= len(stop_id_seq):
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
return True
return False
# =============================
# Chat History
# =============================
history = ["ChatDoctor: I am ChatDoctor, your AI medical assistant. How can I help you today?"]
# =============================
# Get Response Function
# =============================
def get_response(user_input):
global history
human_invitation = "Patient: "
doctor_invitation = "ChatDoctor: "
# Add user input to history
history.append(human_invitation + user_input)
# Build conversation prompt
prompt = "\n".join(history) + "\n" + doctor_invitation
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Define stop words and their token IDs
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
# Generate model response
with torch.no_grad():
output_ids = generator(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_k=TOP_K,
repetition_penalty=REPETITION_PENALTY,
stopping_criteria=stopping_criteria,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode and clean response
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
response = full_output[len(prompt):].strip()
# Remove any "Patient:" that might have slipped through
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
if stop_word in response:
response = response.split(stop_word)[0].strip()
break
# Remove any leading/trailing punctuation artifacts
response = response.strip()
history.append(doctor_invitation + response)
# Free memory
del input_ids, output_ids
gc.collect()
torch.cuda.empty_cache()
return response
# =============================
# Chat Loop
# =============================
if __name__ == "__main__":
print("\n=== ChatDoctor is ready! ===")
print("You (the human) = Patient ")
print("AI = ChatDoctor")
print("Type 'exit' or 'quit' to end the chat.\n")
print("ChatDoctor: Hi there! How can I help you today?\n")
while True:
try:
user_input = input("Patient: ").strip()
if user_input.lower() in ["exit", "quit"]:
print("ChatDoctor: Take care! Goodbye ")
break
if not user_input:
continue
response = get_response(user_input)
print("ChatDoctor:", response, "\n")
except KeyboardInterrupt:
print("\nChatDoctor: Take care! Goodbye")
break
except Exception as e:
print(f"Error: {e}")
print("Please try again.\n")