File size: 11,937 Bytes
52be2ce
 
 
 
 
0f9a313
 
 
 
a640fc0
0f9a313
52be2ce
 
 
ba78ba8
52be2ce
a640fc0
9905f36
f232eef
52be2ce
0f9a313
 
 
 
 
 
52be2ce
0f9a313
 
 
 
 
 
 
 
 
ba78ba8
52be2ce
0f9a313
 
 
 
 
 
 
 
 
f232eef
 
0f9a313
833435e
0f9a313
833435e
a640fc0
f232eef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f9a313
 
 
 
 
 
52be2ce
0f9a313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52be2ce
 
9905f36
52be2ce
 
 
 
 
 
 
 
 
9905f36
52be2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba78ba8
 
 
32493ae
ba78ba8
 
 
 
52be2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
969f9ee
 
52be2ce
 
 
 
969f9ee
52be2ce
969f9ee
ba78ba8
52be2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba78ba8
52be2ce
ba78ba8
52be2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba78ba8
52be2ce
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

from langchain_core.messages import SystemMessage,AIMessage,HumanMessage,ToolMessage
from langchain_core.output_parsers import NumberedListOutputParser,JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate

from state.state import StateVector
from streamlitui.constants import *
import torch
import torch.nn.functional as F
#import tensorflow as tf
import re
from langchain_openai import ChatOpenAI
from langchain_community.tools.semanticscholar.tool import SemanticScholarQueryRun
from langchain_community.utilities.semanticscholar import SemanticScholarAPIWrapper
from langchain_tavily import TavilySearch
import pandas as pd
import torch.nn.functional as F
import os
torch.classes.__path__ = []
class question_model:
    def __init__(self,loaded_tokenizer,loaded_model, llm, df_keys):
        #self.state=StateVector
        self.tokenizer=loaded_tokenizer
        self.distilbert_model=loaded_model
        self.genai_model=llm
        self.df_keys=df_keys
    def create_question_prompt_template(self, state:StateVector) -> StateVector:
        """
        Creates a prompt template based on the state vector.
        """
        state['messages'].extend([SystemMessage(
                    content="You are an AI assistant that helps users find information about the Sustainable Development Goals (SDGs)."
                )
            ])
        for topic, keywords in state['topic_kw'].items():
            state['messages'].append(SystemMessage(content=f"For the UN SDG Goal: {topic}\n. \
                                                Use the following keywords : {', '.join(keywords)}. Generate questions related to the topic in the country of {state['country']} using these keywords.\n"))
        state['messages'].append(AIMessage(content="Based on the provided information, here is an enhanced list of the question: \n"))

        return state

    #Check input raw prompt and extract topics and keywords
    def check_inputs(self,state:StateVector) -> StateVector:
        """Check if topic and keywords are set"""
        #print(state)
        if not state.get('seed_question') or len(state.get('seed_question').strip())<3:
            raise ValueError("Seed question is not set in the state vector.")
        #print(state.get('seed_question').lower())
        predict_input = self.tokenizer(
            text=state.get('seed_question').lower(),
            max_length=512,
            truncation=True,
            padding='max_length',
            return_tensors="pt")
        #print(predict_input)
        with torch.no_grad():
            logits = self.distilbert_model(**predict_input).logits
            prob_value=F.softmax(logits, dim=1).cpu().numpy()[0]
            Topic_Bool=prob_value>0.4
            Topics=[]
            Keywords={}
            for index, key in enumerate(sdg_goals):
                if not Topic_Bool[index]:continue
                #print(sdg_goals[key])
                Topics.append((index+1,sdg_goals[key]))
            #print(Topics)
            for i,t in Topics:
                kw_patterns=self.df_keys[self.df_keys['topic_num']==i]['keywords'].values[0].split(',')
                Keywords[t] = re.findall(r'%s' %("|".join(kw_patterns)),state['seed_question'])
                if not Keywords[t]:
                    Keywords[t] = kw_patterns
                    state['messages'].append(AIMessage(content="Will add keywords for the topic: %s \n" % t ))
            state['topic'] = Topics
            state['topic_kw'] = Keywords
            if not state.get('country'):
                state['messages'].append(AIMessage(content="Country is not set. Please provide a country. \n"))
                return state
            elif not state.get('topic'):
                state['messages'].append(AIMessage(content="Missing topic please ask a question about the 17 Sustainable Development Goals. Graph will terminate. \n"))
            state['messages'].append(AIMessage(content="Topics are: %s and keywords found: %s.\n Proceeding to prompt creation. \n" \
                                            %(", ".join(Keywords.keys()), ", ".join([kw for kws in Keywords.values() for kw in kws]))))
        return state    

    def should_continue(self, state:StateVector) -> str:
        """Determine whether to continue to prompt creation or terminate"""
        if not state.get('topic') or not state.get('topic_kw'):
            return "terminate"
        return "create_question_prompt_template"
    def generate_questions(self, state:StateVector) -> StateVector:
        """
        Generates questions based on the provided topics and keywords.
        This is a placeholder function that can be extended to include more complex question generation logic.
        """
        parser=NumberedListOutputParser()
        runner= self.genai_model | parser
        #template= ChatPromptTemplate.from_messages(state['messages'][-2])
        result = runner.invoke(state['messages'])

        #print("Generated Question: %s" %result)
        
        state['questions'] = result
        #ai_response="\n".join(state['questions'])
        #state['messages'].append(AIMessage(content="Generated questions: "+ai_response))
        return state

