Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,52 +10,38 @@ from langchain_core.messages import HumanMessage, SystemMessage
|
|
| 10 |
from langchain_openai import ChatOpenAI
|
| 11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
|
| 13 |
-
# Function to load
|
| 14 |
-
def
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def load_documents(df, content_column):
|
| 20 |
-
docs = DataFrameLoader(df, page_content_column=content_column).load()
|
| 21 |
-
return docs
|
| 22 |
-
|
| 23 |
-
# Function to tokenize documents
|
| 24 |
-
# def tokenize_documents(docs):
|
| 25 |
-
# encoder = tiktoken.get_encoding("cl100k_base")
|
| 26 |
-
# tokens_per_docs = [len(encoder.encode(doc.page_content)) for doc in docs]
|
| 27 |
-
# total_tokens = sum(tokens_per_docs)
|
| 28 |
-
# cost_per_1000_tokens = 0.0001
|
| 29 |
-
# cost = (total_tokens / 1000) * cost_per_1000_tokens
|
| 30 |
-
# return tokens_per_docs, cost
|
| 31 |
-
|
| 32 |
-
# Function to create vector database
|
| 33 |
-
def create_vector_db(docs):
|
| 34 |
-
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
| 35 |
-
texts = text_splitter.split_documents(docs)
|
| 36 |
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 37 |
-
vectordb = Chroma
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
return vectordb
|
| 42 |
|
| 43 |
# Function to augment prompt
|
| 44 |
def augment_prompt(query, vectordb):
|
| 45 |
results = vectordb.similarity_search(query, k=3)
|
| 46 |
source_knowledge = "\n".join([x.page_content for x in results])
|
| 47 |
-
augmented_prompt = f"""
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
{source_knowledge}
|
| 52 |
|
| 53 |
-
|
|
|
|
| 54 |
return augmented_prompt
|
| 55 |
|
| 56 |
-
# Function to handle chat
|
| 57 |
-
def
|
| 58 |
-
chat = ChatOpenAI(model_name="gpt-3.5-turbo",openai_api_key=openai_api_key)
|
| 59 |
augmented_query = augment_prompt(query, vectordb)
|
| 60 |
prompt = HumanMessage(content=augmented_query)
|
| 61 |
messages = [
|
|
@@ -68,33 +54,17 @@ def chat_with_ai(query, vectordb,openai_api_key):
|
|
| 68 |
# Streamlit UI
|
| 69 |
st.title("Document Processing and AI Chat with LangChain")
|
| 70 |
|
| 71 |
-
#
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
st.write(
|
| 82 |
-
|
| 83 |
-
# Tokenize documents
|
| 84 |
-
# tokens_per_docs, cost = tokenize_documents(docs)
|
| 85 |
-
# st.write(f"Total tokens: {sum(tokens_per_docs)}")
|
| 86 |
-
# st.write(f"Estimated cost: ${cost:.4f}")
|
| 87 |
-
|
| 88 |
-
# Create vector database
|
| 89 |
-
vectordb = create_vector_db(docs)
|
| 90 |
-
st.write("Vector database created and persisted successfully!")
|
| 91 |
-
|
| 92 |
-
# Query input
|
| 93 |
-
query = st.text_input("Enter your query", "Recommend a company to work as a data scientist in the health sector")
|
| 94 |
-
|
| 95 |
-
if st.button("Get Answer"):
|
| 96 |
-
# Chat with AI
|
| 97 |
-
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 98 |
-
response = chat_with_ai(query, vectordb, openai_api_key)
|
| 99 |
-
st.write("Response from AI:")
|
| 100 |
-
st.write(response)
|
|
|
|
| 10 |
from langchain_openai import ChatOpenAI
|
| 11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
|
| 13 |
+
# Function to load vector database
|
| 14 |
+
def load_vector_db(zip_file_path, extract_path):
|
| 15 |
+
with st.spinner("Loading vector store..."):
|
| 16 |
+
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
| 17 |
+
zip_ref.extractall(extract_path)
|
| 18 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 20 |
+
vectordb = Chroma(
|
| 21 |
+
persist_directory=extract_path,
|
| 22 |
+
embedding_function=embedding_function
|
| 23 |
+
)
|
| 24 |
+
st.success("Vector store loaded")
|
| 25 |
return vectordb
|
| 26 |
|
| 27 |
# Function to augment prompt
|
| 28 |
def augment_prompt(query, vectordb):
|
| 29 |
results = vectordb.similarity_search(query, k=3)
|
| 30 |
source_knowledge = "\n".join([x.page_content for x in results])
|
| 31 |
+
augmented_prompt = f"""
|
| 32 |
+
You are an AI assistant. Use the context provided below to answer the question as comprehensively as possible.
|
| 33 |
+
If the answer is not contained within the context, respond with "I don't know".
|
| 34 |
|
| 35 |
+
Context:
|
| 36 |
{source_knowledge}
|
| 37 |
|
| 38 |
+
Question: {query}
|
| 39 |
+
"""
|
| 40 |
return augmented_prompt
|
| 41 |
|
| 42 |
+
# Function to handle chat with OpenAI
|
| 43 |
+
def chat_with_openai(query, vectordb, openai_api_key):
|
| 44 |
+
chat = ChatOpenAI(model_name="gpt-3.5-turbo", openai_api_key=openai_api_key)
|
| 45 |
augmented_query = augment_prompt(query, vectordb)
|
| 46 |
prompt = HumanMessage(content=augmented_query)
|
| 47 |
messages = [
|
|
|
|
| 54 |
# Streamlit UI
|
| 55 |
st.title("Document Processing and AI Chat with LangChain")
|
| 56 |
|
| 57 |
+
# Load vector database
|
| 58 |
+
zip_file_path = "chroma_db_compressed_.zip"
|
| 59 |
+
extract_path = "./chroma_db_extracted"
|
| 60 |
+
vectordb = load_vector_db(zip_file_path, extract_path)
|
| 61 |
|
| 62 |
+
# Query input
|
| 63 |
+
query = st.text_input("Enter your query", "Recommend a company to work as a data scientist in the health sector")
|
| 64 |
+
|
| 65 |
+
if st.button("Get Answer"):
|
| 66 |
+
# Chat with OpenAI
|
| 67 |
+
openai_api_key = st.secrets["OPENAI_API_KEY"]
|
| 68 |
+
response = chat_with_openai(query, vectordb, openai_api_key)
|
| 69 |
+
st.write("Response from AI:")
|
| 70 |
+
st.write(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|