File size: 8,922 Bytes
cc74784
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from pydantic_graph import BaseNode, End, GraphRunContext, Graph
from pydantic_ai import Agent
from pydantic_ai.common_tools.tavily import tavily_search_tool
from dataclasses import dataclass
from pydantic import Field, BaseModel
from typing import  List, Dict, Optional, Any
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from dotenv import load_dotenv
import os
from tavily import TavilyClient
from IPython.display import Image, display
import requests
import time

load_dotenv()
google_api_key=os.getenv('google_api_key')
tavily_key=os.getenv('tavily_key')
tavily_client = TavilyClient(api_key=tavily_key)
llm=GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=google_api_key))
pse=os.getenv('pse')

@dataclass
class State:
    query:str
    preliminary_research: str
    research_plan: Dict
    research_results: Dict
    validation : str
    final: Dict
class paragraph_content(BaseModel):
    title: str = Field(description='the title of the paragraph')
    content: str = Field(description='the content of the paragraph')

class paragraph(BaseModel):
    title: str = Field(description='the title of the paragraph')
    should_include: str = Field(description='a description of what the paragraph should include')  
class Paper_layout(BaseModel):
    title: str = Field(description='the title of the paper')
    paragraphs: List[paragraph]= Field(description='the list of paragraphs of the paper')

paper_layout_agent=Agent(llm, result_type=Paper_layout, system_prompt="generate a paper layout based on the query, preliminary_search, search_results,include a Title for the paper, for the paragraphs only include the title, no content, no image, no table, start with introduction and end with conclusion")
paragraph_gen_agent=Agent(llm, result_type=paragraph_content, system_prompt="generate a paragraph synthesizing the research_results based on the title,what the paragraph should include, and what has already been written to avoid repetition")
class PaperGen_node(BaseNode[State]):
    async def run(self, ctx: GraphRunContext[State])->End:
        prompt=(f'query:{ctx.state.query}, preliminary_search:{ctx.state.preliminary_research},search_results:{ctx.state.research_results.research_results}')
        result=await paper_layout_agent.run(prompt)
        paragraphs=[]
        for i in result.data.paragraphs:
            time.sleep(2)
            paragraph_data=await paragraph_gen_agent.run(f'title:{i.title}, should_include:{i.should_include}, research_results:{ctx.state.research_results.research_results}, already_written:{paragraphs}')
            paragraphs.append(paragraph_data.data.model_dump())

        paper={'title':result.data.title,
                'image_url':ctx.state.research_results.image_url if ctx.state.research_results.image_url else None,
                'paragraphs':paragraphs,
                'table':ctx.state.research_results.table if ctx.state.research_results.table else None,
                'references':ctx.state.research_results.references if ctx.state.research_results.references else None}

        ctx.state.final=paper
                
        return End(ctx.state.final)

def google_image_search(query:str):
  """Search for images using Google Custom Search API
  args: query
  return: image url
  """
  # Define the API endpoint for Google Custom Search
  url = "https://www.googleapis.com/customsearch/v1"

  params = {
      "q": query,
      "cx": pse,
      "key": google_api_key,
      "searchType": "image",  # Search for images
      "num": 1  # Number of results to fetch
  }

  # Make the request to the Google Custom Search API
  response = requests.get(url, params=params)
  data = response.json()

  # Check if the response contains image results
  if 'items' in data:
      # Extract the first image result
      image_url = data['items'][0]['link']
      return image_url

class Table_row(BaseModel):
    data: List[str] = Field(description='the data of the row')
class Table(BaseModel):
    rows: List[Table_row] = Field(description='the rows of the table')
    columns: List[str] = Field(description='the columns of the table')

class Research_results(BaseModel):
    research_results: List[str] = Field(default_factory=None,description='the research results')
    image_url: str = Field(default_factory=None,description='the image url if needed else return None')
    table: dict = Field(default_factory=None,description='the table dataframe in a dictionary format')
    references: str = Field(default_factory=None,description='the references (urls) of the research_results')

table_agent=Agent(llm, result_type=Table, system_prompt="generate a detailed table in dictionary format based on the research and the query")

class Research_node(BaseNode[State]):
    async def run(self, ctx: GraphRunContext[State])->PaperGen_node:
        research_results=Research_results(research_results=[], image_url='', table={}, references='')
        
        for i in ctx.state.research_plan.search_queries:
            response = tavily_client.search(i.search_query)
            data=[]
            for i in response.get('results'):
                if i.get('score')>0.50:
                    
                    
                    data.append(i.get('url'))
                    research_results.research_results.append(i.get('content'))
        research_results.research_results=list(set(research_results.research_results))
        research_results.references=list(set(data))
        research_results.references=', '.join(research_results.references)
        ctx.state.research_results=research_results
        if ctx.state.research_plan.image_search_query:
            image_url=google_image_search(ctx.state.research_plan.image_search_query)
            ctx.state.research_results.image_url=image_url
       
        if ctx.state.research_plan.table:
            result=await table_agent.run(f'research_results:{ctx.state.research_results.research_results},query:{ctx.state.query}')
            ctx.state.research_results.table={'data':[row.data for row in result.data.rows], 'columns':result.data.columns}
        
        
        return PaperGen_node()

class search_query(BaseModel):
    search_query: str = Field(description='the detailed web search query for the research')
    
class Research_plan(BaseModel):
    search_queries: List[search_query] = Field(description='the detailed web search queries for the research')
    table: Optional[str] = Field(default_factory=None,description='if a table is needed, return yes else return None')
    image_search_query: Optional[str] = Field(default_factory=None,description='if image is needed, generate a image search query, optional')

research_plan_agent=Agent(llm, result_type=Research_plan, system_prompt='generate a detailed research plan breaking down the research into smaller parts based on the query and the preliminary search, include a table and image search query if the user wants it')

class Research_plan_node(BaseNode[State]):
    async def run(self, ctx: GraphRunContext[State])->Research_node:
        
        prompt=(f'query:{ctx.state.query}, preliminary_search:{ctx.state.preliminary_research}')
        result=await research_plan_agent.run(prompt)
        ctx.state.research_plan=result.data
        return Research_node()
    
    
search_agent=Agent(llm, tools=[tavily_search_tool(tavily_key)], system_prompt="do a websearch based on the query")

class preliminary_search_node(BaseNode[State]):
    async def run(self, ctx: GraphRunContext[State]) -> Research_plan_node:
        prompt = (' Do a preliminary search to get a global idea of the subject that the user wants to do reseach on as well as the necessary informations to do a search on.\n'
                  f'The subject is based on the query: {ctx.state.query}, return the results of the search.')
        result=await search_agent.run(prompt)
        ctx.state.preliminary_research=result.data
        return Research_plan_node()

class Deep_research_engine:
    def __init__(self):
        self.graph=Graph(nodes=[preliminary_search_node, Research_plan_node, Research_node, PaperGen_node])
        self.state=State(query='', preliminary_research='', research_plan=[], research_results=[], validation='', final='')

    async def chat(self,query:str):
        """Chat with the deep research engine,
        Args:
            query (str): The query to search for
        Returns:
            str: The response from the deep research engine
        """
        self.state.query=query
        response=await self.graph.run(preliminary_search_node(),state=self.state)
        return response.output


    def display_graph(self):
        """Display the graph of the deep research engine
        Returns:
            Image: The image of the graph
        """
        image=self.graph.mermaid_image()
        return display(Image(image))