File size: 13,967 Bytes
16cc298
4da6393
7160ad6
4da6393
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f248256
7160ad6
f248256
d3e32a1
6fd5317
7160ad6
f248256
bcf3e78
 
 
 
7160ad6
 
 
bcf3e78
 
7160ad6
 
 
 
bcf3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4da6393
7160ad6
bcf3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e781bf
4da6393
 
 
 
 
16cc298
7160ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcf3e78
 
 
 
 
 
 
fce8d80
bcf3e78
f248256
bcf3e78
 
 
f248256
bcf3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2325571
bcf3e78
 
 
 
 
 
2325571
6fd5317
d3e32a1
bcf3e78
 
 
548e5c0
bcf3e78
 
12ec640
bcf3e78
 
548e5c0
 
 
 
 
bcf3e78
 
f248256
bcf3e78
 
 
35e79b2
bcf3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109cf19
bcf3e78
 
 
f248256
bcf3e78
 
109cf19
bcf3e78
 
 
 
 
 
 
f248256
bcf3e78
 
 
 
 
7160ad6
bcf3e78
7160ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f248256
bcf3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b30d539
bcf3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7160ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcf3e78
 
 
 
 
 
 
 
 
 
 
 
 
 
6e781bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import os
import sys
import html

# ============================================================================
# CRITICAL: Set API key BEFORE importing Together
# ============================================================================

# Check API key and set TOGETHER_API_KEY environment variable
API_KEY = os.environ.get("pilotikval")
if not API_KEY:
    print("❌ Missing 'pilotikval' environment variable. Please set your TogetherAI API key.")
    sys.exit(1)

# Set TOGETHER_API_KEY for the Together client
os.environ["TOGETHER_API_KEY"] = API_KEY

# NOW import Together and other dependencies
import streamlit as st
import streamlit.components.v1
from together import Together
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from huggingface_hub import snapshot_download

# ============================================================================
# CONFIGURATION
# ============================================================================

# Your HuggingFace dataset repository containing all vector stores
DATASET_REPO = "Sbnos/vstoryies"

# Vector store configurations
VECTOR_STORES = {
    "General Medicine": {
        "collection_name": "oxfordmed",
        "persist_directory": "oxfordmedbookdir"
    },
    "Paediatrics": {
        "collection_name": "paedia",
        "persist_directory": "nelsonpaedia"
    },
    "Respiratory": {
        "collection_name": "respmurraynotes",
        "persist_directory": "respmurray"
    },
    "Dermatology": {
        "collection_name": "derma",
        "persist_directory": "rookderma"
    },
    "Endocrine": {
        "collection_name": "endocrine",
        "persist_directory": "williamsendocrine"
    },
    "Gastroenterology": {
        "collection_name": "gastro",
        "persist_directory": "yamadagastro"
    },
    "Surgery": {
        "collection_name": "gensurgery",
        "persist_directory": "baileysurgery"
    },
    "Neurology": {
        "collection_name": "neuro",
        "persist_directory": "bradleyneuro"
    },
    "Cardiology": {
        "collection_name": "cardiobraun",
        "persist_directory": "braunwaldcardiofin"
    },
    "Nephrology": {
        "collection_name": "nephro",
        "persist_directory": "brennernephro"
    },
    "Orthopedics": {
        "collection_name": "oportho",
        "persist_directory": "campbellorthop"
    },
    "Rheumatology": {
        "collection_name": "rheumatology",
        "persist_directory": "firesteinrheumatology"
    }
}

# Model configurations
EMBED_MODEL = "BAAI/bge-base-en"
LLM_MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct"
RETRIEVAL_K = 26

# ============================================================================
# PAGE CONFIG
# ============================================================================

st.set_page_config(
    page_title="DocChatter Medical RAG",
    page_icon="🩺",
    layout="wide"
)

# ============================================================================
# INITIALIZATION
# ============================================================================

# Initialize TogetherAI client
try:
    client = Together()
except Exception as e:
    st.error(f"❌ Failed to initialize Together client: {e}")
    st.stop()

