|
|
import streamlit as st |
|
|
import os |
|
|
from langchain_groq import ChatGroq |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import Chroma |
|
|
from langchain_text_splitters import CharacterTextSplitter |
|
|
from langchain.chains import ConversationalRetrievalChain |
|
|
from langchain.memory import ConversationBufferMemory |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Serenity AI", page_icon="🌿") |
|
|
st.title("🌿 Serenity: Your CBT Companion") |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("About") |
|
|
st.info("This is a support tool, NOT a doctor. Data is processed locally for privacy.") |
|
|
groq_api_key = st.text_input("Groq API Key", type="password") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def setup_rag(): |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
file_path = os.path.join(current_dir, "cbt_resources.txt") |
|
|
|
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
st.error(f"Error: Could not find 'cbt_resources.txt' at {file_path}") |
|
|
return None |
|
|
|
|
|
with open(file_path, "r") as f: |
|
|
raw_text = f.read() |
|
|
|
|
|
|
|
|
text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
|
texts = text_splitter.split_text(raw_text) |
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
|
|
|
db = Chroma.from_texts(texts, embeddings) |
|
|
return db.as_retriever() |
|
|
|
|
|
|
|
|
if groq_api_key: |
|
|
|
|
|
|
|
|
try: |
|
|
llm = ChatGroq( |
|
|
temperature=0.6, |
|
|
|
|
|
model_name="llama-3.3-70b-versatile", |
|
|
groq_api_key=groq_api_key |
|
|
) |
|
|
|
|
|
|
|
|
if "memory" not in st.session_state: |
|
|
st.session_state.memory = ConversationBufferMemory( |
|
|
memory_key="chat_history", |
|
|
return_messages=True |
|
|
) |
|
|
|
|
|
|
|
|
retriever = setup_rag() |
|
|
|
|
|
if retriever: |
|
|
chain = ConversationalRetrievalChain.from_llm( |
|
|
llm=llm, |
|
|
retriever=retriever, |
|
|
memory=st.session_state.memory, |
|
|
verbose=False |
|
|
) |
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [{"role": "assistant", "content": "Hello. I'm Serenity. How are you feeling today?"}] |
|
|
|
|
|
for msg in st.session_state.messages: |
|
|
st.chat_message(msg["role"]).write(msg["content"]) |
|
|
|
|
|
if prompt := st.chat_input(): |
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
st.chat_message("user").write(prompt) |
|
|
|
|
|
|
|
|
dangerous_keywords = ["suicide", "kill myself", "end it all", "die"] |
|
|
if any(word in prompt.lower() for word in dangerous_keywords): |
|
|
response = "I'm really concerned about you, but I am an AI. Please call your local emergency number immediately (like 988 in the US). You are not alone." |
|
|
else: |
|
|
|
|
|
with st.spinner("Thinking..."): |
|
|
response_dict = chain.invoke({"question": prompt}) |
|
|
response = response_dict['answer'] |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
st.chat_message("assistant").write(response) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"An error occurred: {e}") |
|
|
|
|
|
else: |
|
|
st.warning("Please enter your Groq API Key to start.") |