Spaces:
Sleeping
Sleeping
feat(add_deep_rishsearch): Add the graph and call functions for deep rish search
Browse files- requirements.txt +2 -1
- src/graph/__pycache__/graph_builder.cpython-312.pyc +0 -0
- src/graph/__pycache__/state_vector_nodes.cpython-312.pyc +0 -0
- src/graph/data_analyst_prompts.csv +0 -0
- src/graph/graph_builder.py +48 -6
- src/graph/state_vector_nodes.py +166 -6
- src/llm/__pycache__/llm_setup.cpython-312.pyc +0 -0
- src/llm/llm_setup.py +5 -5
- src/streamlit_app.py +48 -24
requirements.txt
CHANGED
|
@@ -6,10 +6,11 @@ langgraph
|
|
| 6 |
langchain_core
|
| 7 |
langchain_community
|
| 8 |
langchain_huggingface
|
|
|
|
| 9 |
streamlit
|
| 10 |
transformers[torch]
|
| 11 |
langchain_openai
|
| 12 |
-
|
| 13 |
tf-keras
|
| 14 |
tensorflow
|
| 15 |
torch
|
|
|
|
| 6 |
langchain_core
|
| 7 |
langchain_community
|
| 8 |
langchain_huggingface
|
| 9 |
+
langgraph-prebuilt
|
| 10 |
streamlit
|
| 11 |
transformers[torch]
|
| 12 |
langchain_openai
|
| 13 |
+
langchain_google_genai
|
| 14 |
tf-keras
|
| 15 |
tensorflow
|
| 16 |
torch
|
src/graph/__pycache__/graph_builder.cpython-312.pyc
CHANGED
|
Binary files a/src/graph/__pycache__/graph_builder.cpython-312.pyc and b/src/graph/__pycache__/graph_builder.cpython-312.pyc differ
|
|
|
src/graph/__pycache__/state_vector_nodes.cpython-312.pyc
CHANGED
|
Binary files a/src/graph/__pycache__/state_vector_nodes.cpython-312.pyc and b/src/graph/__pycache__/state_vector_nodes.cpython-312.pyc differ
|
|
|
src/graph/data_analyst_prompts.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/graph/graph_builder.py
CHANGED
|
@@ -1,17 +1,21 @@
|
|
| 1 |
from langgraph.graph.state import CompiledStateGraph
|
| 2 |
from langgraph.graph import StateGraph,START, END
|
| 3 |
from state.state import StateVector
|
| 4 |
-
from graph.state_vector_nodes import question_model
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class BuildGraphOptions():
|
| 7 |
def __init__(self,question_model):
|
| 8 |
self.graph_builder=StateGraph(StateVector) #For invocation
|
| 9 |
self.questionmodel=question_model
|
| 10 |
-
def
|
|
|
|
|
|
|
| 11 |
#workflow = self.graph_builder
|
| 12 |
|
| 13 |
self.graph_builder.add_node("check_inputs", self.questionmodel.check_inputs)
|
| 14 |
-
self.graph_builder.add_node("
|
| 15 |
self.graph_builder.add_node("generate_questions", self.questionmodel.generate_questions)
|
| 16 |
# Set entry point
|
| 17 |
self.graph_builder.set_entry_point("check_inputs")
|
|
@@ -21,10 +25,48 @@ class BuildGraphOptions():
|
|
| 21 |
"check_inputs",
|
| 22 |
self.questionmodel.should_continue,
|
| 23 |
{
|
| 24 |
-
"
|
| 25 |
"terminate": END
|
| 26 |
}
|
| 27 |
)
|
| 28 |
-
self.graph_builder.add_edge("
|
| 29 |
self.graph_builder.add_edge("generate_questions", END)
|
| 30 |
-
return self.graph_builder.compile()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from langgraph.graph.state import CompiledStateGraph
|
| 2 |
from langgraph.graph import StateGraph,START, END
|
| 3 |
from state.state import StateVector
|
| 4 |
+
from graph.state_vector_nodes import question_model,research_model
|
| 5 |
+
from langgraph.prebuilt import tools_condition
|
| 6 |
+
from langgraph.prebuilt import ToolNode
|
| 7 |
|
| 8 |
class BuildGraphOptions():
|
| 9 |
def __init__(self,question_model):
|
| 10 |
self.graph_builder=StateGraph(StateVector) #For invocation
|
| 11 |
self.questionmodel=question_model
|
| 12 |
+
def load_research_model(self,research_model):
|
| 13 |
+
self.research_model=research_model
|
| 14 |
+
def build_question_graph(self):
|
| 15 |
#workflow = self.graph_builder
|
| 16 |
|
| 17 |
self.graph_builder.add_node("check_inputs", self.questionmodel.check_inputs)
|
| 18 |
+
self.graph_builder.add_node("create_question_prompt_template", self.questionmodel.create_question_prompt_template)
|
| 19 |
self.graph_builder.add_node("generate_questions", self.questionmodel.generate_questions)
|
| 20 |
# Set entry point
|
| 21 |
self.graph_builder.set_entry_point("check_inputs")
|
|
|
|
| 25 |
"check_inputs",
|
| 26 |
self.questionmodel.should_continue,
|
| 27 |
{
|
| 28 |
+
"create_question_prompt_template": "create_question_prompt_template",
|
| 29 |
"terminate": END
|
| 30 |
}
|
| 31 |
)
|
| 32 |
+
self.graph_builder.add_edge("create_question_prompt_template", "generate_questions")
|
| 33 |
self.graph_builder.add_edge("generate_questions", END)
|
| 34 |
+
return self.graph_builder.compile()
|
| 35 |
+
def build_research_graph(self,research_model:research_model):
|
| 36 |
+
self.graph_builder.add_node("check_inputs", self.questionmodel.check_inputs)
|
| 37 |
+
self.graph_builder.add_node("create_question_prompt_template", self.questionmodel.create_question_prompt_template)
|
| 38 |
+
self.graph_builder.add_node("generate_questions", self.questionmodel.generate_questions)
|
| 39 |
+
self.graph_builder.add_node("create_prompt_template", research_model.create_prompt_template)
|
| 40 |
+
|
| 41 |
+
# Set entry point
|
| 42 |
+
self.graph_builder.set_entry_point("check_inputs")
|
| 43 |
+
# Add conditional edges
|
| 44 |
+
|
| 45 |
+
self.graph_builder.add_conditional_edges(
|
| 46 |
+
"check_inputs",
|
| 47 |
+
self.questionmodel.should_continue,
|
| 48 |
+
{
|
| 49 |
+
"create_question_prompt_template": "create_question_prompt_template",
|
| 50 |
+
"terminate": END
|
| 51 |
+
}
|
| 52 |
+
)
|
| 53 |
+
self.graph_builder.add_edge("create_question_prompt_template", "generate_questions")
|
| 54 |
+
self.graph_builder.add_node(research_model.tool_calling_llm,
|
| 55 |
+
name="tool_calling_llm",
|
| 56 |
+
description="Invoke the LLM with curated questions to answer.",)
|
| 57 |
+
#llm_with_tools,tools=research_model.tool_calling_agent()
|
| 58 |
+
self.graph_builder.add_node("tools", ToolNode(research_model.tools))
|
| 59 |
+
self.graph_builder.add_node(research_model.summary_answer,name="summary_answer",
|
| 60 |
+
description="Summarize the answer from the LLM using the information gathered from the tools.",)
|
| 61 |
+
self.graph_builder.add_edge("generate_questions","create_prompt_template")
|
| 62 |
+
self.graph_builder.add_edge("create_prompt_template","tool_calling_llm")
|
| 63 |
+
|
| 64 |
+
self.graph_builder.add_conditional_edges(
|
| 65 |
+
"tool_calling_llm",
|
| 66 |
+
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
|
| 67 |
+
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to Summary answer with no retrieved docs
|
| 68 |
+
tools_condition,
|
| 69 |
+
)
|
| 70 |
+
self.graph_builder.add_edge("tools", "summary_answer")
|
| 71 |
+
self.graph_builder.add_edge("summary_answer", END)
|
| 72 |
+
return self.graph_builder.compile()
|
src/graph/state_vector_nodes.py
CHANGED
|
@@ -1,19 +1,29 @@
|
|
| 1 |
-
|
| 2 |
-
from langchain_core.
|
|
|
|
|
|
|
|
|
|
| 3 |
from state.state import StateVector
|
| 4 |
from streamlitui.constants import *
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import tensorflow as tf
|
| 8 |
import re
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def __init__(self,loaded_tokenizer,loaded_model, llm, df_keys):
|
| 11 |
#self.state=StateVector
|
| 12 |
self.tokenizer=loaded_tokenizer
|
| 13 |
self.distilbert_model=loaded_model
|
| 14 |
self.genai_model=llm
|
| 15 |
self.df_keys=df_keys
|
| 16 |
-
def
|
| 17 |
"""
|
| 18 |
Creates a prompt template based on the state vector.
|
| 19 |
"""
|
|
@@ -24,7 +34,7 @@ class question_model():
|
|
| 24 |
for topic, keywords in state['topic_kw'].items():
|
| 25 |
state['messages'].append(SystemMessage(content=f"For the UN SDG Goal: {topic}\n. \
|
| 26 |
Use the following keywords : {', '.join(keywords)}. Generate questions related to the topic in the country of {state['country']} using these keywords."))
|
| 27 |
-
state['messages'].append(AIMessage(content="Based on the provided information, here is an enhanced
|
| 28 |
|
| 29 |
return state
|
| 30 |
|
|
@@ -73,7 +83,7 @@ class question_model():
|
|
| 73 |
"""Determine whether to continue to prompt creation or terminate"""
|
| 74 |
if not state.get('topic') or not state.get('topic_kw'):
|
| 75 |
return "terminate"
|
| 76 |
-
return "
|
| 77 |
def generate_questions(self, state:StateVector) -> StateVector:
|
| 78 |
"""
|
| 79 |
Generates questions based on the provided topics and keywords.
|
|
@@ -90,3 +100,153 @@ class question_model():
|
|
| 90 |
#ai_response="\n".join(state['questions'])
|
| 91 |
#state['messages'].append(AIMessage(content="Generated questions: "+ai_response))
|
| 92 |
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from langchain_core.messages import SystemMessage,AIMessage,HumanMessage,ToolMessage
|
| 3 |
+
from langchain_core.output_parsers import NumberedListOutputParser,JsonOutputParser
|
| 4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 5 |
+
|
| 6 |
from state.state import StateVector
|
| 7 |
from streamlitui.constants import *
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
| 10 |
import tensorflow as tf
|
| 11 |
import re
|
| 12 |
+
from langchain_openai import ChatOpenAI
|
| 13 |
+
from langchain_community.tools.semanticscholar.tool import SemanticScholarQueryRun
|
| 14 |
+
from langchain_community.utilities.semanticscholar import SemanticScholarAPIWrapper
|
| 15 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 16 |
+
import pandas as pd
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class question_model:
|
| 20 |
def __init__(self,loaded_tokenizer,loaded_model, llm, df_keys):
|
| 21 |
#self.state=StateVector
|
| 22 |
self.tokenizer=loaded_tokenizer
|
| 23 |
self.distilbert_model=loaded_model
|
| 24 |
self.genai_model=llm
|
| 25 |
self.df_keys=df_keys
|
| 26 |
+
def create_question_prompt_template(self, state:StateVector) -> StateVector:
|
| 27 |
"""
|
| 28 |
Creates a prompt template based on the state vector.
|
| 29 |
"""
|
|
|
|
| 34 |
for topic, keywords in state['topic_kw'].items():
|
| 35 |
state['messages'].append(SystemMessage(content=f"For the UN SDG Goal: {topic}\n. \
|
| 36 |
Use the following keywords : {', '.join(keywords)}. Generate questions related to the topic in the country of {state['country']} using these keywords."))
|
| 37 |
+
state['messages'].append(AIMessage(content="Based on the provided information, here is an enhanced list of the question: \n"))
|
| 38 |
|
| 39 |
return state
|
| 40 |
|
|
|
|
| 83 |
"""Determine whether to continue to prompt creation or terminate"""
|
| 84 |
if not state.get('topic') or not state.get('topic_kw'):
|
| 85 |
return "terminate"
|
| 86 |
+
return "create_question_prompt_template"
|
| 87 |
def generate_questions(self, state:StateVector) -> StateVector:
|
| 88 |
"""
|
| 89 |
Generates questions based on the provided topics and keywords.
|
|
|
|
| 100 |
#ai_response="\n".join(state['questions'])
|
| 101 |
#state['messages'].append(AIMessage(content="Generated questions: "+ai_response))
|
| 102 |
return state
|
| 103 |
+
|
| 104 |
+
class research_model:
|
| 105 |
+
def __init__(self,llm):
|
| 106 |
+
self.llm=llm
|
| 107 |
+
self.local_analysis_file='src/graph/data_analyst_prompts.csv'
|
| 108 |
+
self.tool_names=["direct_semantic_scholar_query", "direct_tavily_search" ]
|
| 109 |
+
semantic_scholar_tool = SemanticScholarQueryRun(
|
| 110 |
+
api_wrapper=SemanticScholarAPIWrapper()
|
| 111 |
+
)
|
| 112 |
+
self.tools=[semantic_scholar_tool,self.direct_tavily_search]
|
| 113 |
+
# Bind the tool to the LLM
|
| 114 |
+
self.llm_with_tools = self.llm.bind_tools(self.tools)
|
| 115 |
+
#self.tavily_api_key=tavily_api_key
|
| 116 |
+
|
| 117 |
+
def direct_semantic_scholar_query(self,query: str):
|
| 118 |
+
|
| 119 |
+
"""Direct invocation of SemanticScholarQueryRun without agent"""
|
| 120 |
+
|
| 121 |
+
# Create the tool directly
|
| 122 |
+
tool = SemanticScholarQueryRun(
|
| 123 |
+
api_wrapper=SemanticScholarAPIWrapper()
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Invoke the tool directly
|
| 127 |
+
result = tool.invoke(query, k=10, output_parser=JsonOutputParser(), fields=["paperId","title","authors", "url","abstract","year","paperId"],sort="year")
|
| 128 |
+
|
| 129 |
+
return result
|
| 130 |
+
|
| 131 |
+
def direct_tavily_search(self,query: str):
|
| 132 |
+
"""Direct invocation of TavilySearchResults without agent"""
|
| 133 |
+
# Create the tool directly
|
| 134 |
+
tavily = TavilySearchResults()
|
| 135 |
+
result = tavily.invoke(query, max_results=5, include_answer=True, include_snippet=True, include_source=True)
|
| 136 |
+
response=""
|
| 137 |
+
for r in result:
|
| 138 |
+
response +="Found a webpage: %s at %s" %(r['title'], r['url'])
|
| 139 |
+
response +="Summary of the page: %s" %r['content']
|
| 140 |
+
response +="Relevance score: %s" %r['score']
|
| 141 |
+
return response
|
| 142 |
+
def data_analysis(self,state:StateVector):
|
| 143 |
+
df_analyst=pd.read_csv(self.local_analysis_file)
|
| 144 |
+
analysis_prompt=[]
|
| 145 |
+
topics=state['topic']
|
| 146 |
+
for t in topics:
|
| 147 |
+
Goal_Number=t[0]
|
| 148 |
+
df_analyst=df_analyst[df_analyst['country']==state['country']]
|
| 149 |
+
df_analyst['goal_number']=df_analyst['goal_number'].astype(int)
|
| 150 |
+
df_analyst=df_analyst[df_analyst['goal_number']==Goal_Number]
|
| 151 |
+
#print(df_analyst.head())
|
| 152 |
+
|
| 153 |
+
if df_analyst.shape[0]>0:
|
| 154 |
+
analysis_prompt.extend(df_analyst['analysis_prompt'].to_list())
|
| 155 |
+
return "\n".join(analysis_prompt)
|
| 156 |
+
|
| 157 |
+
def create_prompt_template(self,state:StateVector) -> ChatPromptTemplate:
|
| 158 |
+
"""
|
| 159 |
+
Creates a prompt template based on the provided questions.
|
| 160 |
+
"""
|
| 161 |
+
topic_string = ", ".join(f"{name}" for num, name in state['topic'])
|
| 162 |
+
keywords=[]
|
| 163 |
+
kw_string=''
|
| 164 |
+
for i,v in state['topic_kw'].items():
|
| 165 |
+
keywords.append(",".join(v))
|
| 166 |
+
kw_string += f" with keywords: {', '.join(keywords)}"
|
| 167 |
+
messages = [
|
| 168 |
+
SystemMessage(content= f"You are an AI assistant that helps users find information about the Sustainable Development Goal: {topic_string}.\
|
| 169 |
+
Your task is to answer questions related to this goal using the provided tools with toolNames: {self.tool_names}\
|
| 170 |
+
You will be provided with a list of questions to answer below: \
|
| 171 |
+
questions = {state["questions"]} "),
|
| 172 |
+
|
| 173 |
+
#AIMessage(content="Using publications on Semantic Scholar and my own reference data, I will answer the questions related to the Sustainable Development Goal: %s." % topic),
|
| 174 |
+
SystemMessage(content=f"Search for recent papers on {kw_string} in {state['country']}."),
|
| 175 |
+
SystemMessage(content=f"Search for recent news on {kw_string} in {state['country']}."),
|
| 176 |
+
SystemMessage(content=f"Search the internet for webpages on {kw_string} in {state['country']}."),
|
| 177 |
+
#HumanMessage(content="Please provide a comprehensive answer to the questions based on the information gathered from the tools.")
|
| 178 |
+
]
|
| 179 |
+
state['messages'] = messages
|
| 180 |
+
return state
|
| 181 |
+
|
| 182 |
+
def tool_calling_agent(self):
|
| 183 |
+
"""Show how to bind the tool to LLM using tool calling"""
|
| 184 |
+
|
| 185 |
+
# Initialize LLM
|
| 186 |
+
'''
|
| 187 |
+
llm = ChatOpenAI(
|
| 188 |
+
temperature=0.1,
|
| 189 |
+
model_name="gpt-4o-mini",
|
| 190 |
+
openai_api_key=openai_api_key
|
| 191 |
+
)
|
| 192 |
+
'''
|
| 193 |
+
# Create the tool
|
| 194 |
+
semantic_scholar_tool = SemanticScholarQueryRun(
|
| 195 |
+
api_wrapper=SemanticScholarAPIWrapper()
|
| 196 |
+
)
|
| 197 |
+
self.tools=[semantic_scholar_tool,self.direct_tavily_search]
|
| 198 |
+
# Bind the tool to the LLM
|
| 199 |
+
llm_with_tools = self.llm.bind_tools(tools)
|
| 200 |
+
|
| 201 |
+
return llm_with_tools,tools
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def tool_calling_llm(self,state:StateVector):
|
| 205 |
+
return {"messages":[self.llm_with_tools.invoke(state["messages"])]}
|
| 206 |
+
|
| 207 |
+
def summary_answer(self,state:StateVector)->StateVector:
|
| 208 |
+
"""
|
| 209 |
+
Function to summarize the answer from the LLM.
|
| 210 |
+
This is a placeholder function that can be extended to include more complex summarization.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
initial_system_message= state["messages"][0] # This is the system message that sets the context for the LLM with the listed questions
|
| 214 |
+
initial_system_message.content += "Please provide a comprehensive answer to the questions. \n"
|
| 215 |
+
|
| 216 |
+
tool_messages = [msg for msg in state["messages"] if isinstance(msg, ToolMessage)]
|
| 217 |
+
augmented_data=""
|
| 218 |
+
if tool_messages:
|
| 219 |
+
initial_system_message.content += "Use the following information gathered from the tools as reference information: \n"
|
| 220 |
+
|
| 221 |
+
for tool_msg in tool_messages:
|
| 222 |
+
print(tool_msg.content, type(tool_msg.content))
|
| 223 |
+
Label_Source=""
|
| 224 |
+
if 'semanticscholar' in tool_msg.name.lower():
|
| 225 |
+
Label_Source="(Source: Scholarly Publication Abstracts from Semantic Scholar)"
|
| 226 |
+
augmented_data+= f"{tool_msg.content}\n"
|
| 227 |
+
elif 'tavily' in tool_msg.name.lower():
|
| 228 |
+
Label_Source="(Source: News Search Results)"
|
| 229 |
+
augmented_data += f"{tool_msg.content}\n"
|
| 230 |
+
else:
|
| 231 |
+
print("Unknown Tool Call")
|
| 232 |
+
|
| 233 |
+
initial_system_message.content += f"{Label_Source} \n {augmented_data}\n"
|
| 234 |
+
analysis_prompt=self.data_analysis(state)
|
| 235 |
+
initial_system_message.content+=analysis_prompt
|
| 236 |
+
initial_system_message.content+="\n Assess if the resources indicate a general positive or negative trend and grade progress\
|
| 237 |
+
from 0-10 where 0 is very negative and 10 is very positive.\n"
|
| 238 |
+
initial_system_message.content+="\n Provide detailed answers to the questions and a list of references used."
|
| 239 |
+
state["messages"].append(initial_system_message)
|
| 240 |
+
'''
|
| 241 |
+
llm = ChatOpenAI(
|
| 242 |
+
temperature=0.4,
|
| 243 |
+
model_name="gpt-4o",
|
| 244 |
+
openai_api_key=openai_api_key
|
| 245 |
+
)
|
| 246 |
+
'''
|
| 247 |
+
#llm=ChatGoogleGenerativeAI(model='gemini-2.5-pro',google_api_key=google_api_key,temperature=0.3)
|
| 248 |
+
|
| 249 |
+
#print(state["messages"][-1].content)
|
| 250 |
+
airesponse = self.llm.invoke(state["messages"][-1].content)
|
| 251 |
+
# For simplicity, we just return the messages as they are
|
| 252 |
+
return {"messages": [airesponse]}
|
src/llm/__pycache__/llm_setup.cpython-312.pyc
CHANGED
|
Binary files a/src/llm/__pycache__/llm_setup.cpython-312.pyc and b/src/llm/__pycache__/llm_setup.cpython-312.pyc differ
|
|
|
src/llm/llm_setup.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import os
|
| 2 |
import streamlit as sl
|
| 3 |
from langchain_openai import ChatOpenAI
|
| 4 |
-
|
|
|
|
| 5 |
class ModelSelection:
|
| 6 |
def __init__(self,user_contols_input):
|
| 7 |
self.user_controls_input=user_contols_input
|
|
@@ -16,16 +17,15 @@ class ModelSelection:
|
|
| 16 |
if selected_model=='OpenAI':
|
| 17 |
try:
|
| 18 |
llm = ChatOpenAI(
|
| 19 |
-
temperature=0.
|
| 20 |
-
model_name="gpt-
|
| 21 |
openai_api_key=gen_api_key
|
| 22 |
)
|
| 23 |
return llm
|
| 24 |
except Exception as e:
|
| 25 |
raise ValueError(f"Error Ocuured With Exception : {e}")
|
| 26 |
else:
|
| 27 |
-
|
| 28 |
-
llm = genai.GenerativeModel('gemini-2.5-pro')
|
| 29 |
return llm
|
| 30 |
|
| 31 |
|
|
|
|
| 1 |
import os
|
| 2 |
import streamlit as sl
|
| 3 |
from langchain_openai import ChatOpenAI
|
| 4 |
+
|
| 5 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 6 |
class ModelSelection:
|
| 7 |
def __init__(self,user_contols_input):
|
| 8 |
self.user_controls_input=user_contols_input
|
|
|
|
| 17 |
if selected_model=='OpenAI':
|
| 18 |
try:
|
| 19 |
llm = ChatOpenAI(
|
| 20 |
+
temperature=0.1,
|
| 21 |
+
model_name="gpt-4o-mini",
|
| 22 |
openai_api_key=gen_api_key
|
| 23 |
)
|
| 24 |
return llm
|
| 25 |
except Exception as e:
|
| 26 |
raise ValueError(f"Error Ocuured With Exception : {e}")
|
| 27 |
else:
|
| 28 |
+
llm = ChatGoogleGenerativeAI(model='gemini-2.5-flash',google_api_key=gen_api_key,temperature=0.1)
|
|
|
|
| 29 |
return llm
|
| 30 |
|
| 31 |
|
src/streamlit_app.py
CHANGED
|
@@ -5,11 +5,12 @@ from typing import List, Optional
|
|
| 5 |
from transformers import DistilBertTokenizerFast, TFDistilBertForSequenceClassification
|
| 6 |
from langchain_core.messages import AnyMessage, AIMessage,SystemMessage, HumanMessage,AIMessageChunk
|
| 7 |
|
|
|
|
| 8 |
from streamlitui.constants import unsdg_countries
|
| 9 |
from llm.llm_setup import ModelSelection
|
| 10 |
import pandas as pd
|
| 11 |
from state.state import StateVector
|
| 12 |
-
from graph.state_vector_nodes import question_model
|
| 13 |
from graph.graph_builder import BuildGraphOptions
|
| 14 |
import re
|
| 15 |
class StreamlitConfigUI:
|
|
@@ -104,27 +105,50 @@ if __name__=='__main__':
|
|
| 104 |
|
| 105 |
if user_message and user_input['UN SDG Country']:
|
| 106 |
state=StateVector(country=user_input['UN SDG Country'], seed_question=user_message, messages=[])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
-
SmartQuestions=question_model(#StateVector=state,
|
| 109 |
-
loaded_tokenizer=loaded_tokenizer,
|
| 110 |
-
loaded_model=loaded_model, llm=llm, df_keys=df_keys)
|
| 111 |
-
builder=BuildGraphOptions(SmartQuestions)
|
| 112 |
-
graph=builder.build_graph()
|
| 113 |
-
with st.chat_message("assistant"):
|
| 114 |
-
intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\
|
| 115 |
-
You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\
|
| 116 |
-
I can also help you create a question related to the SDGs based on your input.\
|
| 117 |
-
Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #,
|
| 118 |
-
st.write(intro)
|
| 119 |
-
initial_input = {'country': user_input['UN SDG Country'], 'seed_question': user_message}
|
| 120 |
-
with st.chat_message("user"):st.write(user_message)
|
| 121 |
-
with st.chat_message("assistant"):
|
| 122 |
-
message_placeholder = st.empty()
|
| 123 |
-
accumulated_content = ""
|
| 124 |
-
for chunk in graph.stream(initial_input, stream_mode='messages'):
|
| 125 |
-
message, meta=chunk
|
| 126 |
-
if isinstance(message, AIMessage):
|
| 127 |
-
accumulated_content += message.content
|
| 128 |
-
message_placeholder.write(accumulated_content)
|
| 129 |
-
else:
|
| 130 |
-
print("Deep Rishsearch")
|
|
|
|
| 5 |
from transformers import DistilBertTokenizerFast, TFDistilBertForSequenceClassification
|
| 6 |
from langchain_core.messages import AnyMessage, AIMessage,SystemMessage, HumanMessage,AIMessageChunk
|
| 7 |
|
| 8 |
+
|
| 9 |
from streamlitui.constants import unsdg_countries
|
| 10 |
from llm.llm_setup import ModelSelection
|
| 11 |
import pandas as pd
|
| 12 |
from state.state import StateVector
|
| 13 |
+
from graph.state_vector_nodes import question_model,research_model
|
| 14 |
from graph.graph_builder import BuildGraphOptions
|
| 15 |
import re
|
| 16 |
class StreamlitConfigUI:
|
|
|
|
| 105 |
|
| 106 |
if user_message and user_input['UN SDG Country']:
|
| 107 |
state=StateVector(country=user_input['UN SDG Country'], seed_question=user_message, messages=[])
|
| 108 |
+
if user_input['selected_usecase']=='AskSmart SDG Assistant':
|
| 109 |
+
SmartQuestions=question_model(#StateVector=state,
|
| 110 |
+
loaded_tokenizer=loaded_tokenizer,
|
| 111 |
+
loaded_model=loaded_model, llm=llm, df_keys=df_keys)
|
| 112 |
+
builder=BuildGraphOptions(SmartQuestions)
|
| 113 |
+
graph=builder.build_question_graph()
|
| 114 |
+
with st.chat_message("assistant"):
|
| 115 |
+
intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\
|
| 116 |
+
You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\
|
| 117 |
+
I can also help you create a question related to the SDGs based on your input.\
|
| 118 |
+
Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #,
|
| 119 |
+
st.write(intro)
|
| 120 |
+
initial_input = {'country': user_input['UN SDG Country'], 'seed_question': user_message}
|
| 121 |
+
with st.chat_message("user"):st.write(user_message)
|
| 122 |
+
with st.chat_message("assistant"):
|
| 123 |
+
message_placeholder = st.empty()
|
| 124 |
+
accumulated_content = ""
|
| 125 |
+
for chunk in graph.stream(initial_input, stream_mode='messages'):
|
| 126 |
+
message, meta=chunk
|
| 127 |
+
if isinstance(message, AIMessage):
|
| 128 |
+
accumulated_content += message.content
|
| 129 |
+
message_placeholder.write(accumulated_content)
|
| 130 |
+
elif user_input['selected_usecase']=='DeepRishSearch':
|
| 131 |
+
print("Deep Rishsearch")
|
| 132 |
+
SmartQuestions=question_model(#StateVector=state,
|
| 133 |
+
loaded_tokenizer=loaded_tokenizer,
|
| 134 |
+
loaded_model=loaded_model, llm=llm, df_keys=df_keys)
|
| 135 |
+
builder=BuildGraphOptions(SmartQuestions)
|
| 136 |
+
ResearchModel=research_model(llm=SmartQuestions.genai_model)
|
| 137 |
+
graph=builder.build_research_graph(ResearchModel)
|
| 138 |
+
with st.chat_message("assistant"):
|
| 139 |
+
intro="Hello, I am an assistant designed to help you learn about the 17 UN SDG goals listed here: https://sdgs.un.org/goals.\
|
| 140 |
+
You can ask me about any of the goals or specific topics, and I will provide you with information and resources related to them.\
|
| 141 |
+
I can also help you create a question related to the SDGs based on your input.\
|
| 142 |
+
Please provide me with a topic or question related to the SDGs and select a country, and I will do my best to assist you." #,
|
| 143 |
+
st.write(intro)
|
| 144 |
+
with st.chat_message("user"):st.write(user_message)
|
| 145 |
+
initial_input = StateVector({'country': user_input['UN SDG Country'], 'seed_question': user_message})
|
| 146 |
+
with st.chat_message("assistant"):
|
| 147 |
+
message_placeholder = st.empty()
|
| 148 |
+
accumulated_content = ""
|
| 149 |
+
for chunk in graph.stream(initial_input, stream_mode='messages'):
|
| 150 |
+
message, meta=chunk
|
| 151 |
+
if isinstance(message, AIMessage):
|
| 152 |
+
accumulated_content += message.content
|
| 153 |
+
message_placeholder.write(accumulated_content)
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|