# Download all vector stores from HuggingFace dataset on first run
@st.cache_resource
def download_all_vectorstores():
    """Download all vector stores from HuggingFace dataset repository"""
    if not any(os.path.exists(config["persist_directory"]) for config in VECTOR_STORES.values()):
        with st.spinner("πŸ“₯ Downloading vector stores from HuggingFace (one-time setup)..."):
            try:
                snapshot_download(
                    repo_id=DATASET_REPO,
                    repo_type="dataset",
                    local_dir=".",
                    allow_patterns=["*"]
                )
                st.success("βœ… Vector stores downloaded successfully!")
            except Exception as e:
                st.error(f"❌ Failed to download vector stores: {e}")
                st.stop()

# Download vector stores if needed
download_all_vectorstores()

# Initialize embeddings
@st.cache_resource
def get_embeddings():
    return HuggingFaceEmbeddings(
        model_name=EMBED_MODEL,
        encode_kwargs={"normalize_embeddings": True}
    )

embeddings = get_embeddings()

# ============================================================================
# SESSION STATE
# ============================================================================

if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []

if 'selected_collection' not in st.session_state:
    st.session_state.selected_collection = list(VECTOR_STORES.keys())[0]

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

@st.cache_resource
def load_vectorstore(_embeddings, collection_name, persist_directory):
    """Load and cache vector store"""
    vectorstore = Chroma(
        collection_name=collection_name,
        persist_directory=persist_directory,
        embedding_function=_embeddings
    )
    return vectorstore.as_retriever(search_kwargs={"k": RETRIEVAL_K})

def build_system_prompt(context: str) -> dict:
    """Build system prompt with retrieved context"""
    prompt = f"""You are an expert medical assistant with access to authoritative medical literature.

Your role:
- Provide accurate, evidence-based medical information
- Answer questions clearly and comprehensively
- Ask clarifying questions if needed
- Use the context below to support your answers
- Be empathetic and professional
- Remember previous messages in the conversation

Retrieved Context:
{context}

Instructions:
- Base your answers on the provided context
- If the context doesn't contain relevant information, acknowledge this
- Structure complex answers with clear organization
- Cite specific information when referencing the context
"""
    return {"role": "system", "content": prompt}

def stream_llm_response(messages):
    """Stream response from TogetherAI"""
    response = ""
    stream = client.chat.completions.create(
        model=LLM_MODEL,
        messages=messages,
        max_tokens=24096,
        temperature=0.1,
        stream=True
    )
    
    for chunk in stream:
        if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
            if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
                response += chunk.choices[0].delta.content
                yield response

# ============================================================================
# SIDEBAR
# ============================================================================

with st.sidebar:
    st.title("🩺 DocChatter Medical RAG")
    st.markdown("---")
    
    # Collection selector
    st.subheader("πŸ“š Select Medical Specialty")
    selected = st.selectbox(
        "Choose a collection:",
        options=list(VECTOR_STORES.keys()),
        index=list(VECTOR_STORES.keys()).index(st.session_state.selected_collection),
        key="collection_selector"
    )
    
    if selected != st.session_state.selected_collection:
        st.session_state.selected_collection = selected
        st.rerun()
    
    st.markdown("---")
    
    # Stats
    st.subheader("πŸ“Š Session Info")
    st.metric("Messages", len(st.session_state.chat_history))
    st.metric("Current Collection", selected)
    
    st.markdown("---")
    
    # Clear button
    if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
        st.session_state.chat_history = []
        st.rerun()
    
    st.markdown("---")
    st.caption("Powered by TogetherAI & LangChain")

# ============================================================================
# MAIN CHAT INTERFACE
# ============================================================================

st.title("πŸ’¬ Medical Document Chat")
st.caption(f"Currently using: **{st.session_state.selected_collection}** collection")

# Load retriever for selected collection
config = VECTOR_STORES[st.session_state.selected_collection]
retriever = load_vectorstore(
    embeddings,
    config["collection_name"],
    config["persist_directory"]
)

