|
|
import os |
|
|
import sys |
|
|
import html |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
API_KEY = os.environ.get("pilotikval") |
|
|
if not API_KEY: |
|
|
print("β Missing 'pilotikval' environment variable. Please set your TogetherAI API key.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
os.environ["TOGETHER_API_KEY"] = API_KEY |
|
|
|
|
|
|
|
|
import streamlit as st |
|
|
import streamlit.components.v1 |
|
|
from together import Together |
|
|
from langchain_community.vectorstores import Chroma |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATASET_REPO = "Sbnos/vstoryies" |
|
|
|
|
|
|
|
|
VECTOR_STORES = { |
|
|
"General Medicine": { |
|
|
"collection_name": "oxfordmed", |
|
|
"persist_directory": "oxfordmedbookdir" |
|
|
}, |
|
|
"Paediatrics": { |
|
|
"collection_name": "paedia", |
|
|
"persist_directory": "nelsonpaedia" |
|
|
}, |
|
|
"Respiratory": { |
|
|
"collection_name": "respmurraynotes", |
|
|
"persist_directory": "respmurray" |
|
|
}, |
|
|
"Dermatology": { |
|
|
"collection_name": "derma", |
|
|
"persist_directory": "rookderma" |
|
|
}, |
|
|
"Endocrine": { |
|
|
"collection_name": "endocrine", |
|
|
"persist_directory": "williamsendocrine" |
|
|
}, |
|
|
"Gastroenterology": { |
|
|
"collection_name": "gastro", |
|
|
"persist_directory": "yamadagastro" |
|
|
}, |
|
|
"Surgery": { |
|
|
"collection_name": "gensurgery", |
|
|
"persist_directory": "baileysurgery" |
|
|
}, |
|
|
"Neurology": { |
|
|
"collection_name": "neuro", |
|
|
"persist_directory": "bradleyneuro" |
|
|
}, |
|
|
"Cardiology": { |
|
|
"collection_name": "cardiobraun", |
|
|
"persist_directory": "braunwaldcardiofin" |
|
|
}, |
|
|
"Nephrology": { |
|
|
"collection_name": "nephro", |
|
|
"persist_directory": "brennernephro" |
|
|
}, |
|
|
"Orthopedics": { |
|
|
"collection_name": "oportho", |
|
|
"persist_directory": "campbellorthop" |
|
|
}, |
|
|
"Rheumatology": { |
|
|
"collection_name": "rheumatology", |
|
|
"persist_directory": "firesteinrheumatology" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
EMBED_MODEL = "BAAI/bge-base-en" |
|
|
LLM_MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct" |
|
|
RETRIEVAL_K = 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="DocChatter Medical RAG", |
|
|
page_icon="π©Ί", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
client = Together() |
|
|
except Exception as e: |
|
|
st.error(f"β Failed to initialize Together client: {e}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def download_all_vectorstores(): |
|
|
"""Download all vector stores from HuggingFace dataset repository""" |
|
|
if not any(os.path.exists(config["persist_directory"]) for config in VECTOR_STORES.values()): |
|
|
with st.spinner("π₯ Downloading vector stores from HuggingFace (one-time setup)..."): |
|
|
try: |
|
|
snapshot_download( |
|
|
repo_id=DATASET_REPO, |
|
|
repo_type="dataset", |
|
|
local_dir=".", |
|
|
allow_patterns=["*"] |
|
|
) |
|
|
st.success("β
Vector stores downloaded successfully!") |
|
|
except Exception as e: |
|
|
st.error(f"β Failed to download vector stores: {e}") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
download_all_vectorstores() |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def get_embeddings(): |
|
|
return HuggingFaceEmbeddings( |
|
|
model_name=EMBED_MODEL, |
|
|
encode_kwargs={"normalize_embeddings": True} |
|
|
) |
|
|
|
|
|
embeddings = get_embeddings() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'chat_history' not in st.session_state: |
|
|
st.session_state.chat_history = [] |
|
|
|
|
|
if 'selected_collection' not in st.session_state: |
|
|
st.session_state.selected_collection = list(VECTOR_STORES.keys())[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_vectorstore(_embeddings, collection_name, persist_directory): |
|
|
"""Load and cache vector store""" |
|
|
vectorstore = Chroma( |
|
|
collection_name=collection_name, |
|
|
persist_directory=persist_directory, |
|
|
embedding_function=_embeddings |
|
|
) |
|
|
return vectorstore.as_retriever(search_kwargs={"k": RETRIEVAL_K}) |
|
|
|
|
|
def build_system_prompt(context: str) -> dict: |
|
|
"""Build system prompt with retrieved context""" |
|
|
prompt = f"""You are an expert medical assistant with access to authoritative medical literature. |
|
|
|
|
|
Your role: |
|
|
- Provide accurate, evidence-based medical information |
|
|
- Answer questions clearly and comprehensively |
|
|
- Ask clarifying questions if needed |
|
|
- Use the context below to support your answers |
|
|
- Be empathetic and professional |
|
|
- Remember previous messages in the conversation |
|
|
|
|
|
Retrieved Context: |
|
|
{context} |
|
|
|
|
|
Instructions: |
|
|
- Base your answers on the provided context |
|
|
- If the context doesn't contain relevant information, acknowledge this |
|
|
- Structure complex answers with clear organization |
|
|
- Cite specific information when referencing the context |
|
|
""" |
|
|
return {"role": "system", "content": prompt} |
|
|
|
|
|
def stream_llm_response(messages): |
|
|
"""Stream response from TogetherAI""" |
|
|
response = "" |
|
|
stream = client.chat.completions.create( |
|
|
model=LLM_MODEL, |
|
|
messages=messages, |
|
|
max_tokens=24096, |
|
|
temperature=0.1, |
|
|
stream=True |
|
|
) |
|
|
|
|
|
for chunk in stream: |
|
|
if hasattr(chunk, 'choices') and len(chunk.choices) > 0: |
|
|
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: |
|
|
response += chunk.choices[0].delta.content |
|
|
yield response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.title("π©Ί DocChatter Medical RAG") |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.subheader("π Select Medical Specialty") |
|
|
selected = st.selectbox( |
|
|
"Choose a collection:", |
|
|
options=list(VECTOR_STORES.keys()), |
|
|
index=list(VECTOR_STORES.keys()).index(st.session_state.selected_collection), |
|
|
key="collection_selector" |
|
|
) |
|
|
|
|
|
if selected != st.session_state.selected_collection: |
|
|
st.session_state.selected_collection = selected |
|
|
st.rerun() |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.subheader("π Session Info") |
|
|
st.metric("Messages", len(st.session_state.chat_history)) |
|
|
st.metric("Current Collection", selected) |
|
|
|
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
if st.button("ποΈ Clear Chat History", use_container_width=True): |
|
|
st.session_state.chat_history = [] |
|
|
st.rerun() |
|
|
|
|
|
st.markdown("---") |
|
|
st.caption("Powered by TogetherAI & LangChain") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("π¬ Medical Document Chat") |
|
|
st.caption(f"Currently using: **{st.session_state.selected_collection}** collection") |
|
|
|
|
|
|
|
|
config = VECTOR_STORES[st.session_state.selected_collection] |
|
|
retriever = load_vectorstore( |
|
|
embeddings, |
|
|
config["collection_name"], |
|
|
config["persist_directory"] |
|
|
) |
|
|
|
|
|
|
|
|
for i, message in enumerate(st.session_state.chat_history): |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
|
if message["role"] == "assistant": |
|
|
escaped_content = html.escape(message["content"]) |
|
|
|
|
|
copy_html = f""" |
|
|
<div style="margin-top: 10px;"> |
|
|
<button onclick="copyText_{i}()" style=" |
|
|
background: transparent; |
|
|
border: 1px solid rgba(250, 250, 250, 0.2); |
|
|
border-radius: 4px; |
|
|
cursor: pointer; |
|
|
font-size: 1.2rem; |
|
|
padding: 0.25rem 0.5rem; |
|
|
opacity: 0.7; |
|
|
transition: opacity 0.2s; |
|
|
" onmouseover="this.style.opacity='1'" onmouseout="this.style.opacity='0.7'" title="Copy to clipboard"> |
|
|
π |
|
|
</button> |
|
|
<span id="status_{i}" style="margin-left: 10px; color: green; display: none;">β Copied!</span> |
|
|
<textarea id="text_{i}" style="position: absolute; left: -9999px;">{escaped_content}</textarea> |
|
|
</div> |
|
|
|
|
|
<script> |
|
|
function copyText_{i}() {{ |
|
|
const textarea = document.getElementById('text_{i}'); |
|
|
textarea.select(); |
|
|
document.execCommand('copy'); |
|
|
|
|
|
// Try modern API as fallback |
|
|
if (navigator.clipboard) {{ |
|
|
navigator.clipboard.writeText(textarea.value); |
|
|
}} |
|
|
|
|
|
const status = document.getElementById('status_{i}'); |
|
|
status.style.display = 'inline'; |
|
|
setTimeout(() => status.style.display = 'none', 2000); |
|
|
}} |
|
|
</script> |
|
|
""" |
|
|
st.components.v1.html(copy_html, height=50) |
|
|
|
|
|
|
|
|
user_input = st.chat_input("Ask me anything about medical topics...") |
|
|
|
|
|
if user_input: |
|
|
|
|
|
st.session_state.chat_history.append({ |
|
|
"role": "user", |
|
|
"content": user_input |
|
|
}) |
|
|
|
|
|
|
|
|
with st.chat_message("user"): |
|
|
st.markdown(user_input) |
|
|
|
|
|
|
|
|
with st.spinner("π Searching medical literature..."): |
|
|
try: |
|
|
docs = retriever.invoke(user_input) |
|
|
except: |
|
|
docs = retriever.get_relevant_documents(user_input) |
|
|
|
|
|
context = "\n\n---\n\n".join([doc.page_content for doc in docs]) |
|
|
|
|
|
|
|
|
messages = [build_system_prompt(context)] |
|
|
|
|
|
|
|
|
for msg in st.session_state.chat_history: |
|
|
messages.append({ |
|
|
"role": msg["role"], |
|
|
"content": msg["content"] |
|
|
}) |
|
|
|
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
response_placeholder = st.empty() |
|
|
full_response = "" |
|
|
|
|
|
for response_chunk in stream_llm_response(messages): |
|
|
full_response = response_chunk |
|
|
response_placeholder.markdown(full_response + "β") |
|
|
|
|
|
response_placeholder.markdown(full_response) |
|
|
|
|
|
|
|
|
escaped_content_new = html.escape(full_response) |
|
|
|
|
|
copy_html_new = f""" |
|
|
<div style="margin-top: 10px;"> |
|
|
<button onclick="copyText_new()" style=" |
|
|
background: transparent; |
|
|
border: 1px solid rgba(250, 250, 250, 0.2); |
|
|
border-radius: 4px; |
|
|
cursor: pointer; |
|
|
font-size: 1.2rem; |
|
|
padding: 0.25rem 0.5rem; |
|
|
opacity: 0.7; |
|
|
transition: opacity 0.2s; |
|
|
" onmouseover="this.style.opacity='1'" onmouseout="this.style.opacity='0.7'" title="Copy to clipboard"> |
|
|
π |
|
|
</button> |
|
|
<span id="status_new" style="margin-left: 10px; color: green; display: none;">β Copied!</span> |
|
|
<textarea id="text_new" style="position: absolute; left: -9999px;">{escaped_content_new}</textarea> |
|
|
</div> |
|
|
|
|
|
<script> |
|
|
function copyText_new() {{ |
|
|
const textarea = document.getElementById('text_new'); |
|
|
textarea.select(); |
|
|
document.execCommand('copy'); |
|
|
|
|
|
// Try modern API as fallback |
|
|
if (navigator.clipboard) {{ |
|
|
navigator.clipboard.writeText(textarea.value); |
|
|
}} |
|
|
|
|
|
const status = document.getElementById('status_new'); |
|
|
status.style.display = 'inline'; |
|
|
setTimeout(() => status.style.display = 'none', 2000); |
|
|
}} |
|
|
</script> |
|
|
""" |
|
|
st.components.v1.html(copy_html_new, height=50) |
|
|
|
|
|
|
|
|
st.session_state.chat_history.append({ |
|
|
"role": "assistant", |
|
|
"content": full_response |
|
|
}) |
|
|
|
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.caption("β οΈ This is an AI assistant. Always consult qualified healthcare professionals for medical advice.") |