Muhammadidrees's picture
Create app.py
8aceec0 verified
raw
history blame
6.24 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import time
# =======================================================
# Session state to track multi-step questions
# =======================================================
session_answers = {}
# =======================================================
# Load Model
# =======================================================
model_name = "augtoma/qCammel-13"
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True
)
model.eval()
print("Model loaded successfully!")
print(f"Device map: {model.hf_device_map}")
print(f"Model device: {next(model.parameters()).device}")
print(f"GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
# =======================================================
# Generate Response with token-by-token streaming
# =======================================================
def generate_doctor_response(history, session_answers):
user_message = history[-1]["content"]
if not user_message.strip():
history.append({"role": "assistant", "content": "⚠️ Please describe your symptoms or ask a question."})
yield history
return
# Build conversation prompt
prompt = """You are an experienced doctor conducting a medical consultation. Your role is to:
1. Ask one follow-up question at a time
2. Provide advice or suggestions if possible
3. Be conversational, caring, and thorough\n\n"""
# Include last 5 exchanges
recent_history = history[-11:-1] if len(history) > 11 else history[:-1]
for msg in recent_history:
role = "Patient" if msg["role"] == "user" else "Doctor"
content = msg['content'].replace(
"⚕️ *Note: This is AI-generated information and not a substitute for professional medical advice. Please consult a healthcare provider for proper diagnosis and treatment.*",
""
).strip()
prompt += f"{role}: {content}\n"
prompt += f"Patient: {user_message}\nDoctor:"
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
gen_config = GenerationConfig(
temperature=0.7,
top_p=0.9,
do_sample=True,
max_new_tokens=120,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2
)
input_length = inputs["input_ids"].shape[1]
torch.cuda.synchronize() if torch.cuda.is_available() else None
with torch.no_grad():
output_ids = model.generate(
**inputs,
generation_config=gen_config
)
torch.cuda.synchronize() if torch.cuda.is_available() else None
# Decode and clean response
generated_ids = output_ids[0][input_length:]
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
# Stop at hints of patient message
stop_patterns = [
"Patient:", "\nPatient", "P:", "How are you", "I am feeling", "Thanks"
]
min_stop_pos = len(response)
for pattern in stop_patterns:
pos = response.lower().find(pattern.lower())
if pos != -1 and pos < min_stop_pos:
min_stop_pos = pos
response = response[:min_stop_pos].strip()
if response.lower().startswith("doctor:"):
response = response[7:].strip()
if len(response) < 10:
response = "I understand your concern. Could you please provide more details about your symptoms so I can assist you better?"
# Append assistant placeholder for streaming
history.append({"role": "assistant", "content": ""})
# Stream token by token
for i in range(0, len(response), 4):
chunk = response[:i+4]
history[-1]["content"] = chunk + "▌"
yield history.copy()
time.sleep(0.015)
# Final response with disclaimer
history[-1]["content"] = response
yield history
# =======================================================
# Gradio Interface
# =======================================================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🩺 AI Doctor Chat Assistant")
chatbot = gr.Chatbot(
label="💬 Doctor Consultation",
type='messages',
avatar_images=(
"https://cdn-icons-png.flaticon.com/512/706/706830.png", # Patient
"https://cdn-icons-png.flaticon.com/512/3774/3774299.png" # Doctor
),
height=500
)
with gr.Row():
user_input = gr.Textbox(
placeholder="Type your symptoms or question here...",
label="🧍 Your Message",
lines=2,
scale=4
)
with gr.Row():
send_btn = gr.Button("💬 Send", variant="primary", scale=1)
clear_btn = gr.Button("🧹 Clear Chat", scale=1)
gr.Examples(
examples=[
"I have a fever of 102°F since yesterday",
"I've been having headaches for the past week",
"I feel very tired all the time",
"I have a sore throat and body aches",
],
inputs=user_input,
label="💡 Example Questions"
)
# Response function
def respond(message, history):
global session_answers
if history is None:
history = []
if not message.strip():
return "", history
history.append({"role": "user", "content": message})
for updated_history in generate_doctor_response(history, session_answers):
yield "", updated_history
# Event handlers
send_btn.click(respond, [user_input, chatbot], [user_input, chatbot])
user_input.submit(respond, [user_input, chatbot], [user_input, chatbot])
clear_btn.click(lambda: [], None, chatbot, queue=False)
# Launch
if __name__ == "__main__":
demo.queue()
demo.launch(share=True)