File size: 4,618 Bytes
49f468c
 
 
 
 
 
 
 
64358a0
49f468c
 
 
 
 
8a34a81
49f468c
64358a0
 
49f468c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b260e4
49f468c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b260e4
49f468c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Literal
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from visit_web_pages_tool import visit_webpage
import wikipedia
import json
from prompt import SYSTEM_PROMPT_MANAGER, SYSTEM_PROMPT_CLEANER
from langchain_community.tools import DuckDuckGoSearchResults
import os

class GraphState(TypedDict):
    history: List

llm = HuggingFaceEndpoint(
    repo_id="openai/gpt-oss-120b",
    task='text-generation',
    max_new_tokens=4096,
    huggingfacehub_api_token=os.environ['hf_token']
)
manager_agent = ChatHuggingFace(llm=llm)
#cleaner_agent = ChatHuggingFace(llm=llm)

def llm_call(state: GraphState) -> GraphState:
    """
        Node used to generate the basic LLM calls from the manager agent.
    """
    print(state['history'])
    answer_llm = manager_agent.invoke(state['history'])
    state['history'].append(answer_llm)
    return state

def tool_call(state: GraphState) -> GraphState:
    """
        Node used to perform tool call. For the moment, the only tool available is web_research.
    """
    #Première étape, convertir le dernier call en json
    json_last_answer = json.loads(state['history'][-1].content)
    if (json_last_answer['action'] == 'web_search'):
        result_search = wikipedia.search(json_last_answer['query'])
        markdown_website = visit_webpage(wikipedia.page(result_search[0]).url)
        state['history'].append(
            {'role': 'tool', 'name': "web_search", 'content': markdown_website, 'tool_call_id': 'blablabla'}
        )
        return state
    else:
        state['history'].append(
            {'role': 'tool', 'content': 'Invalid tool call', 'tool_call_id': 'blablabla'}
        )
        return state

def force_final_answer(state: GraphState) -> GraphState:
    state['history'].append(
        {'role': 'human', 'content': 'Now provide the final answer based on the intermediate answer'}
    )
    return state

def router_edge_tool(state: GraphState) -> Literal["llm_call", "tool_call", "end"]:
    json_last_answer = json.loads(state['history'][-1].content)
    if (json_last_answer['action'] == "intermediate_answer"):
        return "force_final_answer"
    elif (json_last_answer['action'] == 'web_search'):
        return "tool_call"
    else:
        return "end"


my_graph_build = StateGraph(GraphState)
my_graph_build.add_node("llm_call", llm_call)
my_graph_build.add_node("tool_call", tool_call)
my_graph_build.add_node("force_final_answer", force_final_answer)
my_graph_build.add_conditional_edges("llm_call", router_edge_tool, {"force_final_answer": "force_final_answer", "tool_call": "tool_call", "end": END})
my_graph_build.add_edge("tool_call", "llm_call")
my_graph_build.add_edge("force_final_answer", "llm_call")
my_graph_build.set_entry_point("llm_call")
my_graph = my_graph_build.compile()

init_state = GraphState(history=[
    {'role': 'system', 'content': SYSTEM_PROMPT_MANAGER},
    {'role': 'human', 'content': 'How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?'}
    ])

init_state_2 = GraphState(history=[
    {'role': 'system', 'content': SYSTEM_PROMPT_MANAGER},
    {'role': 'human', 'content': 'Who did the actor who played Ray in the Polish-language version of Everybody Loves Raymond play in Magda M.? Give only the first name.'}
    ])

init_state_3 = GraphState(history=[
    {'role': 'system', 'content': SYSTEM_PROMPT_MANAGER},
    {'role': 'human', 'content': 'What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?'}
    ])

init_state_4 = GraphState(history=[
    {'role': 'system', 'content': SYSTEM_PROMPT_MANAGER},
    {'role': 'human', 'content': '.rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI'}
    ])

init_state_5 = GraphState(history=[
    {'role': 'system', 'content': SYSTEM_PROMPT_MANAGER},
    {'role': 'human', 'content': "What country had the least number of athletes at the 1928 Summer Olympics? If there's a tie for a number of athletes, return the first in alphabetical order. Give the IOC country code as your answer."}
    ])
"""
print(manager_agent.invoke([
    {'role': 'system', 'content': SYSTEM_PROMPT_MANAGER},
    {'role': 'human', 'content': 'What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?'}
    ]))
"""
#print(my_graph.invoke(init_state))
#print(my_graph.invoke(init_state_5))