Agents_DeepSearch / src /graph /state_vector_nodes.py
rgp230's picture
fix(downgrade_transformers): Downgrade transformers version to bypass the errors
833435e
from langchain_core.messages import SystemMessage,AIMessage,HumanMessage,ToolMessage
from langchain_core.output_parsers import NumberedListOutputParser,JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from state.state import StateVector
from streamlitui.constants import *
import torch
import torch.nn.functional as F
#import tensorflow as tf
import re
from langchain_openai import ChatOpenAI
from langchain_community.tools.semanticscholar.tool import SemanticScholarQueryRun
from langchain_community.utilities.semanticscholar import SemanticScholarAPIWrapper
from langchain_tavily import TavilySearch
import pandas as pd
import torch.nn.functional as F
import os
torch.classes.__path__ = []
class question_model:
def __init__(self,loaded_tokenizer,loaded_model, llm, df_keys):
#self.state=StateVector
self.tokenizer=loaded_tokenizer
self.distilbert_model=loaded_model
self.genai_model=llm
self.df_keys=df_keys
def create_question_prompt_template(self, state:StateVector) -> StateVector:
"""
Creates a prompt template based on the state vector.
"""
state['messages'].extend([SystemMessage(
content="You are an AI assistant that helps users find information about the Sustainable Development Goals (SDGs)."
)
])
for topic, keywords in state['topic_kw'].items():
state['messages'].append(SystemMessage(content=f"For the UN SDG Goal: {topic}\n. \
Use the following keywords : {', '.join(keywords)}. Generate questions related to the topic in the country of {state['country']} using these keywords.\n"))
state['messages'].append(AIMessage(content="Based on the provided information, here is an enhanced list of the question: \n"))
return state
#Check input raw prompt and extract topics and keywords
def check_inputs(self,state:StateVector) -> StateVector:
"""Check if topic and keywords are set"""
#print(state)
if not state.get('seed_question') or len(state.get('seed_question').strip())<3:
raise ValueError("Seed question is not set in the state vector.")
#print(state.get('seed_question').lower())
predict_input = self.tokenizer(
text=state.get('seed_question').lower(),
max_length=512,
truncation=True,
padding='max_length',
return_tensors="pt")
#print(predict_input)
with torch.no_grad():
logits = self.distilbert_model(**predict_input).logits
prob_value=F.softmax(logits, dim=1).cpu().numpy()[0]
Topic_Bool=prob_value>0.4
Topics=[]
Keywords={}
for index, key in enumerate(sdg_goals):
if not Topic_Bool[index]:continue
#print(sdg_goals[key])
Topics.append((index+1,sdg_goals[key]))
#print(Topics)
for i,t in Topics:
kw_patterns=self.df_keys[self.df_keys['topic_num']==i]['keywords'].values[0].split(',')
Keywords[t] = re.findall(r'%s' %("|".join(kw_patterns)),state['seed_question'])
if not Keywords[t]:
Keywords[t] = kw_patterns
state['messages'].append(AIMessage(content="Will add keywords for the topic: %s \n" % t ))
state['topic'] = Topics
state['topic_kw'] = Keywords
if not state.get('country'):
state['messages'].append(AIMessage(content="Country is not set. Please provide a country. \n"))
return state
elif not state.get('topic'):
state['messages'].append(AIMessage(content="Missing topic please ask a question about the 17 Sustainable Development Goals. Graph will terminate. \n"))
state['messages'].append(AIMessage(content="Topics are: %s and keywords found: %s.\n Proceeding to prompt creation. \n" \
%(", ".join(Keywords.keys()), ", ".join([kw for kws in Keywords.values() for kw in kws]))))
return state
def should_continue(self, state:StateVector) -> str:
"""Determine whether to continue to prompt creation or terminate"""
if not state.get('topic') or not state.get('topic_kw'):
return "terminate"
return "create_question_prompt_template"
def generate_questions(self, state:StateVector) -> StateVector:
"""
Generates questions based on the provided topics and keywords.
This is a placeholder function that can be extended to include more complex question generation logic.
"""
parser=NumberedListOutputParser()
runner= self.genai_model | parser
#template= ChatPromptTemplate.from_messages(state['messages'][-2])
result = runner.invoke(state['messages'])
#print("Generated Question: %s" %result)
state['questions'] = result
#ai_response="\n".join(state['questions'])
#state['messages'].append(AIMessage(content="Generated questions: "+ai_response))
return state
class research_model:
def __init__(self,llm,tavily_api_key):
self.llm=llm
self.local_analysis_file='src/graph/data_analyst_prompts.csv'
self.tool_names=["direct_semantic_scholar_query", "direct_tavily_search" ]
semantic_scholar_tool = SemanticScholarQueryRun(
api_wrapper=SemanticScholarAPIWrapper()
)
self.tools=[semantic_scholar_tool,self.direct_tavily_search]
# Bind the tool to the LLM
self.llm_with_tools = self.llm.bind_tools(self.tools)
os.environ['TAVILY_API_KEY']=tavily_api_key
def direct_semantic_scholar_query(self,query: str):
"""Direct invocation of SemanticScholarQueryRun without agent"""
# Create the tool directly
tool = SemanticScholarQueryRun(
api_wrapper=SemanticScholarAPIWrapper()
)
# Invoke the tool directly
result = tool.invoke(query, k=10, output_parser=JsonOutputParser(), fields=["paperId","title","authors", "url","abstract","year","paperId"],sort="year")
return result
def direct_tavily_search(self,query: str):
"""Direct invocation of TavilySearchResults without agent"""
# Create the tool directly
tavily = TavilySearch(max_results=5, include_answer=True, include_snippet=True, include_source=True)
result = tavily.invoke(query)
answer=result['answer']
response=f"Summary Answer for all webpages: {answer} \n"
for r in result['results']:
response +="Found a webpage: %s at %s \n" %(r['title'], r['url'])
response +="Summary of the page: %s \n" %r['content']
response +="Relevance score: %s\n" %r['score']
return response
def data_analysis(self,state:StateVector):
df_analyst=pd.read_csv(self.local_analysis_file)
analysis_prompt=[]
topics=state['topic']
for t in topics:
Goal_Number=t[0]
df_analyst=df_analyst[df_analyst['country']==state['country']]
df_analyst['goal_number']=df_analyst['goal_number'].astype(int)
df_analyst=df_analyst[df_analyst['goal_number']==Goal_Number]
#print(df_analyst.head())
if df_analyst.shape[0]>0:
analysis_prompt.extend(df_analyst['analysis_prompt'].to_list())
return "\n".join(analysis_prompt)
def create_prompt_template(self,state:StateVector) -> ChatPromptTemplate:
"""
Creates a prompt template based on the provided questions.
"""
topic_string = ", ".join(f"{name}" for num, name in state['topic'])
keywords=[]
kw_string=''
for i,v in state['topic_kw'].items():
keywords.append(",".join(v))
kw_string += f" with keywords: {', '.join(keywords)}"
questions=state["questions"]
country=state['country']
messages = [
SystemMessage(content= f"You are an AI assistant that helps users find information about the Sustainable Development Goal: {topic_string}.\
Your task is to answer questions related to this goal using the provided tools with toolNames: {self.tool_names}\
You will be provided with a list of questions to answer below: \
questions = {questions} "),
SystemMessage(content=f"Search for recent papers on {kw_string} in {country}."),
SystemMessage(content=f"Search the internet for webpages or news on {kw_string} in {country}."),
]
state['messages'] = messages
return state
def tool_calling_agent(self):
"""Show how to bind the tool to LLM using tool calling"""
# Initialize LLM
'''
llm = ChatOpenAI(
temperature=0.1,
model_name="gpt-4o-mini",
openai_api_key=openai_api_key
)
'''
# Create the tool
semantic_scholar_tool = SemanticScholarQueryRun(
api_wrapper=SemanticScholarAPIWrapper()
)
self.tools=[semantic_scholar_tool,self.direct_tavily_search]
# Bind the tool to the LLM
llm_with_tools = self.llm.bind_tools(self.tools)
return llm_with_tools,self.tools
def tool_calling_llm(self,state:StateVector):
return {"messages":[self.llm_with_tools.invoke(state["messages"])]}
def summary_answer(self,state:StateVector)->StateVector:
"""
Function to summarize the answer from the LLM.
This is a placeholder function that can be extended to include more complex summarization.
"""
initial_system_message= state["messages"][0] # This is the system message that sets the context for the LLM with the listed questions
initial_system_message.content += "Please provide a comprehensive answer to the questions. \n"
tool_messages = [msg for msg in state["messages"] if isinstance(msg, ToolMessage)]
augmented_data=""
if tool_messages:
initial_system_message.content += "Use the following information gathered from the tools as reference information: \n"
for tool_msg in tool_messages:
print(tool_msg.content, type(tool_msg.content))
Label_Source=""
if 'semanticscholar' in tool_msg.name.lower():
Label_Source="(Source: Scholarly Publication Abstracts from Semantic Scholar)"
augmented_data+= f"{tool_msg.content}\n"
elif 'tavily' in tool_msg.name.lower():
Label_Source="(Source: News Search Results)"
augmented_data += f"{tool_msg.content}\n"
else:
print("Unknown Tool Call")
initial_system_message.content += f"{Label_Source} \n {augmented_data}\n"
analysis_prompt=self.data_analysis(state)
initial_system_message.content+=analysis_prompt
initial_system_message.content+="\n Assess if the resources indicate a general positive or negative trend and grade progress\
from 0-10 where 0 is very negative and 10 is very positive.\n"
initial_system_message.content+="\n Provide detailed answers to the questions and a list of references used."
print(initial_system_message.content)
state["messages"].append(initial_system_message)
airesponse = self.llm.invoke(state["messages"][-1].content)
# For simplicity, we just return the messages as they are
return {"messages": [airesponse]}