|
|
import os |
|
|
import streamlit as st |
|
|
from together import Together |
|
|
from langchain.vectorstores import Chroma |
|
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
|
from langchain.chains import ConversationalRetrievalChain |
|
|
|
|
|
|
|
|
TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY") |
|
|
if not TOGETHER_API_KEY: |
|
|
st.error("Missing TOGETHER_API_KEY environment variable.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
client = Together(api_key=TOGETHER_API_KEY) |
|
|
|
|
|
|
|
|
EMBED_MODEL_NAME = "BAAI/bge-base-en" |
|
|
embeddings = HuggingFaceBgeEmbeddings( |
|
|
model_name=EMBED_MODEL_NAME, |
|
|
encode_kwargs={"normalize_embeddings": True}, |
|
|
) |
|
|
|
|
|
|
|
|
st.sidebar.title("DocChatter RAG") |
|
|
collection = st.sidebar.selectbox( |
|
|
"Choose a document collection:", |
|
|
['General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine'] |
|
|
) |
|
|
|
|
|
dirs = { |
|
|
'General Medicine': './oxfordmedbookdir/', |
|
|
'RespiratoryFishman': './respfishmandbcud/', |
|
|
'RespiratoryMurray': './respmurray/', |
|
|
'MedMRCP2': './medmrcp2store/', |
|
|
'OldMedicine': './mrcpchromadb/' |
|
|
} |
|
|
cols = { |
|
|
'General Medicine': 'oxfordmed', |
|
|
'RespiratoryFishman': 'fishmannotescud', |
|
|
'RespiratoryMurray': 'respmurraynotes', |
|
|
'MedMRCP2': 'medmrcp2notes', |
|
|
'OldMedicine': 'mrcppassmednotes' |
|
|
} |
|
|
|
|
|
persist_directory = dirs.get(collection) |
|
|
collection_name = cols.get(collection) |
|
|
|
|
|
|
|
|
vectorstore = Chroma( |
|
|
collection_name=collection_name, |
|
|
persist_directory=persist_directory, |
|
|
embedding_function=embeddings |
|
|
) |
|
|
retriever = vectorstore.as_retriever(search_kwargs={"k":5}) |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
|
"You are a helpful assistant for medical professionals. " |
|
|
"Use the following context from medical documents to answer the question. " |
|
|
"If you don't know, say you don't know.\n\nContext:\n{context}\n" |
|
|
) |
|
|
|
|
|
st.title("๐ฉบ DocChatter RAG (Streaming)") |
|
|
|
|
|
|
|
|
if 'chat_history' not in st.session_state: |
|
|
st.session_state.chat_history = [] |
|
|
|
|
|
|
|
|
chat_tab, clear_tab = st.tabs(["Chat", "Clear History"]) |
|
|
with chat_tab: |
|
|
|
|
|
for msg in st.session_state.chat_history: |
|
|
st.chat_message(msg['role']).write(msg['content']) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Ask anything about your docs..."): |
|
|
|
|
|
st.chat_message("user").write(prompt) |
|
|
st.session_state.chat_history.append({"role": "user", "content": prompt}) |
|
|
|
|
|
|
|
|
docs = retriever.get_relevant_documents(prompt) |
|
|
context = "\n---\n".join([d.page_content for d in docs]) |
|
|
|
|
|
|
|
|
system_msg = {"role": "system", "content": SYSTEM_PROMPT.format(context=context)} |
|
|
messages = [system_msg] |
|
|
|
|
|
for msg in st.session_state.chat_history: |
|
|
if msg['role'] in ('user', 'assistant'): |
|
|
messages.append(msg) |
|
|
|
|
|
response_container = st.chat_message("assistant") |
|
|
placeholder = response_container.empty() |
|
|
answer = "" |
|
|
|
|
|
for token in client.chat.completions.create( |
|
|
model="meta-llama/Llama-4-Scout-17B-16E-Instruct", |
|
|
messages=messages, |
|
|
stream=True |
|
|
): |
|
|
if hasattr(token, 'choices'): |
|
|
delta = token.choices[0].delta.content |
|
|
answer += delta |
|
|
placeholder.write(answer) |
|
|
|
|
|
st.session_state.chat_history.append({"role": "assistant", "content": answer}) |
|
|
|
|
|
with clear_tab: |
|
|
if st.button("๐๏ธ Clear chat history"): |
|
|
st.session_state.chat_history = [] |
|
|
st.experimental_rerun() |
|
|
|