HaryaniAnjali's picture
Update app.py
5edb16c verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import gradio as gr
import os
# Setup cache directory for models
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
class CustomChatDoctor:
def __init__(self):
print("Initializing ChatDoctor...")
# Determine device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Model name
model_name = "zl111/ChatDoctor"
# Setup quantization for memory efficiency
if self.device == "cuda":
# Use 8-bit quantization if GPU is available
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True
)
else:
# For CPU, use lighter settings
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True
)
print("Model loaded successfully!")
# Define system prompt
self.system_prompt = """You are a helpful, respectful and honest medical assistant.
Always answer as helpfully as possible, while being safe.
Your answers should be based on verified medical information.
If a question doesn't make any sense, or is not factually coherent, explain why instead of answering something incorrect.
If you don't know the answer to a question, respond with "I don't have enough information to provide a reliable answer."
Always include a disclaimer that you are an AI assistant and not a licensed medical professional.
"""
# Initialize conversation history
self.reset_conversation()
def generate_response(self, user_input):
try:
# Add user input to history
self.conversation_history.append(f"User: {user_input}")
# Construct the prompt
prompt = self.system_prompt + "\n\n"
prompt += "\n".join(self.conversation_history[-10:])
prompt += "\nAI Assistant: "
# Generate the response with appropriate parameters based on device
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad(): # Disable gradient calculations for inference
generation_config = {
"max_new_tokens": 512,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": True,
"pad_token_id": self.tokenizer.eos_token_id
}
outputs = self.model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
**generation_config
)
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# Add response to history
self.conversation_history.append(f"AI Assistant: {response}")
return response
except Exception as e:
# Handle any errors during generation
error_message = f"An error occurred: {str(e)}"
print(error_message)
return "I'm sorry, I encountered an error processing your question. Please try again or ask a different question."
def reset_conversation(self):
self.conversation_history = []
return None
# Create a singleton instance to avoid reloading the model for each user
chat_doctor = None
def get_chat_doctor():
global chat_doctor
if chat_doctor is None:
chat_doctor = CustomChatDoctor()
return chat_doctor
# Example inputs to help users get started
examples = [
"What are the symptoms of diabetes?",
"How can I manage my hypertension?",
"What should I do for a persistent headache?",
"Can you explain what asthma is?",
"What are the side effects of ibuprofen?"
]
# Add CSS for styling
css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.disclaimer {
margin-top: 20px;
padding: 10px;
background-color: #f8f9fa;
border-left: 3px solid #f0ad4e;
font-size: 14px;
}
"""
# Create welcome message function
def welcome():
return "Welcome to ChatDoctor! I'm an AI assistant trained to provide medical information. How can I help you today?"
# Build the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("# Your Custom ChatDoctor")
gr.Markdown("Ask medical questions and get AI-powered responses.")
chatbot = gr.Chatbot(height=600, type="messages")
msg = gr.Textbox(placeholder="Type your medical question here...", lines=2)
with gr.Row():
submit_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear Conversation")
# Display example queries that users can click on
gr.Examples(
examples=examples,
inputs=msg
)
with gr.Accordion("About this AI", open=False):
gr.Markdown("""
**ChatDoctor** is a medical conversation model designed to provide general health information.
This AI uses language models to generate responses based on patterns learned from medical texts and conversations.
**Important Notes:**
- This system is for informational purposes only
- Not a substitute for professional medical advice
- In emergencies, contact emergency services immediately
""")
# Add disclaimer at the bottom with custom styling
gr.HTML("""
<div class="disclaimer">
<strong>Disclaimer:</strong> This AI assistant provides information for educational purposes only.
Always consult with a qualified healthcare provider for medical advice, diagnosis, or treatment.
This tool is not intended to replace professional medical consultation.
</div>
""")
def respond(message, chat_history):
# Lazy-load model on first request
doctor = get_chat_doctor()
if message.strip() == "":
return "", chat_history
bot_message = doctor.generate_response(message)
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": bot_message})
return "", chat_history
def clear_history():
doctor = get_chat_doctor()
doctor.reset_conversation()
return None
# Show welcome message when the app starts
demo.load(lambda: None, None, chatbot, js="""
() => {
const welcomeMsg = "Welcome to ChatDoctor! I'm an AI assistant trained to provide medical information. How can I help you today?";
return [[null, welcomeMsg]];
}
""")
submit_btn.click(respond, [msg, chatbot], [msg, chatbot])
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear_btn.click(clear_history, None, chatbot)
# Launch the app
demo.launch()