File size: 7,599 Bytes
3530638 6dffeff 3530638 1aa1e51 3530638 1aa1e51 3530638 6dffeff 3530638 6dffeff 3530638 6dffeff 3530638 1aa1e51 b9db63a 3530638 b9db63a 3530638 1aa1e51 3530638 1aa1e51 3530638 6dffeff 3530638 1aa1e51 3530638 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import streamlit as st
import Rag
from openai import OpenAI
from together import Together
import time
import os
from dotenv import load_dotenv
from google_sheets_uploader import upload_to_google_sheets
import pandas as pd
import json
load_dotenv()
if "embedder_loaded" not in st.session_state:
st.session_state.embedder_loaded = False
if "current_embedder_name" not in st.session_state:
st.session_state.current_embedder_name = None
if "last_sources" not in st.session_state:
st.session_state.last_sources = []
if "session_data" not in st.session_state:
st.session_state.session_data = []
if "uploaded_rows_count" not in st.session_state:
st.session_state.uploaded_rows_count = 0
st.set_page_config(
page_title="Bipolar Assistant Chatbot",
page_icon=":robot_face:",
layout="wide",
initial_sidebar_state="collapsed"
)
model_options = [
"Qwen/Qwen3-Embedding-0.6B",
"jinaai/jina-embeddings-v3",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5",
"sentence-transformers/all-mpnet-base-v2",
"Other"
]
st.sidebar.title("Settings")
with st.sidebar:
st.subheader("Model Selection")
embedder_name = st.selectbox("Select embedder model", model_options, index=0)
if embedder_name == "Other":
embedder_name = st.text_input('Enter the embedder model name')
if (not st.session_state.embedder_loaded or
st.session_state.current_embedder_name != embedder_name):
with st.spinner(f"Loading embedding model: {embedder_name}..."):
Rag.launch_depression_assistant(embedder_name=embedder_name)
st.session_state.embedder_loaded = True
st.session_state.current_embedder_name = embedder_name
st.success(f"โ
Embedding model {embedder_name} loaded successfully!")
else:
st.info(f"๐ Current embedding model: {st.session_state.current_embedder_name}")
if isinstance(Rag.llm_client, OpenAI):
# NVIDIA client
model_list = ["openai/gpt-oss-20b"]
elif isinstance(Rag.llm_client, Together):
# Together client
model_list = ["meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
"deepseek-ai/deepseek-r1",
"meta/llama-3.3-70b-instruct"]
else:
# Default or unknown client
model_list = ["meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"]
selected_model = st.selectbox('Choose a model for generation',
model_list,
key='selected_model')
temperature = st.slider('temperature', min_value=0.01, max_value=1.0, value=0.05, step=0.01)
top_p = st.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
max_length = st.slider('max_length', min_value=100, max_value=1000, value=500, step=10)
if st.button("Save and Upload to Google Sheets"):
try:
new_data = st.session_state.session_data[st.session_state.uploaded_rows_count:]
if new_data:
upload_to_google_sheets(new_data)
st.session_state.uploaded_rows_count = len(st.session_state.session_data)
st.success("Successfully uploaded to Google Sheets!")
else:
st.info("No new data to upload.")
except Exception as e:
st.error(f"An error occurred: {e}")
st.markdown("## ๐ฌ Bipolar Assistant Chatbot")
if "messages" not in st.session_state:
st.session_state.messages = [{
"role": "assistant",
"content": "Welcome to a prototype of the open-source and open-weight CANMAT/ISBD 2018 Bipolar Guideline chatbot. Please try asking it questions that can be answered by the guidelines. Improvements are ongoing - the visual aspect will change substantially soon. Please let John-Jose know any feedback at johnjose.nunez@ubc.ca. Thanks!"
}]
chat_col, sources_col = st.columns([1, 1])
with sources_col:
st.markdown("### Sources")
sources_placeholder = st.empty()
with sources_placeholder.container():
if st.session_state.last_sources:
for i, result in enumerate(st.session_state.last_sources):
st.markdown(f"**Source {i + 1}** | Similarity: {result.get('similarity', 'N/A')}")
st.markdown(f"- **Section:** {result['section']}")
st.markdown(f"> {result['text']}")
st.markdown("---")
else:
st.markdown("*Sources will appear here after you ask a question.*")
with chat_col:
for i, message in enumerate(st.session_state.messages):
with st.chat_message(message["role"]):
st.markdown(message["content"])
if message["role"] == "assistant" and i > 0:
feedback_options = ["Good", "Bad", "Neutral"]
answer_feedback = st.radio(
"Rate your answer:",
options=feedback_options,
index=2, # Default to Neutral
key=f"answer_feedback_{i}",
horizontal=True,
)
st.session_state.session_data[i // 2 - 1]["feedback"] = answer_feedback
source_feedback = st.radio(
"Rate your sources:",
options=feedback_options,
index=2, # Default to Neutral
key=f"source_feedback_{i}",
horizontal=True,
)
st.session_state.session_data[i // 2 - 1]["source_feedback"] = source_feedback
if user_input := st.chat_input("Ask me questions about the CANMAT bipolar guideline!"):
st.chat_message("user").markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
history = st.session_state.messages[:-1][-4:]
collected = ""
t0 = time.perf_counter()
results, response = Rag.depression_assistant(user_input, model_name=selected_model, max_tokens=max_length,
temperature=temperature, top_p=top_p, stream_flag=True,
chat_history=history)
st.session_state.last_sources = results if results else []
with sources_placeholder.container():
if st.session_state.last_sources:
for i, result in enumerate(st.session_state.last_sources):
st.markdown(f"**Source {i + 1}** | Similarity: {result.get('similarity', 'N/A')}")
st.markdown(f"- **Section:** {result['section']}")
st.markdown(f"> {result['text']}")
st.markdown("---")
else:
st.markdown("*Sources will appear here after you ask a question.*")
placeholder = st.empty()
for chunk in response:
collected += chunk
placeholder.markdown(collected)
t1 = time.perf_counter()
print(f"[Time] Retriever + Generator takes: {t1 - t0:.2f} seconds in total.")
print(f"============== Finish R-A-Generation for Current Query {user_input} ==============")
st.session_state.messages.append({"role": "assistant", "content": collected})
st.session_state.session_data.append(
{
"query": user_input,
"response": collected,
"sources": json.dumps(st.session_state.last_sources, indent=4),
"feedback": "Neutral",
"source_feedback": "Neutral",
}
)
st.rerun() |