File size: 7,599 Bytes
3530638
6dffeff
3530638
 
 
 
 
1aa1e51
 
 
3530638
 
 
 
 
 
 
 
 
1aa1e51
 
 
 
3530638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dffeff
3530638
 
 
 
 
 
6dffeff
 
 
 
 
 
 
 
 
 
 
 
3530638
6dffeff
3530638
 
 
 
 
 
1aa1e51
 
 
 
 
 
 
 
 
 
 
 
b9db63a
3530638
 
 
 
b9db63a
3530638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aa1e51
3530638
 
1aa1e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3530638
 
 
 
 
 
 
 
 
6dffeff
3530638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aa1e51
 
 
 
 
 
 
 
 
 
3530638
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import streamlit as st
import Rag
from openai import OpenAI
from together import Together
import time
import os
from dotenv import load_dotenv
from google_sheets_uploader import upload_to_google_sheets
import pandas as pd
import json

load_dotenv()

if "embedder_loaded" not in st.session_state:
    st.session_state.embedder_loaded = False
if "current_embedder_name" not in st.session_state:
    st.session_state.current_embedder_name = None
if "last_sources" not in st.session_state:
    st.session_state.last_sources = []
if "session_data" not in st.session_state:
    st.session_state.session_data = []
if "uploaded_rows_count" not in st.session_state:
    st.session_state.uploaded_rows_count = 0

st.set_page_config(
    page_title="Bipolar Assistant Chatbot",
    page_icon=":robot_face:",
    layout="wide",
    initial_sidebar_state="collapsed"
)

model_options = [
    "Qwen/Qwen3-Embedding-0.6B",
    "jinaai/jina-embeddings-v3",
    "BAAI/bge-large-en-v1.5",
    "BAAI/bge-small-en-v1.5",
    "BAAI/bge-base-en-v1.5",
    "sentence-transformers/all-mpnet-base-v2",
    "Other"
]

st.sidebar.title("Settings")
with st.sidebar:
    st.subheader("Model Selection")
    embedder_name = st.selectbox("Select embedder model", model_options, index=0)

    if embedder_name == "Other":
        embedder_name = st.text_input('Enter the embedder model name')

    if (not st.session_state.embedder_loaded or
            st.session_state.current_embedder_name != embedder_name):

        with st.spinner(f"Loading embedding model: {embedder_name}..."):
            Rag.launch_depression_assistant(embedder_name=embedder_name)
            st.session_state.embedder_loaded = True
            st.session_state.current_embedder_name = embedder_name
            st.success(f"โœ… Embedding model {embedder_name} loaded successfully!")
    else:
        st.info(f"๐Ÿ“‹ Current embedding model: {st.session_state.current_embedder_name}")

    if isinstance(Rag.llm_client, OpenAI):
        # NVIDIA client
        model_list = ["openai/gpt-oss-20b"]
    elif isinstance(Rag.llm_client, Together):
        # Together client
        model_list = ["meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
                      "deepseek-ai/deepseek-r1",
                      "meta/llama-3.3-70b-instruct"]
    else:
        # Default or unknown client
        model_list = ["meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"]

    selected_model = st.selectbox('Choose a model for generation',
                                  model_list,
                                  key='selected_model')

    temperature = st.slider('temperature', min_value=0.01, max_value=1.0, value=0.05, step=0.01)
    top_p = st.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
    max_length = st.slider('max_length', min_value=100, max_value=1000, value=500, step=10)

    if st.button("Save and Upload to Google Sheets"):
        try:
            new_data = st.session_state.session_data[st.session_state.uploaded_rows_count:]
            if new_data:
                upload_to_google_sheets(new_data)
                st.session_state.uploaded_rows_count = len(st.session_state.session_data)
                st.success("Successfully uploaded to Google Sheets!")
            else:
                st.info("No new data to upload.")
        except Exception as e:
            st.error(f"An error occurred: {e}")

st.markdown("## ๐Ÿ’ฌ Bipolar Assistant Chatbot")

if "messages" not in st.session_state:
    st.session_state.messages = [{
        "role": "assistant",
        "content": "Welcome to a prototype of the open-source and open-weight CANMAT/ISBD 2018 Bipolar Guideline chatbot. Please try asking it questions that can be answered by the guidelines. Improvements are ongoing - the visual aspect will change substantially soon. Please let John-Jose know any feedback at johnjose.nunez@ubc.ca. Thanks!"
    }]

chat_col, sources_col = st.columns([1, 1])

with sources_col:
    st.markdown("### Sources")
    sources_placeholder = st.empty()

    with sources_placeholder.container():
        if st.session_state.last_sources:
            for i, result in enumerate(st.session_state.last_sources):
                st.markdown(f"**Source {i + 1}** | Similarity: {result.get('similarity', 'N/A')}")
                st.markdown(f"- **Section:** {result['section']}")
                st.markdown(f"> {result['text']}")
                st.markdown("---")
        else:
            st.markdown("*Sources will appear here after you ask a question.*")

with chat_col:
    for i, message in enumerate(st.session_state.messages):
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
            if message["role"] == "assistant" and i > 0:
                feedback_options = ["Good", "Bad", "Neutral"]
                answer_feedback = st.radio(
                    "Rate your answer:",
                    options=feedback_options,
                    index=2,  # Default to Neutral
                    key=f"answer_feedback_{i}",
                    horizontal=True,
                )
                st.session_state.session_data[i // 2 - 1]["feedback"] = answer_feedback

                source_feedback = st.radio(
                    "Rate your sources:",
                    options=feedback_options,
                    index=2,  # Default to Neutral
                    key=f"source_feedback_{i}",
                    horizontal=True,
                )
                st.session_state.session_data[i // 2 - 1]["source_feedback"] = source_feedback

    if user_input := st.chat_input("Ask me questions about the CANMAT bipolar guideline!"):
        st.chat_message("user").markdown(user_input)
        st.session_state.messages.append({"role": "user", "content": user_input})

        history = st.session_state.messages[:-1][-4:]

        collected = ""
        t0 = time.perf_counter()
        results, response = Rag.depression_assistant(user_input, model_name=selected_model, max_tokens=max_length,
                                                 temperature=temperature, top_p=top_p, stream_flag=True,
                                                 chat_history=history)

        st.session_state.last_sources = results if results else []

        with sources_placeholder.container():
            if st.session_state.last_sources:
                for i, result in enumerate(st.session_state.last_sources):
                    st.markdown(f"**Source {i + 1}** | Similarity: {result.get('similarity', 'N/A')}")
                    st.markdown(f"- **Section:** {result['section']}")
                    st.markdown(f"> {result['text']}")
                    st.markdown("---")
            else:
                st.markdown("*Sources will appear here after you ask a question.*")

        placeholder = st.empty()
        for chunk in response:
            collected += chunk
            placeholder.markdown(collected)

        t1 = time.perf_counter()
        print(f"[Time] Retriever + Generator takes: {t1 - t0:.2f} seconds in total.")
        print(f"============== Finish R-A-Generation for Current Query {user_input} ==============")

        st.session_state.messages.append({"role": "assistant", "content": collected})

        st.session_state.session_data.append(
            {
                "query": user_input,
                "response": collected,
                "sources": json.dumps(st.session_state.last_sources, indent=4),
                "feedback": "Neutral",
                "source_feedback": "Neutral",
            }
        )

        st.rerun()