Muhammadidrees's picture
Upload 15 files
373f237 verified
raw
history blame
10.8 kB
import os
import gc
import torch
import gradio as gr
from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
# =============================
# Configuration
# =============================
MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
MAX_NEW_TOKENS = 200
TEMPERATURE = 0.5
TOP_K = 50
REPETITION_PENALTY = 1.1
# Detect device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model from {MODEL_PATH} on {device}...")
# =============================
# Load Tokenizer and Model
# =============================
tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LlamaForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
generator = model.generate
print("βœ… ChatDoctor model loaded successfully!\n")
# =============================
# System Prompt
# =============================
SYSTEM_PROMPT = """
You are ChatDoctor β€” a friendly, professional, and caring virtual doctor.
Whenever a patient describes their symptoms:
1. Always include a recommendation for diet, fluids, and proteins appropriate for recovery.
- Fruits: citrus (orange, lemon), kiwi, papaya
- Vegetables: leafy greens, carrots, spinach
- Fluids: warm soups, herbal teas, coconut water
- Proteins: boiled eggs, lentils, fish, chicken soup
- Extras: garlic, ginger, turmeric
2. Recommend safe over-the-counter medicines if applicable (e.g., paracetamol for fever).
3. Ask follow-up questions if needed to understand the patient's condition better.
4. Always encourage the patient to see a real doctor if symptoms persist, worsen, or are serious.
5. Provide clear, warm, and empathetic advice.
6. Make your response structured and easy to understand.
7. Even if the patient only mentions a symptom, always include diet, fluids, protein, and care suggestions automatically.
"""
# =============================
# Stopping Criteria
# =============================
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_ids):
self.stop_ids = stop_ids
def __call__(self, input_ids, scores, **kwargs):
for stop_id_seq in self.stop_ids:
if len(stop_id_seq) == 1:
if input_ids[0][-1] == stop_id_seq[0]:
return True
else:
if len(input_ids[0]) >= len(stop_id_seq):
if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
return True
return False
# =============================
# Chat History (Global)
# =============================
conversation_history = []
# =============================
# Get Response Function
# =============================
def get_response(user_input, history_context):
"""Generate response from ChatDoctor model"""
# Build conversation from history
history_text = []
for human, assistant in history_context:
if human:
history_text.append("Patient: " + human)
if assistant:
history_text.append("ChatDoctor: " + assistant)
# Add current user input
history_text.append("Patient: " + user_input)
# Build full prompt including system instructions
prompt = SYSTEM_PROMPT + "\n\nConversation so far:\n" + "\n".join(history_text) + "\nChatDoctor:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Define stop words and their token IDs
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
# Generate model response
with torch.no_grad():
output_ids = generator(
input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_k=TOP_K,
repetition_penalty=REPETITION_PENALTY,
stopping_criteria=stopping_criteria,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Decode and clean response
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
response = full_output[len(prompt):].strip()
# Remove any "Patient:" that might have slipped through
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
if stop_word in response:
response = response.split(stop_word)[0].strip()
break
# Free memory
del input_ids, output_ids
gc.collect()
torch.cuda.empty_cache()
return response
# =============================
# Gradio Chat Function
# =============================
def chat_function(message, history):
"""Gradio chat interface function"""
if not message.strip():
return ""
try:
response = get_response(message, history)
return response
except Exception as e:
return f"Error: {str(e)}"
# =============================
# Custom CSS
# =============================
custom_css = """
#header {
text-align: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
}
#header h1 {
margin: 0;
font-size: 2.5em;
}
#header p {
margin: 10px 0 0 0;
font-size: 1.1em;
opacity: 0.9;
}
.disclaimer {
background-color: #fff3cd;
border: 1px solid #ffc107;
border-radius: 8px;
padding: 15px;
margin: 20px 0;
color: #856404;
}
.disclaimer h3 {
margin-top: 0;
color: #856404;
}
footer {
text-align: center;
margin-top: 30px;
color: #666;
font-size: 0.9em;
}
"""
# =============================
# Gradio Interface
# =============================
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
# Header
gr.HTML("""
<div id="header">
<h1>🩺 ChatDoctor AI Assistant</h1>
<p>Your AI-powered medical conversation partner</p>
</div>
""")
# Disclaimer
gr.HTML("""
<div class="disclaimer">
<h3>⚠️ Medical Disclaimer</h3>
<p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
It is NOT a substitute for professional medical advice, diagnosis, or treatment.
Always seek the advice of your physician or other qualified health provider with any questions
you may have regarding a medical condition. Never disregard professional medical advice or
delay in seeking it because of something you have read here.</p>
</div>
""")
# Chatbot Interface
chatbot = gr.Chatbot(
height=500,
placeholder="<div style='text-align: center; padding: 40px;'><h3>πŸ‘‹ Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. How can I assist you today?</p></div>",
show_label=False,
avatar_images=(None, "πŸ€–"),
)
with gr.Row():
msg = gr.Textbox(
placeholder="Type your message here... (e.g., 'I have a headache')",
show_label=False,
scale=9,
container=False
)
submit_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
retry_btn = gr.Button("πŸ”„ Retry", scale=1)
# Examples
gr.Examples(
examples=[
"I have a persistent headache for 3 days. What should I do?",
"What are the symptoms of diabetes?",
"How can I improve my sleep quality?",
"I have a fever and sore throat. Should I be concerned?",
"What are some natural ways to reduce stress?",
],
inputs=msg,
label="πŸ’‘ Example Questions"
)
# Settings (collapsed by default)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
temperature_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=TEMPERATURE,
step=0.1,
label="Temperature (Creativity)",
info="Higher values make responses more creative but less focused"
)
max_tokens_slider = gr.Slider(
minimum=50,
maximum=500,
value=MAX_NEW_TOKENS,
step=50,
label="Max Response Length",
info="Maximum number of tokens in response"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=TOP_K,
step=1,
label="Top K",
info="Limits vocabulary selection"
)
# Footer
gr.HTML(f"""
<footer>
<p>Powered by ChatDoctor Model | Built with Gradio</p>
<p>Device: {device.upper()} | Model: LLaMA-based Medical AI</p>
</footer>
""")
# Event handlers
def user_message(user_msg, history):
return "", history + [[user_msg, None]]
def bot_response(history, temp, max_tok, top_k_val):
global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
TEMPERATURE = temp
MAX_NEW_TOKENS = int(max_tok)
TOP_K = int(top_k_val)
user_msg = history[-1][0]
bot_msg = chat_function(user_msg, history[:-1])
history[-1][1] = bot_msg
return history
# Connect events
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
)
submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
def retry_last():
return None
retry_btn.click(retry_last, None, chatbot, queue=False)
# =============================
# Launch Interface
# =============================
if __name__ == "__main__":
print("\nπŸš€ Launching ChatDoctor Gradio Interface...")
demo.queue()
demo.launch(
server_name="0.0.0.0", # Accessible from network
server_port=7860,
share=False, # Set to True to create public link
show_error=True
)