# Display chat history
for i, message in enumerate(st.session_state.chat_history):
    with st.chat_message(message["role"]):
        st.markdown(message["content"])
        
        # Add one-click copy button for assistant messages
        if message["role"] == "assistant":
            escaped_content = html.escape(message["content"])
            
            copy_html = f"""
                <div style="margin-top: 10px;">
                    <button onclick="copyText_{i}()" style="
                        background: transparent;
                        border: 1px solid rgba(250, 250, 250, 0.2);
                        border-radius: 4px;
                        cursor: pointer;
                        font-size: 1.2rem;
                        padding: 0.25rem 0.5rem;
                        opacity: 0.7;
                        transition: opacity 0.2s;
                    " onmouseover="this.style.opacity='1'" onmouseout="this.style.opacity='0.7'" title="Copy to clipboard">
                        πŸ“‹
                    </button>
                    <span id="status_{i}" style="margin-left: 10px; color: green; display: none;">βœ“ Copied!</span>
                    <textarea id="text_{i}" style="position: absolute; left: -9999px;">{escaped_content}</textarea>
                </div>
                
                <script>
                function copyText_{i}() {{
                    const textarea = document.getElementById('text_{i}');
                    textarea.select();
                    document.execCommand('copy');
                    
                    // Try modern API as fallback
                    if (navigator.clipboard) {{
                        navigator.clipboard.writeText(textarea.value);
                    }}
                    
                    const status = document.getElementById('status_{i}');
                    status.style.display = 'inline';
                    setTimeout(() => status.style.display = 'none', 2000);
                }}
                </script>
            """
            st.components.v1.html(copy_html, height=50)

# Chat input
user_input = st.chat_input("Ask me anything about medical topics...")

if user_input:
    # Add user message
    st.session_state.chat_history.append({
        "role": "user",
        "content": user_input
    })
    
    # Display user message
    with st.chat_message("user"):
        st.markdown(user_input)
    
    # Retrieve relevant documents
    with st.spinner("πŸ” Searching medical literature..."):
        try:
            docs = retriever.invoke(user_input)
        except:
            docs = retriever.get_relevant_documents(user_input)
        
        context = "\n\n---\n\n".join([doc.page_content for doc in docs])
    
    # Build messages for LLM
    messages = [build_system_prompt(context)]
    
    # Add chat history
    for msg in st.session_state.chat_history:
        messages.append({
            "role": msg["role"],
            "content": msg["content"]
        })
    
    # Stream assistant response
    with st.chat_message("assistant"):
        response_placeholder = st.empty()
        full_response = ""
        
        for response_chunk in stream_llm_response(messages):
            full_response = response_chunk
            response_placeholder.markdown(full_response + "β–Œ")
        
        response_placeholder.markdown(full_response)
        
        # Add one-click copy button for the new response
        escaped_content_new = html.escape(full_response)
        
        copy_html_new = f"""
            <div style="margin-top: 10px;">
                <button onclick="copyText_new()" style="
                    background: transparent;
                    border: 1px solid rgba(250, 250, 250, 0.2);
                    border-radius: 4px;
                    cursor: pointer;
                    font-size: 1.2rem;
                    padding: 0.25rem 0.5rem;
                    opacity: 0.7;
                    transition: opacity 0.2s;
                " onmouseover="this.style.opacity='1'" onmouseout="this.style.opacity='0.7'" title="Copy to clipboard">
                    πŸ“‹
                </button>
                <span id="status_new" style="margin-left: 10px; color: green; display: none;">βœ“ Copied!</span>
                <textarea id="text_new" style="position: absolute; left: -9999px;">{escaped_content_new}</textarea>
            </div>
            
            <script>
            function copyText_new() {{
                const textarea = document.getElementById('text_new');
                textarea.select();
                document.execCommand('copy');
                
                // Try modern API as fallback
                if (navigator.clipboard) {{
                    navigator.clipboard.writeText(textarea.value);
                }}
                
                const status = document.getElementById('status_new');
                status.style.display = 'inline';
                setTimeout(() => status.style.display = 'none', 2000);
            }}
            </script>
        """
        st.components.v1.html(copy_html_new, height=50)
    
    # Save assistant response
    st.session_state.chat_history.append({
        "role": "assistant",
        "content": full_response
    })
    
    st.rerun()

# ============================================================================
# FOOTER
# ============================================================================

st.markdown("---")
st.caption("⚠️ This is an AI assistant. Always consult qualified healthcare professionals for medical advice.")