File size: 3,814 Bytes
16cc298
f248256
 
 
9b06b3d
f248256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1903659
16cc298
f248256
 
 
 
 
9b06b3d
fce8d80
f248256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fe6c25
f248256
 
 
 
 
 
8858343
f248256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import streamlit as st
from together import Together
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.chains import ConversationalRetrievalChain

# --- Configuration ---
TOGETHER_API_KEY = os.environ.get("TOGETHER_API_KEY")
if not TOGETHER_API_KEY:
    st.error("Missing TOGETHER_API_KEY environment variable.")
    st.stop()

# Initialize TogetherAI client
client = Together(api_key=TOGETHER_API_KEY)

# Embeddings setup
EMBED_MODEL_NAME = "BAAI/bge-base-en"
embeddings = HuggingFaceBgeEmbeddings(
    model_name=EMBED_MODEL_NAME,
    encode_kwargs={"normalize_embeddings": True},
)

# Sidebar: select collection
st.sidebar.title("DocChatter RAG")
collection = st.sidebar.selectbox(
    "Choose a document collection:",
    ['General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine']
)

dirs = {
    'General Medicine': './oxfordmedbookdir/',
    'RespiratoryFishman': './respfishmandbcud/',
    'RespiratoryMurray': './respmurray/',
    'MedMRCP2': './medmrcp2store/',
    'OldMedicine': './mrcpchromadb/'
}
cols = {
    'General Medicine': 'oxfordmed',
    'RespiratoryFishman': 'fishmannotescud',
    'RespiratoryMurray': 'respmurraynotes',
    'MedMRCP2': 'medmrcp2notes',
    'OldMedicine': 'mrcppassmednotes'
}

persist_directory = dirs.get(collection)
collection_name = cols.get(collection)

# Load Chroma vector store
vectorstore = Chroma(
    collection_name=collection_name,
    persist_directory=persist_directory,
    embedding_function=embeddings
)
retriever = vectorstore.as_retriever(search_kwargs={"k":5})

# System prompt template
SYSTEM_PROMPT = (
    "You are a helpful assistant for medical professionals. "
    "Use the following context from medical documents to answer the question. "
    "If you don't know, say you don't know.\n\nContext:\n{context}\n"
)

st.title("🩺 DocChatter RAG (Streaming)")

# Initialize chat history
if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []  # list of dicts {role, content}

# Tabs
chat_tab, clear_tab = st.tabs(["Chat", "Clear History"])
with chat_tab:
    # Display history
    for msg in st.session_state.chat_history:
        st.chat_message(msg['role']).write(msg['content'])

    # User input
    if prompt := st.chat_input("Ask anything about your docs..."):
        # User message
        st.chat_message("user").write(prompt)
        st.session_state.chat_history.append({"role": "user", "content": prompt})

        # Retrieve relevant docs
        docs = retriever.get_relevant_documents(prompt)
        context = "\n---\n".join([d.page_content for d in docs])

        # Build messages for TogetherAI
        system_msg = {"role": "system", "content": SYSTEM_PROMPT.format(context=context)}
        messages = [system_msg]
        # include prior conversation
        for msg in st.session_state.chat_history:
            if msg['role'] in ('user', 'assistant'):
                messages.append(msg)
        # Prepare streaming response
        response_container = st.chat_message("assistant")
        placeholder = response_container.empty()
        answer = ""
        # Stream tokens
        for token in client.chat.completions.create(
            model="meta-llama/Llama-4-Scout-17B-16E-Instruct",
            messages=messages,
            stream=True
        ):
            if hasattr(token, 'choices'):
                delta = token.choices[0].delta.content
                answer += delta
                placeholder.write(answer)
        # Save assistant message
        st.session_state.chat_history.append({"role": "assistant", "content": answer})

with clear_tab:
    if st.button("🗑️ Clear chat history"):
        st.session_state.chat_history = []
        st.experimental_rerun()