medchat2 / app.py
Sbnos's picture
mainfile_with copy
7160ad6 verified
import os
import sys
import html
# ============================================================================
# CRITICAL: Set API key BEFORE importing Together
# ============================================================================
# Check API key and set TOGETHER_API_KEY environment variable
API_KEY = os.environ.get("pilotikval")
if not API_KEY:
print("❌ Missing 'pilotikval' environment variable. Please set your TogetherAI API key.")
sys.exit(1)
# Set TOGETHER_API_KEY for the Together client
os.environ["TOGETHER_API_KEY"] = API_KEY
# NOW import Together and other dependencies
import streamlit as st
import streamlit.components.v1
from together import Together
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from huggingface_hub import snapshot_download
# ============================================================================
# CONFIGURATION
# ============================================================================
# Your HuggingFace dataset repository containing all vector stores
DATASET_REPO = "Sbnos/vstoryies"
# Vector store configurations
VECTOR_STORES = {
"General Medicine": {
"collection_name": "oxfordmed",
"persist_directory": "oxfordmedbookdir"
},
"Paediatrics": {
"collection_name": "paedia",
"persist_directory": "nelsonpaedia"
},
"Respiratory": {
"collection_name": "respmurraynotes",
"persist_directory": "respmurray"
},
"Dermatology": {
"collection_name": "derma",
"persist_directory": "rookderma"
},
"Endocrine": {
"collection_name": "endocrine",
"persist_directory": "williamsendocrine"
},
"Gastroenterology": {
"collection_name": "gastro",
"persist_directory": "yamadagastro"
},
"Surgery": {
"collection_name": "gensurgery",
"persist_directory": "baileysurgery"
},
"Neurology": {
"collection_name": "neuro",
"persist_directory": "bradleyneuro"
},
"Cardiology": {
"collection_name": "cardiobraun",
"persist_directory": "braunwaldcardiofin"
},
"Nephrology": {
"collection_name": "nephro",
"persist_directory": "brennernephro"
},
"Orthopedics": {
"collection_name": "oportho",
"persist_directory": "campbellorthop"
},
"Rheumatology": {
"collection_name": "rheumatology",
"persist_directory": "firesteinrheumatology"
}
}
# Model configurations
EMBED_MODEL = "BAAI/bge-base-en"
LLM_MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct"
RETRIEVAL_K = 26
# ============================================================================
# PAGE CONFIG
# ============================================================================
st.set_page_config(
page_title="DocChatter Medical RAG",
page_icon="🩺",
layout="wide"
)
# ============================================================================
# INITIALIZATION
# ============================================================================
# Initialize TogetherAI client
try:
client = Together()
except Exception as e:
st.error(f"❌ Failed to initialize Together client: {e}")
st.stop()
# Download all vector stores from HuggingFace dataset on first run
@st.cache_resource
def download_all_vectorstores():
"""Download all vector stores from HuggingFace dataset repository"""
if not any(os.path.exists(config["persist_directory"]) for config in VECTOR_STORES.values()):
with st.spinner("πŸ“₯ Downloading vector stores from HuggingFace (one-time setup)..."):
try:
snapshot_download(
repo_id=DATASET_REPO,
repo_type="dataset",
local_dir=".",
allow_patterns=["*"]
)
st.success("βœ… Vector stores downloaded successfully!")
except Exception as e:
st.error(f"❌ Failed to download vector stores: {e}")
st.stop()
# Download vector stores if needed
download_all_vectorstores()
# Initialize embeddings
@st.cache_resource
def get_embeddings():
return HuggingFaceEmbeddings(
model_name=EMBED_MODEL,
encode_kwargs={"normalize_embeddings": True}
)
embeddings = get_embeddings()
# ============================================================================
# SESSION STATE
# ============================================================================
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
if 'selected_collection' not in st.session_state:
st.session_state.selected_collection = list(VECTOR_STORES.keys())[0]
# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
@st.cache_resource
def load_vectorstore(_embeddings, collection_name, persist_directory):
"""Load and cache vector store"""
vectorstore = Chroma(
collection_name=collection_name,
persist_directory=persist_directory,
embedding_function=_embeddings
)
return vectorstore.as_retriever(search_kwargs={"k": RETRIEVAL_K})
def build_system_prompt(context: str) -> dict:
"""Build system prompt with retrieved context"""
prompt = f"""You are an expert medical assistant with access to authoritative medical literature.
Your role:
- Provide accurate, evidence-based medical information
- Answer questions clearly and comprehensively
- Ask clarifying questions if needed
- Use the context below to support your answers
- Be empathetic and professional
- Remember previous messages in the conversation
Retrieved Context:
{context}
Instructions:
- Base your answers on the provided context
- If the context doesn't contain relevant information, acknowledge this
- Structure complex answers with clear organization
- Cite specific information when referencing the context
"""
return {"role": "system", "content": prompt}
def stream_llm_response(messages):
"""Stream response from TogetherAI"""
response = ""
stream = client.chat.completions.create(
model=LLM_MODEL,
messages=messages,
max_tokens=24096,
temperature=0.1,
stream=True
)
for chunk in stream:
if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
response += chunk.choices[0].delta.content
yield response
# ============================================================================
# SIDEBAR
# ============================================================================
with st.sidebar:
st.title("🩺 DocChatter Medical RAG")
st.markdown("---")
# Collection selector
st.subheader("πŸ“š Select Medical Specialty")
selected = st.selectbox(
"Choose a collection:",
options=list(VECTOR_STORES.keys()),
index=list(VECTOR_STORES.keys()).index(st.session_state.selected_collection),
key="collection_selector"
)
if selected != st.session_state.selected_collection:
st.session_state.selected_collection = selected
st.rerun()
st.markdown("---")
# Stats
st.subheader("πŸ“Š Session Info")
st.metric("Messages", len(st.session_state.chat_history))
st.metric("Current Collection", selected)
st.markdown("---")
# Clear button
if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
st.session_state.chat_history = []
st.rerun()
st.markdown("---")
st.caption("Powered by TogetherAI & LangChain")
# ============================================================================
# MAIN CHAT INTERFACE
# ============================================================================
st.title("πŸ’¬ Medical Document Chat")
st.caption(f"Currently using: **{st.session_state.selected_collection}** collection")
# Load retriever for selected collection
config = VECTOR_STORES[st.session_state.selected_collection]
retriever = load_vectorstore(
embeddings,
config["collection_name"],
config["persist_directory"]
)
# Display chat history
for i, message in enumerate(st.session_state.chat_history):
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Add one-click copy button for assistant messages
if message["role"] == "assistant":
escaped_content = html.escape(message["content"])
copy_html = f"""
<div style="margin-top: 10px;">
<button onclick="copyText_{i}()" style="
background: transparent;
border: 1px solid rgba(250, 250, 250, 0.2);
border-radius: 4px;
cursor: pointer;
font-size: 1.2rem;
padding: 0.25rem 0.5rem;
opacity: 0.7;
transition: opacity 0.2s;
" onmouseover="this.style.opacity='1'" onmouseout="this.style.opacity='0.7'" title="Copy to clipboard">
πŸ“‹
</button>
<span id="status_{i}" style="margin-left: 10px; color: green; display: none;">βœ“ Copied!</span>
<textarea id="text_{i}" style="position: absolute; left: -9999px;">{escaped_content}</textarea>
</div>
<script>
function copyText_{i}() {{
const textarea = document.getElementById('text_{i}');
textarea.select();
document.execCommand('copy');
// Try modern API as fallback
if (navigator.clipboard) {{
navigator.clipboard.writeText(textarea.value);
}}
const status = document.getElementById('status_{i}');
status.style.display = 'inline';
setTimeout(() => status.style.display = 'none', 2000);
}}
</script>
"""
st.components.v1.html(copy_html, height=50)
# Chat input
user_input = st.chat_input("Ask me anything about medical topics...")
if user_input:
# Add user message
st.session_state.chat_history.append({
"role": "user",
"content": user_input
})
# Display user message
with st.chat_message("user"):
st.markdown(user_input)
# Retrieve relevant documents
with st.spinner("πŸ” Searching medical literature..."):
try:
docs = retriever.invoke(user_input)
except:
docs = retriever.get_relevant_documents(user_input)
context = "\n\n---\n\n".join([doc.page_content for doc in docs])
# Build messages for LLM
messages = [build_system_prompt(context)]
# Add chat history
for msg in st.session_state.chat_history:
messages.append({
"role": msg["role"],
"content": msg["content"]
})
# Stream assistant response
with st.chat_message("assistant"):
response_placeholder = st.empty()
full_response = ""
for response_chunk in stream_llm_response(messages):
full_response = response_chunk
response_placeholder.markdown(full_response + "β–Œ")
response_placeholder.markdown(full_response)
# Add one-click copy button for the new response
escaped_content_new = html.escape(full_response)
copy_html_new = f"""
<div style="margin-top: 10px;">
<button onclick="copyText_new()" style="
background: transparent;
border: 1px solid rgba(250, 250, 250, 0.2);
border-radius: 4px;
cursor: pointer;
font-size: 1.2rem;
padding: 0.25rem 0.5rem;
opacity: 0.7;
transition: opacity 0.2s;
" onmouseover="this.style.opacity='1'" onmouseout="this.style.opacity='0.7'" title="Copy to clipboard">
πŸ“‹
</button>
<span id="status_new" style="margin-left: 10px; color: green; display: none;">βœ“ Copied!</span>
<textarea id="text_new" style="position: absolute; left: -9999px;">{escaped_content_new}</textarea>
</div>
<script>
function copyText_new() {{
const textarea = document.getElementById('text_new');
textarea.select();
document.execCommand('copy');
// Try modern API as fallback
if (navigator.clipboard) {{
navigator.clipboard.writeText(textarea.value);
}}
const status = document.getElementById('status_new');
status.style.display = 'inline';
setTimeout(() => status.style.display = 'none', 2000);
}}
</script>
"""
st.components.v1.html(copy_html_new, height=50)
# Save assistant response
st.session_state.chat_history.append({
"role": "assistant",
"content": full_response
})
st.rerun()
# ============================================================================
# FOOTER
# ============================================================================
st.markdown("---")
st.caption("⚠️ This is an AI assistant. Always consult qualified healthcare professionals for medical advice.")