|
|
import streamlit as st |
|
|
import Rag |
|
|
from openai import OpenAI |
|
|
from together import Together |
|
|
import time |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from google_sheets_uploader import upload_to_google_sheets |
|
|
import pandas as pd |
|
|
import json |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
if "embedder_loaded" not in st.session_state: |
|
|
st.session_state.embedder_loaded = False |
|
|
if "current_embedder_name" not in st.session_state: |
|
|
st.session_state.current_embedder_name = None |
|
|
if "last_sources" not in st.session_state: |
|
|
st.session_state.last_sources = [] |
|
|
if "session_data" not in st.session_state: |
|
|
st.session_state.session_data = [] |
|
|
if "uploaded_rows_count" not in st.session_state: |
|
|
st.session_state.uploaded_rows_count = 0 |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Bipolar Assistant Chatbot", |
|
|
page_icon=":robot_face:", |
|
|
layout="wide", |
|
|
initial_sidebar_state="collapsed" |
|
|
) |
|
|
|
|
|
model_options = [ |
|
|
"Qwen/Qwen3-Embedding-0.6B", |
|
|
"jinaai/jina-embeddings-v3", |
|
|
"BAAI/bge-large-en-v1.5", |
|
|
"BAAI/bge-small-en-v1.5", |
|
|
"BAAI/bge-base-en-v1.5", |
|
|
"sentence-transformers/all-mpnet-base-v2", |
|
|
"Other" |
|
|
] |
|
|
|
|
|
st.sidebar.title("Settings") |
|
|
with st.sidebar: |
|
|
st.subheader("Model Selection") |
|
|
embedder_name = st.selectbox("Select embedder model", model_options, index=0) |
|
|
|
|
|
if embedder_name == "Other": |
|
|
embedder_name = st.text_input('Enter the embedder model name') |
|
|
|
|
|
if (not st.session_state.embedder_loaded or |
|
|
st.session_state.current_embedder_name != embedder_name): |
|
|
|
|
|
with st.spinner(f"Loading embedding model: {embedder_name}..."): |
|
|
Rag.launch_depression_assistant(embedder_name=embedder_name) |
|
|
st.session_state.embedder_loaded = True |
|
|
st.session_state.current_embedder_name = embedder_name |
|
|
st.success(f"✅ Embedding model {embedder_name} loaded successfully!") |
|
|
else: |
|
|
st.info(f"📋 Current embedding model: {st.session_state.current_embedder_name}") |
|
|
|
|
|
if isinstance(Rag.llm_client, OpenAI): |
|
|
|
|
|
model_list = ["openai/gpt-oss-20b"] |
|
|
elif isinstance(Rag.llm_client, Together): |
|
|
|
|
|
model_list = ["meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", |
|
|
"deepseek-ai/deepseek-r1", |
|
|
"meta/llama-3.3-70b-instruct"] |
|
|
else: |
|
|
|
|
|
model_list = ["meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"] |
|
|
|
|
|
selected_model = st.selectbox('Choose a model for generation', |
|
|
model_list, |
|
|
key='selected_model') |
|
|
|
|
|
temperature = st.slider('temperature', min_value=0.01, max_value=1.0, value=0.05, step=0.01) |
|
|
top_p = st.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01) |
|
|
max_length = st.slider('max_length', min_value=100, max_value=1000, value=500, step=10) |
|
|
|
|
|
if st.button("Save and Upload to Google Sheets"): |
|
|
try: |
|
|
new_data = st.session_state.session_data[st.session_state.uploaded_rows_count:] |
|
|
if new_data: |
|
|
upload_to_google_sheets(new_data) |
|
|
st.session_state.uploaded_rows_count = len(st.session_state.session_data) |
|
|
st.success("Successfully uploaded to Google Sheets!") |
|
|
else: |
|
|
st.info("No new data to upload.") |
|
|
except Exception as e: |
|
|
st.error(f"An error occurred: {e}") |
|
|
|
|
|
st.markdown("## 💬 Bipolar Assistant Chatbot") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [{ |
|
|
"role": "assistant", |
|
|
"content": "Welcome to a prototype of the open-source and open-weight CANMAT/ISBD 2018 Bipolar Guideline chatbot. Please try asking it questions that can be answered by the guidelines. Improvements are ongoing - the visual aspect will change substantially soon. Please let John-Jose know any feedback at johnjose.nunez@ubc.ca. Thanks!" |
|
|
}] |
|
|
|
|
|
chat_col, sources_col = st.columns([1, 1]) |
|
|
|
|
|
with sources_col: |
|
|
st.markdown("### Sources") |
|
|
sources_placeholder = st.empty() |
|
|
|
|
|
with sources_placeholder.container(): |
|
|
if st.session_state.last_sources: |
|
|
for i, result in enumerate(st.session_state.last_sources): |
|
|
st.markdown(f"**Source {i + 1}** | Similarity: {result.get('similarity', 'N/A')}") |
|
|
st.markdown(f"- **Section:** {result['section']}") |
|
|
st.markdown(f"> {result['text']}") |
|
|
st.markdown("---") |
|
|
else: |
|
|
st.markdown("*Sources will appear here after you ask a question.*") |
|
|
|
|
|
with chat_col: |
|
|
for i, message in enumerate(st.session_state.messages): |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
if message["role"] == "assistant" and i > 0: |
|
|
feedback_options = ["Good", "Bad", "Neutral"] |
|
|
answer_feedback = st.radio( |
|
|
"Rate your answer:", |
|
|
options=feedback_options, |
|
|
index=2, |
|
|
key=f"answer_feedback_{i}", |
|
|
horizontal=True, |
|
|
) |
|
|
st.session_state.session_data[i // 2 - 1]["feedback"] = answer_feedback |
|
|
|
|
|
source_feedback = st.radio( |
|
|
"Rate your sources:", |
|
|
options=feedback_options, |
|
|
index=2, |
|
|
key=f"source_feedback_{i}", |
|
|
horizontal=True, |
|
|
) |
|
|
st.session_state.session_data[i // 2 - 1]["source_feedback"] = source_feedback |
|
|
|
|
|
if user_input := st.chat_input("Ask me questions about the CANMAT bipolar guideline!"): |
|
|
st.chat_message("user").markdown(user_input) |
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
|
|
|
|
history = st.session_state.messages[:-1][-4:] |
|
|
|
|
|
collected = "" |
|
|
t0 = time.perf_counter() |
|
|
results, response = Rag.depression_assistant(user_input, model_name=selected_model, max_tokens=max_length, |
|
|
temperature=temperature, top_p=top_p, stream_flag=True, |
|
|
chat_history=history) |
|
|
|
|
|
st.session_state.last_sources = results if results else [] |
|
|
|
|
|
with sources_placeholder.container(): |
|
|
if st.session_state.last_sources: |
|
|
for i, result in enumerate(st.session_state.last_sources): |
|
|
st.markdown(f"**Source {i + 1}** | Similarity: {result.get('similarity', 'N/A')}") |
|
|
st.markdown(f"- **Section:** {result['section']}") |
|
|
st.markdown(f"> {result['text']}") |
|
|
st.markdown("---") |
|
|
else: |
|
|
st.markdown("*Sources will appear here after you ask a question.*") |
|
|
|
|
|
placeholder = st.empty() |
|
|
for chunk in response: |
|
|
collected += chunk |
|
|
placeholder.markdown(collected) |
|
|
|
|
|
t1 = time.perf_counter() |
|
|
print(f"[Time] Retriever + Generator takes: {t1 - t0:.2f} seconds in total.") |
|
|
print(f"============== Finish R-A-Generation for Current Query {user_input} ==============") |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": collected}) |
|
|
|
|
|
st.session_state.session_data.append( |
|
|
{ |
|
|
"query": user_input, |
|
|
"response": collected, |
|
|
"sources": json.dumps(st.session_state.last_sources, indent=4), |
|
|
"feedback": "Neutral", |
|
|
"source_feedback": "Neutral", |
|
|
} |
|
|
) |
|
|
|
|
|
st.rerun() |