Muhammadidrees's picture
Update app.py
429fe72 verified
# chat.py
import os
import gc
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
# =============================
# Configuration
# =============================
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.5
TOP_K = 50
REPETITION_PENALTY = 1.1
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model from {MODEL_PATH} on {device}...")
# =============================
# Load Tokenizer and Model
# =============================
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LlamaForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto", # automatically dispatch weights to GPU
torch_dtype=torch.float16, # half precision for faster inference
low_cpu_mem_usage=True # optimize CPU memory
)
# DO NOT call model.to(device) when using device_map="auto"
generator = model.generate
print("✅ Model loaded successfully!\n")
# =============================
# Chat History
# =============================
systemprompt = ("""You are ChatDoctor — an intelligent, empathetic medical AI assistant.
Your role is to carefully gather medical information, reason clinically,
and provide safe, evidence-based guidance.
Follow these instructions strictly:
1. When a patient describes their illness, DO NOT diagnose immediately.
2. Ask relevant, targeted questions to collect all necessary details
such as symptoms, duration, severity, lifestyle habits, medical history,
medications, and any recent tests or changes.
3. Once you have enough information for a preliminary diagnosis, clearly
explain your reasoning and possible causes in simple medical language.
4. Then, provide a clear and structured response that includes:
- **Diagnosis:** probable or confirmed condition(s)
- **Dietary Advice:** foods to include and avoid
- **Lifestyle Advice:** exercise, sleep, stress, and other habits
5. Be concise, empathetic, and professional at all times.
6. Never switch roles or generate “Patient:” responses. Always remain as ChatDoctor.
7. If symptoms suggest a serious or emergency condition, advise the patient
to seek immediate medical attention.""")
history = [systemprompt, "ChatDoctor: I am ChatDoctor, what medical questions do you have?"]
# =============================
# Response Function
# =============================
def get_response(user_input):
global history
human_invitation = "Patient: "
doctor_invitation = "ChatDoctor: "
# Append user input
history.append(human_invitation + user_input)
# Build prompt
prompt = "\n".join(history) + "\n" + doctor_invitation
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Generate response
with torch.no_grad():
output_ids = generator(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_k=TOP_K,
repetition_penalty=REPETITION_PENALTY
)
# Decode response
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
response = full_output[len(prompt):].strip()
# Clean if the model repeats the patient prompt
if response.startswith("Patient:"):
response = response[len("Patient:"):].strip()
# Append model response to history
history.append(doctor_invitation + response)
# Free memory
del input_ids, output_ids
gc.collect()
torch.cuda.empty_cache()
return response
# =============================
# CLI Chat
# =============================
if __name__ == "__main__":
print("\n=== ChatDoctor is ready! Type your questions. ===\n")
while True:
try:
user_input = input("Patient: ").strip()
if user_input.lower() in ["exit", "quit"]:
print("Exiting ChatDoctor. Goodbye!")
break
response = get_response(user_input)
print("ChatDoctor: " + response + "\n")
except KeyboardInterrupt:
print("\nExiting ChatDoctor. Goodbye!")
break