Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ------------------------------- | |
| # Page config | |
| # ------------------------------- | |
| st.set_page_config( | |
| page_title="DocMed Demo", | |
| page_icon="🩺", | |
| layout="centered" | |
| ) | |
| # ------------------------------- | |
| # Header | |
| # ------------------------------- | |
| st.title("🩺 DocMed") | |
| st.subheader("Medical Study Assistant (Educational Use Only)") | |
| st.markdown( | |
| """ | |
| ⚠️ **Disclaimer** | |
| DocMed is an **educational AI model only**. | |
| It must **NOT** be used for diagnosis, treatment, or clinical decision-making. | |
| """ | |
| ) | |
| # ------------------------------- | |
| # Load model (cached) | |
| # ------------------------------- | |
| def load_model(): | |
| model_id = "jip7e/DocMed" # <-- your HF model repo | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto" | |
| ) | |
| model.eval() | |
| return tokenizer, model | |
| with st.spinner("Loading DocMed model..."): | |
| tokenizer, model = load_model() | |
| # ------------------------------- | |
| # User input | |
| # ------------------------------- | |
| question = st.text_area( | |
| "Ask a medical question (student level):", | |
| placeholder="e.g. What is hydronephrosis?", | |
| height=120 | |
| ) | |
| # ------------------------------- | |
| # Inference | |
| # ------------------------------- | |
| if st.button("Ask DocMed"): | |
| if question.strip() == "": | |
| st.warning("Please enter a question.") | |
| else: | |
| prompt = f"""You are DocMed, a medical study assistant. | |
| Explain the following clearly and concisely for a medical student. | |
| Use short sentences and simple language. | |
| You may use emojis if helpful. | |
| Question: | |
| {question} | |
| Answer: | |
| """ | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=120, | |
| temperature=0.6, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| do_sample=True | |
| ) | |
| decoded = tokenizer.decode(output[0], skip_special_tokens=True) | |
| # Remove prompt echo if present | |
| answer = decoded.split("Answer:")[-1].strip() | |
| st.markdown("### 🧠 DocMed says:") | |
| st.write(answer) |