Spaces:
Sleeping
Sleeping
Vela
commited on
Commit
·
75115cd
1
Parent(s):
00f1bc6
enhanced graph
Browse files- .gitignore +2 -1
- app.py +4 -4
- application/services/{gemini_model.py → gemini_api_service.py} +29 -6
- application/services/mongo_db_service.py +2 -1
- application/tools/emission_data_extractor.py +1 -1
- application/tools/web_search_tools.py +3 -1
- main.py +161 -0
- pages/chatbot.py +10 -16
- pages/multiple_pdf_extractor.py +2 -2
.gitignore
CHANGED
|
@@ -3,4 +3,5 @@
|
|
| 3 |
data
|
| 4 |
__pycache__/
|
| 5 |
logs/
|
| 6 |
-
test.py
|
|
|
|
|
|
| 3 |
data
|
| 4 |
__pycache__/
|
| 5 |
logs/
|
| 6 |
+
test.py
|
| 7 |
+
reports/
|
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import os
|
| 3 |
-
from application.services import
|
| 4 |
from google.genai.errors import ClientError
|
| 5 |
from application.utils import logger
|
| 6 |
from application.schemas.response_schema import (
|
|
@@ -44,7 +44,7 @@ if st.session_state.pdf_file:
|
|
| 44 |
with col1:
|
| 45 |
if st.button(f"Generate {MODEL_1} Response"):
|
| 46 |
with st.spinner(f"Calling {MODEL_1}..."):
|
| 47 |
-
result =
|
| 48 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_1, file_name)
|
| 49 |
st.session_state[f"{MODEL_1}_result"] = result
|
| 50 |
if st.session_state[f"{MODEL_1}_result"]:
|
|
@@ -54,7 +54,7 @@ if st.session_state.pdf_file:
|
|
| 54 |
with col2:
|
| 55 |
if st.button(f"Generate {MODEL_2} Response"):
|
| 56 |
with st.spinner(f"Calling {MODEL_2}..."):
|
| 57 |
-
result =
|
| 58 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_2, file_name)
|
| 59 |
st.session_state[f"{MODEL_2}_result"] = result
|
| 60 |
if st.session_state[f"{MODEL_2}_result"]:
|
|
@@ -65,7 +65,7 @@ if st.session_state.pdf_file:
|
|
| 65 |
try:
|
| 66 |
if st.button(f"Generate {MODEL_3} Response"):
|
| 67 |
with st.spinner(f"Calling {MODEL_3}..."):
|
| 68 |
-
result =
|
| 69 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_3, file_name)
|
| 70 |
st.session_state[f"{MODEL_3}_result"] = result
|
| 71 |
except ClientError as e:
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import os
|
| 3 |
+
from application.services import gemini_api_service, streamlit_function
|
| 4 |
from google.genai.errors import ClientError
|
| 5 |
from application.utils import logger
|
| 6 |
from application.schemas.response_schema import (
|
|
|
|
| 44 |
with col1:
|
| 45 |
if st.button(f"Generate {MODEL_1} Response"):
|
| 46 |
with st.spinner(f"Calling {MODEL_1}..."):
|
| 47 |
+
result = gemini_api_service.extract_emissions_data_as_json(API_1 , MODEL_1, st.session_state.pdf_file[0],FULL_RESPONSE_SCHEMA)
|
| 48 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_1, file_name)
|
| 49 |
st.session_state[f"{MODEL_1}_result"] = result
|
| 50 |
if st.session_state[f"{MODEL_1}_result"]:
|
|
|
|
| 54 |
with col2:
|
| 55 |
if st.button(f"Generate {MODEL_2} Response"):
|
| 56 |
with st.spinner(f"Calling {MODEL_2}..."):
|
| 57 |
+
result = gemini_api_service.extract_emissions_data_as_json(API_2, MODEL_2, st.session_state.pdf_file[0],FULL_RESPONSE_SCHEMA)
|
| 58 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_2, file_name)
|
| 59 |
st.session_state[f"{MODEL_2}_result"] = result
|
| 60 |
if st.session_state[f"{MODEL_2}_result"]:
|
|
|
|
| 65 |
try:
|
| 66 |
if st.button(f"Generate {MODEL_3} Response"):
|
| 67 |
with st.spinner(f"Calling {MODEL_3}..."):
|
| 68 |
+
result = gemini_api_service.extract_emissions_data_as_json(API_3, MODEL_3, st.session_state.pdf_file[0], FULL_RESPONSE_SCHEMA)
|
| 69 |
excel_file = streamlit_function.export_results_to_excel(result, MODEL_3, file_name)
|
| 70 |
st.session_state[f"{MODEL_3}_result"] = result
|
| 71 |
except ClientError as e:
|
application/services/{gemini_model.py → gemini_api_service.py}
RENAMED
|
@@ -5,6 +5,8 @@ from typing import Optional, Dict, Union, IO, List, BinaryIO
|
|
| 5 |
from google import genai
|
| 6 |
from google.genai import types
|
| 7 |
from application.utils import logger
|
|
|
|
|
|
|
| 8 |
|
| 9 |
logger=logger.get_logger()
|
| 10 |
|
|
@@ -136,11 +138,11 @@ def upload_file(
|
|
| 136 |
config: Optional[Dict[str, str]] = None
|
| 137 |
) -> Optional[types.File]:
|
| 138 |
"""
|
| 139 |
-
Uploads a file to the Gemini API, handling
|
| 140 |
|
| 141 |
Args:
|
| 142 |
-
file (Union[str, IO[bytes]]):
|
| 143 |
-
file_name (Optional[str]): Name for the file. If None,
|
| 144 |
config (Optional[Dict[str, str]]): Extra config like 'mime_type'.
|
| 145 |
|
| 146 |
Returns:
|
|
@@ -150,8 +152,14 @@ def upload_file(
|
|
| 150 |
Exception: If upload fails.
|
| 151 |
"""
|
| 152 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
if not file_name:
|
| 154 |
-
if
|
|
|
|
|
|
|
| 155 |
file_name = os.path.basename(file)
|
| 156 |
elif hasattr(file, "name"):
|
| 157 |
file_name = os.path.basename(file.name)
|
|
@@ -164,17 +172,32 @@ def upload_file(
|
|
| 164 |
config.update({"name": sanitized_name, "mime_type": mime_type})
|
| 165 |
gemini_file_key = f"files/{sanitized_name}"
|
| 166 |
|
|
|
|
| 167 |
if gemini_file_key in get_files():
|
| 168 |
logger.info(f"File already exists on Gemini: {gemini_file_key}")
|
| 169 |
return client.files.get(name=gemini_file_key)
|
| 170 |
|
| 171 |
logger.info(f"Uploading file to Gemini: {gemini_file_key}")
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
if isinstance(file, str):
|
|
|
|
|
|
|
| 174 |
with open(file, "rb") as f:
|
| 175 |
return client.files.upload(file=f, config=config)
|
| 176 |
-
|
| 177 |
-
|
|
|
|
| 178 |
|
| 179 |
except Exception as e:
|
| 180 |
logger.error(f"Failed to upload file '{file_name}': {e}")
|
|
|
|
| 5 |
from google import genai
|
| 6 |
from google.genai import types
|
| 7 |
from application.utils import logger
|
| 8 |
+
import requests
|
| 9 |
+
import io
|
| 10 |
|
| 11 |
logger=logger.get_logger()
|
| 12 |
|
|
|
|
| 138 |
config: Optional[Dict[str, str]] = None
|
| 139 |
) -> Optional[types.File]:
|
| 140 |
"""
|
| 141 |
+
Uploads a file to the Gemini API, handling local file paths, binary streams, and URLs.
|
| 142 |
|
| 143 |
Args:
|
| 144 |
+
file (Union[str, IO[bytes]]): Local file path, URL, or binary file object.
|
| 145 |
+
file_name (Optional[str]): Name for the file. If None, tries to infer it from the source.
|
| 146 |
config (Optional[Dict[str, str]]): Extra config like 'mime_type'.
|
| 147 |
|
| 148 |
Returns:
|
|
|
|
| 152 |
Exception: If upload fails.
|
| 153 |
"""
|
| 154 |
try:
|
| 155 |
+
# Determine if input is a URL
|
| 156 |
+
is_url = isinstance(file, str) and file.startswith(('http://', 'https://'))
|
| 157 |
+
|
| 158 |
+
# Determine file name if not provided
|
| 159 |
if not file_name:
|
| 160 |
+
if is_url:
|
| 161 |
+
file_name = os.path.basename(file.split("?")[0]) # Remove query params
|
| 162 |
+
elif isinstance(file, str):
|
| 163 |
file_name = os.path.basename(file)
|
| 164 |
elif hasattr(file, "name"):
|
| 165 |
file_name = os.path.basename(file.name)
|
|
|
|
| 172 |
config.update({"name": sanitized_name, "mime_type": mime_type})
|
| 173 |
gemini_file_key = f"files/{sanitized_name}"
|
| 174 |
|
| 175 |
+
# Check if file already exists
|
| 176 |
if gemini_file_key in get_files():
|
| 177 |
logger.info(f"File already exists on Gemini: {gemini_file_key}")
|
| 178 |
return client.files.get(name=gemini_file_key)
|
| 179 |
|
| 180 |
logger.info(f"Uploading file to Gemini: {gemini_file_key}")
|
| 181 |
|
| 182 |
+
# Handle URL
|
| 183 |
+
if is_url:
|
| 184 |
+
headers = {
|
| 185 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
| 186 |
+
}
|
| 187 |
+
response = requests.get(file, headers=headers)
|
| 188 |
+
response.raise_for_status()
|
| 189 |
+
file_content = io.BytesIO(response.content)
|
| 190 |
+
return client.files.upload(file=file_content, config=config)
|
| 191 |
+
|
| 192 |
+
# Handle local file path
|
| 193 |
if isinstance(file, str):
|
| 194 |
+
if not os.path.isfile(file):
|
| 195 |
+
raise FileNotFoundError(f"Local file '{file}' does not exist.")
|
| 196 |
with open(file, "rb") as f:
|
| 197 |
return client.files.upload(file=f, config=config)
|
| 198 |
+
|
| 199 |
+
# Handle already opened binary file object
|
| 200 |
+
return client.files.upload(file=file, config=config)
|
| 201 |
|
| 202 |
except Exception as e:
|
| 203 |
logger.error(f"Failed to upload file '{file_name}': {e}")
|
application/services/mongo_db_service.py
CHANGED
|
@@ -84,4 +84,5 @@ def retrieve_documents(collection_name: str, query: Optional[Dict] = None) -> Li
|
|
| 84 |
logger.exception(f"An error occurred while retrieving documents: {str(e)}")
|
| 85 |
return []
|
| 86 |
|
| 87 |
-
# all_docs = retrieve_documents("Zalando")
|
|
|
|
|
|
| 84 |
logger.exception(f"An error occurred while retrieving documents: {str(e)}")
|
| 85 |
return []
|
| 86 |
|
| 87 |
+
# all_docs = retrieve_documents("Zalando")
|
| 88 |
+
# print(all_docs)
|
application/tools/emission_data_extractor.py
CHANGED
|
@@ -6,7 +6,7 @@ import requests
|
|
| 6 |
from google import genai
|
| 7 |
from google.genai import types
|
| 8 |
from application.utils.logger import get_logger
|
| 9 |
-
from application.services.
|
| 10 |
from application.services.mongo_db_service import store_document
|
| 11 |
from application.schemas.response_schema import GEMINI_GHG_PARAMETERS
|
| 12 |
from langchain_core.tools import tool
|
|
|
|
| 6 |
from google import genai
|
| 7 |
from google.genai import types
|
| 8 |
from application.utils.logger import get_logger
|
| 9 |
+
from application.services.gemini_api_service import upload_file
|
| 10 |
from application.services.mongo_db_service import store_document
|
| 11 |
from application.schemas.response_schema import GEMINI_GHG_PARAMETERS
|
| 12 |
from langchain_core.tools import tool
|
application/tools/web_search_tools.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import Literal
|
|
| 8 |
from duckduckgo_search import DDGS
|
| 9 |
from tavily import TavilyClient
|
| 10 |
from langchain_core.tools import tool
|
|
|
|
| 11 |
|
| 12 |
logger = get_logger()
|
| 13 |
load_dotenv()
|
|
@@ -54,7 +55,8 @@ def get_top_companies_from_web(query: str):
|
|
| 54 |
|
| 55 |
output = response.output_text
|
| 56 |
# logger.info(f"Raw Output: {output}")
|
| 57 |
-
parsed_list =
|
|
|
|
| 58 |
logger.info(f"Parsed List: {parsed_list}")
|
| 59 |
result = CompanyListResponse(companies=parsed_list)
|
| 60 |
return result
|
|
|
|
| 8 |
from duckduckgo_search import DDGS
|
| 9 |
from tavily import TavilyClient
|
| 10 |
from langchain_core.tools import tool
|
| 11 |
+
import ast
|
| 12 |
|
| 13 |
logger = get_logger()
|
| 14 |
load_dotenv()
|
|
|
|
| 55 |
|
| 56 |
output = response.output_text
|
| 57 |
# logger.info(f"Raw Output: {output}")
|
| 58 |
+
parsed_list = ast.literal_eval(output.strip())
|
| 59 |
+
# parsed_list = eval(output.strip())
|
| 60 |
logger.info(f"Parsed List: {parsed_list}")
|
| 61 |
result = CompanyListResponse(companies=parsed_list)
|
| 62 |
return result
|
main.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import operator
|
| 3 |
+
import functools
|
| 4 |
+
from typing import Annotated, Sequence, TypedDict, Union, Optional
|
| 5 |
+
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from langchain_openai import ChatOpenAI
|
| 8 |
+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
| 9 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 10 |
+
from langchain_core.runnables import Runnable
|
| 11 |
+
from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser
|
| 12 |
+
from langgraph.graph import StateGraph, END
|
| 13 |
+
|
| 14 |
+
from application.agents.scraper_agent import scraper_agent
|
| 15 |
+
from application.agents.extractor_agent import extractor_agent
|
| 16 |
+
from application.utils.logger import get_logger
|
| 17 |
+
|
| 18 |
+
load_dotenv()
|
| 19 |
+
logger = get_logger()
|
| 20 |
+
|
| 21 |
+
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
|
| 22 |
+
if not OPENAI_API_KEY:
|
| 23 |
+
logger.error("OPENAI_API_KEY is missing. Please set it in your environment variables.")
|
| 24 |
+
raise EnvironmentError("OPENAI_API_KEY not found in environment variables.")
|
| 25 |
+
|
| 26 |
+
MEMBERS = ["Scraper", "Extractor"]
|
| 27 |
+
OPTIONS = ["FINISH"] + MEMBERS
|
| 28 |
+
|
| 29 |
+
SUPERVISOR_SYSTEM_PROMPT = (
|
| 30 |
+
"You are a supervisor tasked with managing a conversation between the following workers: {members}. "
|
| 31 |
+
"Given the user's request and the previous messages, determine what to do next:\n"
|
| 32 |
+
"- If the user asks to search, find, or scrape data from the web, choose 'Scraper'.\n"
|
| 33 |
+
"- If the user asks to extract ESG emissions data from a file or PDF, choose 'Extractor'.\n"
|
| 34 |
+
"- If the task is complete, choose 'FINISH'.\n"
|
| 35 |
+
"- If the message is general conversation (like greetings, questions, thanks, chatting), directly respond with a message.\n"
|
| 36 |
+
"Each worker will perform its task and report back.\n"
|
| 37 |
+
"When you respond directly, make sure your message is friendly and helpful."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
FUNCTION_DEF = {
|
| 41 |
+
"name": "route_or_respond",
|
| 42 |
+
"description": "Select the next role OR respond directly.",
|
| 43 |
+
"parameters": {
|
| 44 |
+
"title": "RouteOrRespondSchema",
|
| 45 |
+
"type": "object",
|
| 46 |
+
"properties": {
|
| 47 |
+
"next": {
|
| 48 |
+
"title": "Next Worker",
|
| 49 |
+
"anyOf": [{"enum": OPTIONS}],
|
| 50 |
+
"description": "Choose next worker if needed."
|
| 51 |
+
},
|
| 52 |
+
"response": {
|
| 53 |
+
"title": "Supervisor Response",
|
| 54 |
+
"type": "string",
|
| 55 |
+
"description": "Respond directly if no worker action is needed."
|
| 56 |
+
}
|
| 57 |
+
},
|
| 58 |
+
"required": [],
|
| 59 |
+
},
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
class AgentState(TypedDict):
|
| 63 |
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
| 64 |
+
next: Optional[str]
|
| 65 |
+
response: Optional[str]
|
| 66 |
+
|
| 67 |
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
|
| 68 |
+
|
| 69 |
+
def agent_node(state: AgentState, agent: Runnable, name: str) -> dict:
|
| 70 |
+
logger.info(f"Agent {name} invoked.")
|
| 71 |
+
try:
|
| 72 |
+
result = agent.invoke(state)
|
| 73 |
+
logger.info(f"Agent {name} completed successfully.")
|
| 74 |
+
return {"messages": [HumanMessage(content=result["output"], name=name)]}
|
| 75 |
+
except Exception as e:
|
| 76 |
+
logger.exception(f"Agent {name} failed with error: {str(e)}")
|
| 77 |
+
raise
|
| 78 |
+
|
| 79 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 80 |
+
[
|
| 81 |
+
("system", SUPERVISOR_SYSTEM_PROMPT),
|
| 82 |
+
MessagesPlaceholder(variable_name="messages"),
|
| 83 |
+
(
|
| 84 |
+
"system",
|
| 85 |
+
"Based on the conversation, either select next worker (one of: {options}) or respond directly with a message.",
|
| 86 |
+
),
|
| 87 |
+
]
|
| 88 |
+
).partial(options=str(OPTIONS), members=", ".join(MEMBERS))
|
| 89 |
+
|
| 90 |
+
# supervisor_chain = (
|
| 91 |
+
# prompt
|
| 92 |
+
# | llm.bind_functions(functions=[FUNCTION_DEF], function_call="route_or_respond")
|
| 93 |
+
# | JsonOutputFunctionsParser()
|
| 94 |
+
# )
|
| 95 |
+
|
| 96 |
+
supervisor_chain = (
|
| 97 |
+
prompt
|
| 98 |
+
| llm.bind_tools(tools=[FUNCTION_DEF], tool_choice="route_or_respond")
|
| 99 |
+
| JsonOutputKeyToolsParser(key_name="route_or_respond")
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def supervisor_node(state: AgentState) -> AgentState:
|
| 103 |
+
logger.info("Supervisor invoked.")
|
| 104 |
+
output = supervisor_chain.invoke(state)
|
| 105 |
+
logger.info(f"Supervisor output: {output}")
|
| 106 |
+
|
| 107 |
+
if isinstance(output, list) and len(output) > 0:
|
| 108 |
+
output = output[0]
|
| 109 |
+
|
| 110 |
+
next_step = output.get("next")
|
| 111 |
+
response = output.get("response")
|
| 112 |
+
|
| 113 |
+
if not next_step and not response:
|
| 114 |
+
raise ValueError(f"Supervisor produced invalid output: {output}")
|
| 115 |
+
|
| 116 |
+
return {
|
| 117 |
+
"messages": state["messages"],
|
| 118 |
+
"next": next_step,
|
| 119 |
+
"response": response,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
workflow = StateGraph(AgentState)
|
| 123 |
+
|
| 124 |
+
workflow.add_node("Scraper", functools.partial(agent_node, agent=scraper_agent, name="Scraper"))
|
| 125 |
+
workflow.add_node("Extractor", functools.partial(agent_node, agent=extractor_agent, name="Extractor"))
|
| 126 |
+
workflow.add_node("supervisor", supervisor_node)
|
| 127 |
+
# workflow.add_node("supervisor", supervisor_chain)
|
| 128 |
+
workflow.add_node("supervisor_response", lambda state: {"messages": [AIMessage(content=state["response"], name="Supervisor")]})
|
| 129 |
+
|
| 130 |
+
for member in MEMBERS:
|
| 131 |
+
workflow.add_edge(member, "supervisor")
|
| 132 |
+
|
| 133 |
+
def router(state: AgentState):
|
| 134 |
+
if state.get("response"):
|
| 135 |
+
return "supervisor_response"
|
| 136 |
+
return state.get("next")
|
| 137 |
+
|
| 138 |
+
conditional_map = {member: member for member in MEMBERS}
|
| 139 |
+
conditional_map["FINISH"] = END
|
| 140 |
+
conditional_map["supervisor_response"] = "supervisor_response"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
workflow.add_conditional_edges("supervisor", router, conditional_map)
|
| 144 |
+
|
| 145 |
+
workflow.set_entry_point("supervisor")
|
| 146 |
+
|
| 147 |
+
graph = workflow.compile()
|
| 148 |
+
|
| 149 |
+
# # === Example Run ===
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
logger.info("Starting the graph execution...")
|
| 152 |
+
initial_message = HumanMessage(content="Can you get zalando pdf link")
|
| 153 |
+
input_state = {"messages": [initial_message]}
|
| 154 |
+
|
| 155 |
+
for step in graph.stream(input_state):
|
| 156 |
+
if "__end__" not in step:
|
| 157 |
+
logger.info(f"Graph Step Output: {step}")
|
| 158 |
+
print(step)
|
| 159 |
+
print("----")
|
| 160 |
+
|
| 161 |
+
logger.info("Graph execution completed.")
|
pages/chatbot.py
CHANGED
|
@@ -2,16 +2,10 @@ import streamlit as st
|
|
| 2 |
from dotenv import load_dotenv
|
| 3 |
|
| 4 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
| 5 |
-
# from application.agents.scraper_agent import app
|
| 6 |
-
# from application.utils.logger import get_logger
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from application.utils.logger import get_logger
|
| 12 |
-
except ImportError as e:
|
| 13 |
-
st.error(f"Import Error: Ensure backend modules are accessible. Details: {e}")
|
| 14 |
-
st.stop()
|
| 15 |
|
| 16 |
logger = get_logger()
|
| 17 |
|
|
@@ -19,8 +13,8 @@ st.set_page_config(page_title="Sustainability AI Assistant", layout="wide")
|
|
| 19 |
st.title("♻️ Sustainability Report AI Assistant")
|
| 20 |
st.caption(
|
| 21 |
"Ask about sustainability reports by company or industry! "
|
| 22 |
-
"(e.g., 'Get report for Apple', 'Download report for Microsoft 2023', "
|
| 23 |
-
"'Find reports for top 3 airline companies', 'Download this pdf <link>')"
|
| 24 |
)
|
| 25 |
|
| 26 |
load_dotenv()
|
|
@@ -34,10 +28,10 @@ def initialize_chat_history():
|
|
| 34 |
def display_chat_history():
|
| 35 |
"""Render previous chat messages."""
|
| 36 |
for message in st.session_state.messages:
|
| 37 |
-
if isinstance(message, SystemMessage):
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
with st.chat_message("user"):
|
| 42 |
st.markdown(message.content)
|
| 43 |
elif isinstance(message, AIMessage):
|
|
@@ -77,10 +71,10 @@ def display_last_ai_response():
|
|
| 77 |
logger.warning("No AI message found in the final output.")
|
| 78 |
|
| 79 |
initialize_chat_history()
|
| 80 |
-
display_chat_history()
|
| 81 |
|
| 82 |
if user_query := st.chat_input("Your question about sustainability reports..."):
|
| 83 |
logger.info(f"User input received: {user_query}")
|
|
|
|
| 84 |
|
| 85 |
st.session_state.messages.append(HumanMessage(content=user_query))
|
| 86 |
|
|
|
|
| 2 |
from dotenv import load_dotenv
|
| 3 |
|
| 4 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
from application.agents.scraper_agent import app
|
| 7 |
+
from main import graph
|
| 8 |
+
from application.utils.logger import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
logger = get_logger()
|
| 11 |
|
|
|
|
| 13 |
st.title("♻️ Sustainability Report AI Assistant")
|
| 14 |
st.caption(
|
| 15 |
"Ask about sustainability reports by company or industry! "
|
| 16 |
+
"(e.g., 'Get sustainability report for Apple', 'Download sustainability report for Microsoft 2023', "
|
| 17 |
+
"'Find sustainability reports for top 3 airline companies', 'Download this pdf <link>')"
|
| 18 |
)
|
| 19 |
|
| 20 |
load_dotenv()
|
|
|
|
| 28 |
def display_chat_history():
|
| 29 |
"""Render previous chat messages."""
|
| 30 |
for message in st.session_state.messages:
|
| 31 |
+
# if isinstance(message, SystemMessage):
|
| 32 |
+
# # st.info(f"System: {message.content}")
|
| 33 |
+
# pass
|
| 34 |
+
if isinstance(message, HumanMessage):
|
| 35 |
with st.chat_message("user"):
|
| 36 |
st.markdown(message.content)
|
| 37 |
elif isinstance(message, AIMessage):
|
|
|
|
| 71 |
logger.warning("No AI message found in the final output.")
|
| 72 |
|
| 73 |
initialize_chat_history()
|
|
|
|
| 74 |
|
| 75 |
if user_query := st.chat_input("Your question about sustainability reports..."):
|
| 76 |
logger.info(f"User input received: {user_query}")
|
| 77 |
+
display_chat_history()
|
| 78 |
|
| 79 |
st.session_state.messages.append(HumanMessage(content=user_query))
|
| 80 |
|
pages/multiple_pdf_extractor.py
CHANGED
|
@@ -6,7 +6,7 @@ from application.schemas.response_schema import (
|
|
| 6 |
GEMINI_GOVERNANCE_PARAMETERS, GEMINI_MATERIALITY_PARAMETERS,
|
| 7 |
GEMINI_NET_ZERO_INTERVENTION_PARAMETERS
|
| 8 |
)
|
| 9 |
-
from application.services import
|
| 10 |
from application.utils import logger
|
| 11 |
|
| 12 |
logger = logger.get_logger()
|
|
@@ -58,7 +58,7 @@ if st.session_state.uploaded_files:
|
|
| 58 |
all_results = {}
|
| 59 |
|
| 60 |
for label, schema in RESPONSE_SCHEMAS.items():
|
| 61 |
-
result =
|
| 62 |
streamlit_function.export_results_to_excel(result, sheet_name=selected_model, filename=file_name, column=label)
|
| 63 |
all_results[label] = result
|
| 64 |
st.session_state[result_key] = all_results
|
|
|
|
| 6 |
GEMINI_GOVERNANCE_PARAMETERS, GEMINI_MATERIALITY_PARAMETERS,
|
| 7 |
GEMINI_NET_ZERO_INTERVENTION_PARAMETERS
|
| 8 |
)
|
| 9 |
+
from application.services import gemini_api_service, streamlit_function
|
| 10 |
from application.utils import logger
|
| 11 |
|
| 12 |
logger = logger.get_logger()
|
|
|
|
| 58 |
all_results = {}
|
| 59 |
|
| 60 |
for label, schema in RESPONSE_SCHEMAS.items():
|
| 61 |
+
result = gemini_api_service.extract_emissions_data_as_json("gemini", selected_model, pdf_file, schema)
|
| 62 |
streamlit_function.export_results_to_excel(result, sheet_name=selected_model, filename=file_name, column=label)
|
| 63 |
all_results[label] = result
|
| 64 |
st.session_state[result_key] = all_results
|