File size: 2,327 Bytes
3b95c05
5f5c56a
 
3b95c05
4087561
 
 
 
 
 
 
 
3b95c05
4087561
 
 
5f5c56a
 
3b95c05
5f5c56a
 
 
4087561
5f5c56a
3b95c05
5f5c56a
 
4087561
 
 
5f5c56a
 
4087561
5f5c56a
 
 
 
 
 
 
 
 
4087561
5f5c56a
 
 
4087561
 
 
5f5c56a
 
4087561
 
5f5c56a
 
4087561
 
 
5f5c56a
 
 
 
5e0ecd2
 
4087561
 
5f5c56a
5e0ecd2
 
5f5c56a
5e0ecd2
 
 
4087561
5e0ecd2
4087561
 
 
 
 
 
 
 
 
 
 
5e0ecd2
4087561
 
5f5c56a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# -------------------------------
# Page config
# -------------------------------
st.set_page_config(
    page_title="DocMed Demo",
    page_icon="🩺",
    layout="centered"
)

# -------------------------------
# Header
# -------------------------------
st.title("🩺 DocMed")
st.subheader("Medical Study Assistant (Educational Use Only)")

st.markdown(
    """
⚠️ **Disclaimer**  
DocMed is an **educational AI model only**.  
It must **NOT** be used for diagnosis, treatment, or clinical decision-making.
"""
)

# -------------------------------
# Load model (cached)
# -------------------------------
@st.cache_resource
def load_model():
    model_id = "jip7e/DocMed"  # <-- your HF model repo

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto"
    )
    model.eval()
    return tokenizer, model


with st.spinner("Loading DocMed model..."):
    tokenizer, model = load_model()

# -------------------------------
# User input
# -------------------------------
question = st.text_area(
    "Ask a medical question (student level):",
    placeholder="e.g. What is hydronephrosis?",
    height=120
)

# -------------------------------
# Inference
# -------------------------------
if st.button("Ask DocMed"):
    if question.strip() == "":
        st.warning("Please enter a question.")
    else:
        prompt = f"""You are DocMed, a medical study assistant.
Explain the following clearly and concisely for a medical student.
Use short sentences and simple language.
You may use emojis if helpful.

Question:
{question}

Answer:
"""

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=120,
                temperature=0.6,
                top_p=0.9,
                repetition_penalty=1.2,
                do_sample=True
            )

        decoded = tokenizer.decode(output[0], skip_special_tokens=True)

        # Remove prompt echo if present
        answer = decoded.split("Answer:")[-1].strip()

        st.markdown("### 🧠 DocMed says:")
        st.write(answer)