Spaces:
No application file
No application file
File size: 5,491 Bytes
88f42a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | import os
import streamlit as st
from streamlit.logger import get_logger
from langchain.callbacks.base import BaseCallbackHandler
from langchain_community.graphs import Neo4jGraph
from dotenv import load_dotenv
from utils import (
create_vector_index,
)
from chains import (
load_embedding_model,
load_llm,
configure_llm_only_chain,
configure_qa_rag_chain,
generate_ticket,
)
load_dotenv(".env")
url = os.getenv("NEO4J_URI")
username = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")
ollama_base_url = os.getenv("OLLAMA_BASE_URL")
embedding_model_name = os.getenv("EMBEDDING_MODEL")
llm_name = os.getenv("LLM")
# Remapping for Langchain Neo4j integration
os.environ["NEO4J_URL"] = url
logger = get_logger(__name__)
# if Neo4j is local, you can go to http://localhost:7474/ to browse the database
neo4j_graph = Neo4jGraph(
url=url, username=username, password=password, refresh_schema=False
)
embeddings, dimension = load_embedding_model(
embedding_model_name, config={"ollama_base_url": ollama_base_url}, logger=logger
)
create_vector_index(neo4j_graph)
class StreamHandler(BaseCallbackHandler):
def __init__(self, container, initial_text=""):
self.container = container
self.text = initial_text
def on_llm_new_token(self, token: str, **kwargs) -> None:
self.text += token
self.container.markdown(self.text)
llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
llm_chain = configure_llm_only_chain(llm)
rag_chain = configure_qa_rag_chain(
llm, embeddings, embeddings_store_url=url, username=username, password=password
)
# Streamlit UI
styl = f"""
<style>
/* not great support for :has yet (hello FireFox), but using it for now */
.element-container:has([aria-label="Select RAG mode"]) {{
position: fixed;
bottom: 33px;
background: white;
z-index: 101;
}}
.stChatFloatingInputContainer {{
bottom: 20px;
}}
/* Generate ticket text area */
textarea[aria-label="Description"] {{
height: 200px;
}}
.element-container:has([aria-label="What coding issue can I help you resolve today?"]) {{
bottom: 45px;
}}
</style>
"""
st.markdown(styl, unsafe_allow_html=True)
def chat_input():
user_input = st.chat_input("What coding issue can I help you resolve today?")
if user_input:
with st.chat_message("user"):
st.write(user_input)
with st.chat_message("assistant"):
st.caption(f"RAG: {name}")
stream_handler = StreamHandler(st.empty())
result = output_function(
{"question": user_input, "chat_history": []}, callbacks=[stream_handler]
)["answer"]
output = result
st.session_state[f"user_input"].append(user_input)
st.session_state[f"generated"].append(output)
st.session_state[f"rag_mode"].append(name)
def display_chat():
# Session state
if "generated" not in st.session_state:
st.session_state[f"generated"] = []
if "user_input" not in st.session_state:
st.session_state[f"user_input"] = []
if "rag_mode" not in st.session_state:
st.session_state[f"rag_mode"] = []
if st.session_state[f"generated"]:
size = len(st.session_state[f"generated"])
# Display only the last three exchanges
for i in range(max(size - 3, 0), size):
with st.chat_message("user"):
st.write(st.session_state[f"user_input"][i])
with st.chat_message("assistant"):
st.caption(f"RAG: {st.session_state[f'rag_mode'][i]}")
st.write(st.session_state[f"generated"][i])
with st.expander("Not finding what you're looking for?"):
st.write(
"Automatically generate a draft for an internal ticket to our support team."
)
st.button(
"Generate ticket",
type="primary",
key="show_ticket",
on_click=open_sidebar,
)
with st.container():
st.write(" ")
def mode_select() -> str:
options = ["Disabled", "Enabled"]
return st.radio("Select RAG mode", options, horizontal=True)
name = mode_select()
if name == "LLM only" or name == "Disabled":
output_function = llm_chain
elif name == "Vector + Graph" or name == "Enabled":
output_function = rag_chain
def open_sidebar():
st.session_state.open_sidebar = True
def close_sidebar():
st.session_state.open_sidebar = False
if not "open_sidebar" in st.session_state:
st.session_state.open_sidebar = False
if st.session_state.open_sidebar:
new_title, new_question = generate_ticket(
neo4j_graph=neo4j_graph,
llm_chain=llm_chain,
input_question=st.session_state[f"user_input"][-1],
)
with st.sidebar:
st.title("Ticket draft")
st.write("Auto generated draft ticket")
st.text_input("Title", new_title)
st.text_area("Description", new_question)
st.button(
"Submit to support team",
type="primary",
key="submit_ticket",
on_click=close_sidebar,
)
display_chat()
chat_input()
|