class research_model:
    def __init__(self,llm,tavily_api_key):
        self.llm=llm
        self.local_analysis_file='src/graph/data_analyst_prompts.csv'
        self.tool_names=["direct_semantic_scholar_query", "direct_tavily_search" ]
        semantic_scholar_tool = SemanticScholarQueryRun(
            api_wrapper=SemanticScholarAPIWrapper()
        )
        self.tools=[semantic_scholar_tool,self.direct_tavily_search]
        # Bind the tool to the LLM
        self.llm_with_tools = self.llm.bind_tools(self.tools)
        os.environ['TAVILY_API_KEY']=tavily_api_key

    def direct_semantic_scholar_query(self,query: str):

        """Direct invocation of SemanticScholarQueryRun without agent"""
        
        # Create the tool directly
        tool = SemanticScholarQueryRun(
            api_wrapper=SemanticScholarAPIWrapper()
        )
        
        # Invoke the tool directly
        result = tool.invoke(query, k=10, output_parser=JsonOutputParser(), fields=["paperId","title","authors", "url","abstract","year","paperId"],sort="year")

        return result

    def direct_tavily_search(self,query: str):
        """Direct invocation of TavilySearchResults without agent"""
        # Create the tool directly
        tavily = TavilySearch(max_results=5, include_answer=True, include_snippet=True, include_source=True)
        result = tavily.invoke(query)
        answer=result['answer']
        response=f"Summary Answer for all webpages: {answer} \n"
        for r in result['results']:
            response +="Found a webpage: %s at %s \n" %(r['title'], r['url'])
            response +="Summary of the page: %s \n" %r['content']
            response +="Relevance score: %s\n" %r['score']
        return response
    def data_analysis(self,state:StateVector):
        df_analyst=pd.read_csv(self.local_analysis_file)
        analysis_prompt=[]
        topics=state['topic']
        for t in topics:
            Goal_Number=t[0]
            df_analyst=df_analyst[df_analyst['country']==state['country']]
            df_analyst['goal_number']=df_analyst['goal_number'].astype(int)
            df_analyst=df_analyst[df_analyst['goal_number']==Goal_Number]
            #print(df_analyst.head())

            if df_analyst.shape[0]>0:
                analysis_prompt.extend(df_analyst['analysis_prompt'].to_list())
        return "\n".join(analysis_prompt)
    
    def create_prompt_template(self,state:StateVector) -> ChatPromptTemplate:
            """
            Creates a prompt template based on the provided questions.
            """
            topic_string = ", ".join(f"{name}" for num, name in state['topic'])
            keywords=[]
            kw_string=''
            for i,v in state['topic_kw'].items():
                keywords.append(",".join(v))
            kw_string += f" with keywords: {', '.join(keywords)}"
            questions=state["questions"]
            country=state['country']
            messages = [
                    SystemMessage(content= f"You are an AI assistant that helps users find information about the Sustainable Development Goal: {topic_string}.\
                                Your task is to answer questions related to this goal using the provided tools with toolNames: {self.tool_names}\
                                    You will be provided with a list of questions to answer below: \
                                    questions = {questions} "),

                    SystemMessage(content=f"Search for recent papers on {kw_string} in {country}."),
                    SystemMessage(content=f"Search the internet for webpages or news on {kw_string} in {country}."),
                ]
            state['messages'] = messages
            return state

    def tool_calling_agent(self):
        """Show how to bind the tool to LLM using tool calling"""
        
        # Initialize LLM
        '''
        llm = ChatOpenAI(
            temperature=0.1,
            model_name="gpt-4o-mini",
            openai_api_key=openai_api_key
        )
        '''
        # Create the tool
        semantic_scholar_tool = SemanticScholarQueryRun(
            api_wrapper=SemanticScholarAPIWrapper()
        )
        self.tools=[semantic_scholar_tool,self.direct_tavily_search]
        # Bind the tool to the LLM
        llm_with_tools = self.llm.bind_tools(self.tools)
        
        return llm_with_tools,self.tools


    def tool_calling_llm(self,state:StateVector):
        return {"messages":[self.llm_with_tools.invoke(state["messages"])]}

    def summary_answer(self,state:StateVector)->StateVector:
        """
        Function to summarize the answer from the LLM.
        This is a placeholder function that can be extended to include more complex summarization.
        """
        
        initial_system_message= state["messages"][0] # This is the system message that sets the context for the LLM with the listed questions
        initial_system_message.content += "Please provide a comprehensive answer to the questions. \n"
        
        tool_messages = [msg for msg in state["messages"] if isinstance(msg, ToolMessage)]
        augmented_data=""
        if tool_messages:
            initial_system_message.content += "Use the following information gathered from the tools as reference information: \n"
            
            for tool_msg in tool_messages:
                print(tool_msg.content, type(tool_msg.content))
                Label_Source=""
                if 'semanticscholar' in tool_msg.name.lower():
                    Label_Source="(Source: Scholarly Publication Abstracts from Semantic Scholar)"
                    augmented_data+= f"{tool_msg.content}\n"
                elif 'tavily' in tool_msg.name.lower():
                    Label_Source="(Source: News Search Results)"
                    augmented_data += f"{tool_msg.content}\n"
                else:
                    print("Unknown Tool Call")

                initial_system_message.content += f"{Label_Source} \n {augmented_data}\n"
        analysis_prompt=self.data_analysis(state)
        initial_system_message.content+=analysis_prompt
        initial_system_message.content+="\n Assess if the resources indicate a general positive or negative trend and grade progress\
            from 0-10 where 0 is very negative and 10 is very positive.\n"
        initial_system_message.content+="\n Provide detailed answers to the questions and a list of references used."
        print(initial_system_message.content)
        state["messages"].append(initial_system_message)
        airesponse = self.llm.invoke(state["messages"][-1].content)
        # For simplicity, we just return the messages as they are
        return {"messages": [airesponse]}