Spaces:
Sleeping
Sleeping
persist_directory (2) added
Browse files
app.py
CHANGED
|
@@ -245,15 +245,18 @@ def create_vector_database_ESG():
|
|
| 245 |
#len(docs)
|
| 246 |
print(f"length of documents loaded: {len(documents)}")
|
| 247 |
print(f"total number of document chunks generated :{len(docs)}")
|
|
|
|
| 248 |
embed_model = HuggingFaceEmbeddings()
|
| 249 |
|
| 250 |
vs = Chroma.from_documents(
|
| 251 |
documents=docs,
|
| 252 |
embedding=embed_model,
|
| 253 |
-
collection_name="
|
|
|
|
| 254 |
)
|
|
|
|
| 255 |
doc_retriever_ESG = vs.as_retriever()
|
| 256 |
-
|
| 257 |
index = VectorStoreIndex.from_documents(llama_parse_documents)
|
| 258 |
query_engine = index.as_query_engine()
|
| 259 |
|
|
@@ -274,19 +277,25 @@ def create_vector_database_financials():
|
|
| 274 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=15)
|
| 275 |
docs = text_splitter.split_documents(documents)
|
| 276 |
|
|
|
|
|
|
|
| 277 |
embed_model = HuggingFaceEmbeddings()
|
| 278 |
|
|
|
|
| 279 |
vs = Chroma.from_documents(
|
| 280 |
documents=docs,
|
| 281 |
embedding=embed_model,
|
| 282 |
-
collection_name="
|
|
|
|
| 283 |
)
|
|
|
|
| 284 |
doc_retriever_financials = vs.as_retriever()
|
| 285 |
|
|
|
|
| 286 |
index = VectorStoreIndex.from_documents(llama_parse_documents)
|
| 287 |
query_engine_financials = index.as_query_engine()
|
| 288 |
|
| 289 |
-
print('Vector DB created successfully
|
| 290 |
return doc_retriever_financials, query_engine_financials
|
| 291 |
|
| 292 |
#--------------
|
|
@@ -328,6 +337,7 @@ for uploaded_file in uploaded_files_financials:
|
|
| 328 |
#---------------
|
| 329 |
def ESG_strategy():
|
| 330 |
doc_retriever_ESG, _ = create_vector_database_ESG()
|
|
|
|
| 331 |
prompt_template = """<|system|>
|
| 332 |
You are a seasoned specialist in environmental, social and governance matters. You write expert analyses for institutional investors. Always use figures, nemerical and statistical data when possible. Output must have sub-headings in bold font and be fluent.<|end|>
|
| 333 |
<|user|>
|
|
@@ -505,15 +515,8 @@ with strategies_container:
|
|
| 505 |
with mrow1_col2:
|
| 506 |
if "ESG_analysis_button_key" in st.session_state.results and st.session_state.results["ESG_analysis_button_key"]:
|
| 507 |
|
| 508 |
-
doc_retriever_ESG, query_engine = create_vector_database_ESG()
|
| 509 |
-
|
| 510 |
-
file_path = os.path.join("data", "parsed_data_financials.pkl")
|
| 511 |
-
|
| 512 |
-
# Check if the file exists before running the function
|
| 513 |
-
if os.path.exists(file_path):
|
| 514 |
-
doc_retriever_financials, query_engine_financials = create_vector_database_financials()
|
| 515 |
-
else:
|
| 516 |
-
print(f"The file {file_path} does not exist. Skipping vector database creation.")
|
| 517 |
|
| 518 |
memory = ConversationBufferMemory(memory_key="chat_history", k=3, return_messages=True)
|
| 519 |
search = SerpAPIWrapper()
|
|
@@ -548,19 +551,17 @@ with strategies_container:
|
|
| 548 |
"""
|
| 549 |
)
|
| 550 |
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
| StrOutputParser()
|
| 563 |
-
)
|
| 564 |
|
| 565 |
ESG_chain = (
|
| 566 |
{
|
|
@@ -581,12 +582,11 @@ with strategies_container:
|
|
| 581 |
description="Useful for answering questions about specific ESG figures, data and statistics.",
|
| 582 |
)
|
| 583 |
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
)
|
| 590 |
|
| 591 |
tools = [
|
| 592 |
Tool(
|
|
@@ -594,23 +594,19 @@ with strategies_container:
|
|
| 594 |
func=ESG_chain.invoke,
|
| 595 |
description="Useful for answering general questions about environmental, social, and governance (ESG) matters related to the company. ",
|
| 596 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
Tool(
|
| 598 |
name="Search Tool",
|
| 599 |
func=search.run,
|
| 600 |
description="Useful when other tools do not provide the answer.",
|
| 601 |
),
|
| 602 |
vector_query_tool_ESG,
|
| 603 |
-
]
|
| 604 |
-
|
| 605 |
-
if os.path.exists(file_path):
|
| 606 |
-
tools.append(
|
| 607 |
-
Tool(
|
| 608 |
-
name="Financials QA System",
|
| 609 |
-
func=financials_chain.invoke,
|
| 610 |
-
description="Useful for answering general questions about financial or operational information concerning the company.",
|
| 611 |
-
),
|
| 612 |
vector_query_tool_financials,
|
| 613 |
-
|
| 614 |
|
| 615 |
# Initialize the agent with LCEL tools and memory
|
| 616 |
agent = initialize_agent(
|
|
|
|
| 245 |
#len(docs)
|
| 246 |
print(f"length of documents loaded: {len(documents)}")
|
| 247 |
print(f"total number of document chunks generated :{len(docs)}")
|
| 248 |
+
persist_directory = "./chroma_db_ESG" # Specify directory for Chroma persistence
|
| 249 |
embed_model = HuggingFaceEmbeddings()
|
| 250 |
|
| 251 |
vs = Chroma.from_documents(
|
| 252 |
documents=docs,
|
| 253 |
embedding=embed_model,
|
| 254 |
+
collection_name="rag_ESG",
|
| 255 |
+
persist_directory=persist_directory # Ensure persistence
|
| 256 |
)
|
| 257 |
+
|
| 258 |
doc_retriever_ESG = vs.as_retriever()
|
| 259 |
+
|
| 260 |
index = VectorStoreIndex.from_documents(llama_parse_documents)
|
| 261 |
query_engine = index.as_query_engine()
|
| 262 |
|
|
|
|
| 277 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=15)
|
| 278 |
docs = text_splitter.split_documents(documents)
|
| 279 |
|
| 280 |
+
# Add a persist directory for Chroma DB
|
| 281 |
+
persist_directory = "./chroma_db_financials" # Specify directory for persistence
|
| 282 |
embed_model = HuggingFaceEmbeddings()
|
| 283 |
|
| 284 |
+
# Initialize Chroma with persistence
|
| 285 |
vs = Chroma.from_documents(
|
| 286 |
documents=docs,
|
| 287 |
embedding=embed_model,
|
| 288 |
+
collection_name="rag_financials", # Use a unique collection name
|
| 289 |
+
persist_directory=persist_directory # Persist the data
|
| 290 |
)
|
| 291 |
+
|
| 292 |
doc_retriever_financials = vs.as_retriever()
|
| 293 |
|
| 294 |
+
# Build a VectorStore index for querying
|
| 295 |
index = VectorStoreIndex.from_documents(llama_parse_documents)
|
| 296 |
query_engine_financials = index.as_query_engine()
|
| 297 |
|
| 298 |
+
print('Vector DB for financials created successfully!')
|
| 299 |
return doc_retriever_financials, query_engine_financials
|
| 300 |
|
| 301 |
#--------------
|
|
|
|
| 337 |
#---------------
|
| 338 |
def ESG_strategy():
|
| 339 |
doc_retriever_ESG, _ = create_vector_database_ESG()
|
| 340 |
+
|
| 341 |
prompt_template = """<|system|>
|
| 342 |
You are a seasoned specialist in environmental, social and governance matters. You write expert analyses for institutional investors. Always use figures, nemerical and statistical data when possible. Output must have sub-headings in bold font and be fluent.<|end|>
|
| 343 |
<|user|>
|
|
|
|
| 515 |
with mrow1_col2:
|
| 516 |
if "ESG_analysis_button_key" in st.session_state.results and st.session_state.results["ESG_analysis_button_key"]:
|
| 517 |
|
| 518 |
+
doc_retriever_ESG, query_engine = create_vector_database_ESG()
|
| 519 |
+
doc_retriever_financials, query_engine_financials = create_vector_database_financials()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
memory = ConversationBufferMemory(memory_key="chat_history", k=3, return_messages=True)
|
| 522 |
search = SerpAPIWrapper()
|
|
|
|
| 551 |
"""
|
| 552 |
)
|
| 553 |
|
| 554 |
+
financials_chain = (
|
| 555 |
+
{
|
| 556 |
+
"context": doc_retriever_financials,
|
| 557 |
+
# Lambda function now accepts one argument (even if unused)
|
| 558 |
+
"chat_history": lambda _: format_chat_history(memory.load_memory_variables({})["chat_history"]),
|
| 559 |
+
"question": RunnablePassthrough(),
|
| 560 |
+
}
|
| 561 |
+
| prompt_financials
|
| 562 |
+
| llm_tool
|
| 563 |
+
| StrOutputParser()
|
| 564 |
+
)
|
|
|
|
|
|
|
| 565 |
|
| 566 |
ESG_chain = (
|
| 567 |
{
|
|
|
|
| 582 |
description="Useful for answering questions about specific ESG figures, data and statistics.",
|
| 583 |
)
|
| 584 |
|
| 585 |
+
vector_query_tool_financials = Tool(
|
| 586 |
+
name="Vector Query Engine Financials",
|
| 587 |
+
func=lambda query: query_engine_financials.query(query), # Use query_engine to query the vector database
|
| 588 |
+
description="Useful for answering questions about specific financial figures, data and statistics.",
|
| 589 |
+
)
|
|
|
|
| 590 |
|
| 591 |
tools = [
|
| 592 |
Tool(
|
|
|
|
| 594 |
func=ESG_chain.invoke,
|
| 595 |
description="Useful for answering general questions about environmental, social, and governance (ESG) matters related to the company. ",
|
| 596 |
),
|
| 597 |
+
Tool(
|
| 598 |
+
name="Financials QA System",
|
| 599 |
+
func=financials_chain.invoke,
|
| 600 |
+
description="Useful for answering general questions about financial or operational information concerning the company.",
|
| 601 |
+
),
|
| 602 |
Tool(
|
| 603 |
name="Search Tool",
|
| 604 |
func=search.run,
|
| 605 |
description="Useful when other tools do not provide the answer.",
|
| 606 |
),
|
| 607 |
vector_query_tool_ESG,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
vector_query_tool_financials,
|
| 609 |
+
]
|
| 610 |
|
| 611 |
# Initialize the agent with LCEL tools and memory
|
| 612 |
agent = initialize_agent(
|