Spaces:
Sleeping
Sleeping
File size: 11,937 Bytes
52be2ce 0f9a313 a640fc0 0f9a313 52be2ce ba78ba8 52be2ce a640fc0 9905f36 f232eef 52be2ce 0f9a313 52be2ce 0f9a313 ba78ba8 52be2ce 0f9a313 f232eef 0f9a313 833435e 0f9a313 833435e a640fc0 f232eef 0f9a313 52be2ce 0f9a313 52be2ce 9905f36 52be2ce 9905f36 52be2ce ba78ba8 32493ae ba78ba8 52be2ce 969f9ee 52be2ce 969f9ee 52be2ce 969f9ee ba78ba8 52be2ce ba78ba8 52be2ce ba78ba8 52be2ce ba78ba8 52be2ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
from langchain_core.messages import SystemMessage,AIMessage,HumanMessage,ToolMessage
from langchain_core.output_parsers import NumberedListOutputParser,JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from state.state import StateVector
from streamlitui.constants import *
import torch
import torch.nn.functional as F
#import tensorflow as tf
import re
from langchain_openai import ChatOpenAI
from langchain_community.tools.semanticscholar.tool import SemanticScholarQueryRun
from langchain_community.utilities.semanticscholar import SemanticScholarAPIWrapper
from langchain_tavily import TavilySearch
import pandas as pd
import torch.nn.functional as F
import os
torch.classes.__path__ = []
class question_model:
def __init__(self,loaded_tokenizer,loaded_model, llm, df_keys):
#self.state=StateVector
self.tokenizer=loaded_tokenizer
self.distilbert_model=loaded_model
self.genai_model=llm
self.df_keys=df_keys
def create_question_prompt_template(self, state:StateVector) -> StateVector:
"""
Creates a prompt template based on the state vector.
"""
state['messages'].extend([SystemMessage(
content="You are an AI assistant that helps users find information about the Sustainable Development Goals (SDGs)."
)
])
for topic, keywords in state['topic_kw'].items():
state['messages'].append(SystemMessage(content=f"For the UN SDG Goal: {topic}\n. \
Use the following keywords : {', '.join(keywords)}. Generate questions related to the topic in the country of {state['country']} using these keywords.\n"))
state['messages'].append(AIMessage(content="Based on the provided information, here is an enhanced list of the question: \n"))
return state
#Check input raw prompt and extract topics and keywords
def check_inputs(self,state:StateVector) -> StateVector:
"""Check if topic and keywords are set"""
#print(state)
if not state.get('seed_question') or len(state.get('seed_question').strip())<3:
raise ValueError("Seed question is not set in the state vector.")
#print(state.get('seed_question').lower())
predict_input = self.tokenizer(
text=state.get('seed_question').lower(),
max_length=512,
truncation=True,
padding='max_length',
return_tensors="pt")
#print(predict_input)
with torch.no_grad():
logits = self.distilbert_model(**predict_input).logits
prob_value=F.softmax(logits, dim=1).cpu().numpy()[0]
Topic_Bool=prob_value>0.4
Topics=[]
Keywords={}
for index, key in enumerate(sdg_goals):
if not Topic_Bool[index]:continue
#print(sdg_goals[key])
Topics.append((index+1,sdg_goals[key]))
#print(Topics)
for i,t in Topics:
kw_patterns=self.df_keys[self.df_keys['topic_num']==i]['keywords'].values[0].split(',')
Keywords[t] = re.findall(r'%s' %("|".join(kw_patterns)),state['seed_question'])
if not Keywords[t]:
Keywords[t] = kw_patterns
state['messages'].append(AIMessage(content="Will add keywords for the topic: %s \n" % t ))
state['topic'] = Topics
state['topic_kw'] = Keywords
if not state.get('country'):
state['messages'].append(AIMessage(content="Country is not set. Please provide a country. \n"))
return state
elif not state.get('topic'):
state['messages'].append(AIMessage(content="Missing topic please ask a question about the 17 Sustainable Development Goals. Graph will terminate. \n"))
state['messages'].append(AIMessage(content="Topics are: %s and keywords found: %s.\n Proceeding to prompt creation. \n" \
%(", ".join(Keywords.keys()), ", ".join([kw for kws in Keywords.values() for kw in kws]))))
return state
def should_continue(self, state:StateVector) -> str:
"""Determine whether to continue to prompt creation or terminate"""
if not state.get('topic') or not state.get('topic_kw'):
return "terminate"
return "create_question_prompt_template"
def generate_questions(self, state:StateVector) -> StateVector:
"""
Generates questions based on the provided topics and keywords.
This is a placeholder function that can be extended to include more complex question generation logic.
"""
parser=NumberedListOutputParser()
runner= self.genai_model | parser
#template= ChatPromptTemplate.from_messages(state['messages'][-2])
result = runner.invoke(state['messages'])
#print("Generated Question: %s" %result)
state['questions'] = result
#ai_response="\n".join(state['questions'])
#state['messages'].append(AIMessage(content="Generated questions: "+ai_response))
return state
class research_model:
def __init__(self,llm,tavily_api_key):
self.llm=llm
self.local_analysis_file='src/graph/data_analyst_prompts.csv'
self.tool_names=["direct_semantic_scholar_query", "direct_tavily_search" ]
semantic_scholar_tool = SemanticScholarQueryRun(
api_wrapper=SemanticScholarAPIWrapper()
)
self.tools=[semantic_scholar_tool,self.direct_tavily_search]
# Bind the tool to the LLM
self.llm_with_tools = self.llm.bind_tools(self.tools)
os.environ['TAVILY_API_KEY']=tavily_api_key
def direct_semantic_scholar_query(self,query: str):
"""Direct invocation of SemanticScholarQueryRun without agent"""
# Create the tool directly
tool = SemanticScholarQueryRun(
api_wrapper=SemanticScholarAPIWrapper()
)
# Invoke the tool directly
result = tool.invoke(query, k=10, output_parser=JsonOutputParser(), fields=["paperId","title","authors", "url","abstract","year","paperId"],sort="year")
return result
def direct_tavily_search(self,query: str):
"""Direct invocation of TavilySearchResults without agent"""
# Create the tool directly
tavily = TavilySearch(max_results=5, include_answer=True, include_snippet=True, include_source=True)
result = tavily.invoke(query)
answer=result['answer']
response=f"Summary Answer for all webpages: {answer} \n"
for r in result['results']:
response +="Found a webpage: %s at %s \n" %(r['title'], r['url'])
response +="Summary of the page: %s \n" %r['content']
response +="Relevance score: %s\n" %r['score']
return response
def data_analysis(self,state:StateVector):
df_analyst=pd.read_csv(self.local_analysis_file)
analysis_prompt=[]
topics=state['topic']
for t in topics:
Goal_Number=t[0]
df_analyst=df_analyst[df_analyst['country']==state['country']]
df_analyst['goal_number']=df_analyst['goal_number'].astype(int)
df_analyst=df_analyst[df_analyst['goal_number']==Goal_Number]
#print(df_analyst.head())
if df_analyst.shape[0]>0:
analysis_prompt.extend(df_analyst['analysis_prompt'].to_list())
return "\n".join(analysis_prompt)
def create_prompt_template(self,state:StateVector) -> ChatPromptTemplate:
"""
Creates a prompt template based on the provided questions.
"""
topic_string = ", ".join(f"{name}" for num, name in state['topic'])
keywords=[]
kw_string=''
for i,v in state['topic_kw'].items():
keywords.append(",".join(v))
kw_string += f" with keywords: {', '.join(keywords)}"
questions=state["questions"]
country=state['country']
messages = [
SystemMessage(content= f"You are an AI assistant that helps users find information about the Sustainable Development Goal: {topic_string}.\
Your task is to answer questions related to this goal using the provided tools with toolNames: {self.tool_names}\
You will be provided with a list of questions to answer below: \
questions = {questions} "),
SystemMessage(content=f"Search for recent papers on {kw_string} in {country}."),
SystemMessage(content=f"Search the internet for webpages or news on {kw_string} in {country}."),
]
state['messages'] = messages
return state
def tool_calling_agent(self):
"""Show how to bind the tool to LLM using tool calling"""
# Initialize LLM
'''
llm = ChatOpenAI(
temperature=0.1,
model_name="gpt-4o-mini",
openai_api_key=openai_api_key
)
'''
# Create the tool
semantic_scholar_tool = SemanticScholarQueryRun(
api_wrapper=SemanticScholarAPIWrapper()
)
self.tools=[semantic_scholar_tool,self.direct_tavily_search]
# Bind the tool to the LLM
llm_with_tools = self.llm.bind_tools(self.tools)
return llm_with_tools,self.tools
def tool_calling_llm(self,state:StateVector):
return {"messages":[self.llm_with_tools.invoke(state["messages"])]}
def summary_answer(self,state:StateVector)->StateVector:
"""
Function to summarize the answer from the LLM.
This is a placeholder function that can be extended to include more complex summarization.
"""
initial_system_message= state["messages"][0] # This is the system message that sets the context for the LLM with the listed questions
initial_system_message.content += "Please provide a comprehensive answer to the questions. \n"
tool_messages = [msg for msg in state["messages"] if isinstance(msg, ToolMessage)]
augmented_data=""
if tool_messages:
initial_system_message.content += "Use the following information gathered from the tools as reference information: \n"
for tool_msg in tool_messages:
print(tool_msg.content, type(tool_msg.content))
Label_Source=""
if 'semanticscholar' in tool_msg.name.lower():
Label_Source="(Source: Scholarly Publication Abstracts from Semantic Scholar)"
augmented_data+= f"{tool_msg.content}\n"
elif 'tavily' in tool_msg.name.lower():
Label_Source="(Source: News Search Results)"
augmented_data += f"{tool_msg.content}\n"
else:
print("Unknown Tool Call")
initial_system_message.content += f"{Label_Source} \n {augmented_data}\n"
analysis_prompt=self.data_analysis(state)
initial_system_message.content+=analysis_prompt
initial_system_message.content+="\n Assess if the resources indicate a general positive or negative trend and grade progress\
from 0-10 where 0 is very negative and 10 is very positive.\n"
initial_system_message.content+="\n Provide detailed answers to the questions and a list of references used."
print(initial_system_message.content)
state["messages"].append(initial_system_message)
airesponse = self.llm.invoke(state["messages"][-1].content)
# For simplicity, we just return the messages as they are
return {"messages": [airesponse]} |