DocMed-Demo / src /streamlit_app.py
Jip7e's picture
Update src/streamlit_app.py
4087561 verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# -------------------------------
# Page config
# -------------------------------
st.set_page_config(
page_title="DocMed Demo",
page_icon="🩺",
layout="centered"
)
# -------------------------------
# Header
# -------------------------------
st.title("🩺 DocMed")
st.subheader("Medical Study Assistant (Educational Use Only)")
st.markdown(
"""
⚠️ **Disclaimer**
DocMed is an **educational AI model only**.
It must **NOT** be used for diagnosis, treatment, or clinical decision-making.
"""
)
# -------------------------------
# Load model (cached)
# -------------------------------
@st.cache_resource
def load_model():
model_id = "jip7e/DocMed" # <-- your HF model repo
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto"
)
model.eval()
return tokenizer, model
with st.spinner("Loading DocMed model..."):
tokenizer, model = load_model()
# -------------------------------
# User input
# -------------------------------
question = st.text_area(
"Ask a medical question (student level):",
placeholder="e.g. What is hydronephrosis?",
height=120
)
# -------------------------------
# Inference
# -------------------------------
if st.button("Ask DocMed"):
if question.strip() == "":
st.warning("Please enter a question.")
else:
prompt = f"""You are DocMed, a medical study assistant.
Explain the following clearly and concisely for a medical student.
Use short sentences and simple language.
You may use emojis if helpful.
Question:
{question}
Answer:
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.6,
top_p=0.9,
repetition_penalty=1.2,
do_sample=True
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
# Remove prompt echo if present
answer = decoded.split("Answer:")[-1].strip()
st.markdown("### 🧠 DocMed says:")
st.write(answer)