symptom_gpt2 / app.py
Branis333's picture
Update app.py
8a209f1 verified
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
# Load your fine-tuned model from Hugging Face
MODEL_NAME = "Branis333/symptom-gpt2-chatbot"
print("Loading model and tokenizer...")
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(f"✓ Model loaded successfully on {device}")
def respond(
message,
history,
system_message,
max_tokens,
temperature,
top_p,
):
"""
Generate response from the medical chatbot
"""
# Format the prompt with conversation history
conversation = system_message + "\n\n"
# Add chat history
for entry in history:
if isinstance(entry, dict):
# New format: messages with 'role' and 'content'
if entry.get('role') == 'user':
conversation += f"User: {entry['content']}\n"
elif entry.get('role') == 'assistant':
conversation += f"Bot: {entry['content']}\n"
else:
# Old format: tuples
user_msg, bot_msg = entry
conversation += f"User: {user_msg}\nBot: {bot_msg}\n"
# Add current message
conversation += f"User: {message}\nBot:"
# Tokenize and generate
inputs = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=1024).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode the response
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the bot's response
if "Bot:" in full_response:
response = full_response.split("Bot:")[-1].strip()
# Remove any trailing "User:" if present
if "User:" in response:
response = response.split("User:")[0].strip()
else:
response = full_response
return response
# Add CSS to make Send button more visible
custom_css = """
button[type="submit"] {
min-width: 100px !important;
padding: 10px 20px !important;
font-size: 14px !important;
font-weight: 600 !important;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: 2px solid #667eea !important;
color: white !important;
border-radius: 6px !important;
cursor: pointer !important;
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important;
transition: all 0.2s ease !important;
}
button[type="submit"]:hover {
background: linear-gradient(135deg, #764ba2 0%, #667eea 100%) !important;
box-shadow: 0 6px 16px rgba(102, 126, 234, 0.6) !important;
transform: translateY(-2px) !important;
}
"""
# Create Gradio ChatInterface
chatbot = gr.ChatInterface(
respond,
type="messages", # Use new message format
chatbot=gr.Chatbot(
height=500,
),
textbox=gr.Textbox(
placeholder="Ask me about your symptoms or medical questions...",
container=True, # keep the input container visible so the send button renders
scale=7
),
title="🏥 Medical Symptom Chatbot",
description="Ask questions about symptoms, diseases, and medical conditions. This bot is trained on medical Q&A data. For informational purposes only - always consult healthcare professionals.",
theme="soft",
examples=[
["I have fever and cough. What could this be?"],
["What are the symptoms of diabetes?"],
["What is hypertension?"],
["I have a headache and nausea. What should I do?"],
["What are the precautions for common cold?"],
],
cache_examples=False,
additional_inputs=[
gr.Textbox(
value="You are a helpful medical chatbot that provides information about symptoms and diseases. Always recommend consulting a healthcare professional for serious conditions.",
label="System Message",
lines=3
),
gr.Slider(
minimum=50,
maximum=300,
value=150,
step=10,
label="Max Tokens",
info="Maximum length of the response"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher = more creative, Lower = more focused"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p (Nucleus Sampling)",
info="Controls diversity of responses"
),
],
submit_btn="Send",
css=custom_css,
)
# Launch the app
if __name__ == "__main__":
chatbot.launch(
share=False, # Set to True to create a public link
server_name="0.0.0.0", # Makes it accessible externally
server_port=7860,
)