bipolar / src /app.py
ymali's picture
use openai oss
6dffeff
import streamlit as st
import Rag
from openai import OpenAI
from together import Together
import time
import os
from dotenv import load_dotenv
from google_sheets_uploader import upload_to_google_sheets
import pandas as pd
import json
load_dotenv()
if "embedder_loaded" not in st.session_state:
st.session_state.embedder_loaded = False
if "current_embedder_name" not in st.session_state:
st.session_state.current_embedder_name = None
if "last_sources" not in st.session_state:
st.session_state.last_sources = []
if "session_data" not in st.session_state:
st.session_state.session_data = []
if "uploaded_rows_count" not in st.session_state:
st.session_state.uploaded_rows_count = 0
st.set_page_config(
page_title="Bipolar Assistant Chatbot",
page_icon=":robot_face:",
layout="wide",
initial_sidebar_state="collapsed"
)
model_options = [
"Qwen/Qwen3-Embedding-0.6B",
"jinaai/jina-embeddings-v3",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-small-en-v1.5",
"BAAI/bge-base-en-v1.5",
"sentence-transformers/all-mpnet-base-v2",
"Other"
]
st.sidebar.title("Settings")
with st.sidebar:
st.subheader("Model Selection")
embedder_name = st.selectbox("Select embedder model", model_options, index=0)
if embedder_name == "Other":
embedder_name = st.text_input('Enter the embedder model name')
if (not st.session_state.embedder_loaded or
st.session_state.current_embedder_name != embedder_name):
with st.spinner(f"Loading embedding model: {embedder_name}..."):
Rag.launch_depression_assistant(embedder_name=embedder_name)
st.session_state.embedder_loaded = True
st.session_state.current_embedder_name = embedder_name
st.success(f"✅ Embedding model {embedder_name} loaded successfully!")
else:
st.info(f"📋 Current embedding model: {st.session_state.current_embedder_name}")
if isinstance(Rag.llm_client, OpenAI):
# NVIDIA client
model_list = ["openai/gpt-oss-20b"]
elif isinstance(Rag.llm_client, Together):
# Together client
model_list = ["meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
"deepseek-ai/deepseek-r1",
"meta/llama-3.3-70b-instruct"]
else:
# Default or unknown client
model_list = ["meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"]
selected_model = st.selectbox('Choose a model for generation',
model_list,
key='selected_model')
temperature = st.slider('temperature', min_value=0.01, max_value=1.0, value=0.05, step=0.01)
top_p = st.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
max_length = st.slider('max_length', min_value=100, max_value=1000, value=500, step=10)
if st.button("Save and Upload to Google Sheets"):
try:
new_data = st.session_state.session_data[st.session_state.uploaded_rows_count:]
if new_data:
upload_to_google_sheets(new_data)
st.session_state.uploaded_rows_count = len(st.session_state.session_data)
st.success("Successfully uploaded to Google Sheets!")
else:
st.info("No new data to upload.")
except Exception as e:
st.error(f"An error occurred: {e}")
st.markdown("## 💬 Bipolar Assistant Chatbot")
if "messages" not in st.session_state:
st.session_state.messages = [{
"role": "assistant",
"content": "Welcome to a prototype of the open-source and open-weight CANMAT/ISBD 2018 Bipolar Guideline chatbot. Please try asking it questions that can be answered by the guidelines. Improvements are ongoing - the visual aspect will change substantially soon. Please let John-Jose know any feedback at johnjose.nunez@ubc.ca. Thanks!"
}]
chat_col, sources_col = st.columns([1, 1])
with sources_col:
st.markdown("### Sources")
sources_placeholder = st.empty()
with sources_placeholder.container():
if st.session_state.last_sources:
for i, result in enumerate(st.session_state.last_sources):
st.markdown(f"**Source {i + 1}** | Similarity: {result.get('similarity', 'N/A')}")
st.markdown(f"- **Section:** {result['section']}")
st.markdown(f"> {result['text']}")
st.markdown("---")
else:
st.markdown("*Sources will appear here after you ask a question.*")
with chat_col:
for i, message in enumerate(st.session_state.messages):
with st.chat_message(message["role"]):
st.markdown(message["content"])
if message["role"] == "assistant" and i > 0:
feedback_options = ["Good", "Bad", "Neutral"]
answer_feedback = st.radio(
"Rate your answer:",
options=feedback_options,
index=2, # Default to Neutral
key=f"answer_feedback_{i}",
horizontal=True,
)
st.session_state.session_data[i // 2 - 1]["feedback"] = answer_feedback
source_feedback = st.radio(
"Rate your sources:",
options=feedback_options,
index=2, # Default to Neutral
key=f"source_feedback_{i}",
horizontal=True,
)
st.session_state.session_data[i // 2 - 1]["source_feedback"] = source_feedback
if user_input := st.chat_input("Ask me questions about the CANMAT bipolar guideline!"):
st.chat_message("user").markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})
history = st.session_state.messages[:-1][-4:]
collected = ""
t0 = time.perf_counter()
results, response = Rag.depression_assistant(user_input, model_name=selected_model, max_tokens=max_length,
temperature=temperature, top_p=top_p, stream_flag=True,
chat_history=history)
st.session_state.last_sources = results if results else []
with sources_placeholder.container():
if st.session_state.last_sources:
for i, result in enumerate(st.session_state.last_sources):
st.markdown(f"**Source {i + 1}** | Similarity: {result.get('similarity', 'N/A')}")
st.markdown(f"- **Section:** {result['section']}")
st.markdown(f"> {result['text']}")
st.markdown("---")
else:
st.markdown("*Sources will appear here after you ask a question.*")
placeholder = st.empty()
for chunk in response:
collected += chunk
placeholder.markdown(collected)
t1 = time.perf_counter()
print(f"[Time] Retriever + Generator takes: {t1 - t0:.2f} seconds in total.")
print(f"============== Finish R-A-Generation for Current Query {user_input} ==============")
st.session_state.messages.append({"role": "assistant", "content": collected})
st.session_state.session_data.append(
{
"query": user_input,
"response": collected,
"sources": json.dumps(st.session_state.last_sources, indent=4),
"feedback": "Neutral",
"source_feedback": "Neutral",
}
)
st.rerun()