File size: 5,016 Bytes
e21909c
 
 
b0809d3
e21909c
 
 
 
 
 
 
4870d9e
 
 
 
e21909c
 
24c9d4e
 
e21909c
 
 
 
 
 
 
8804351
e21909c
 
 
 
 
 
89ce62d
 
 
 
 
 
 
 
623ea05
e21909c
 
 
 
 
 
4870d9e
 
 
 
 
e21909c
 
4870d9e
e21909c
4870d9e
 
 
 
 
e21909c
4870d9e
e21909c
 
 
 
 
 
 
 
 
4870d9e
 
 
 
e21909c
 
4870d9e
 
 
 
 
 
 
 
e21909c
 
 
 
 
 
 
 
 
 
 
 
4870d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1c760a
4870d9e
 
 
 
e21909c
 
 
 
 
4870d9e
 
 
 
 
e21909c
4870d9e
 
 
 
f1c760a
e21909c
 
4870d9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os

from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import StateGraph, START, MessagesState, END
from langchain.agents import create_agent
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_ollama import ChatOllama
from langchain.agents.middleware.types import AgentState
from langchain.messages import HumanMessage, AIMessage, SystemMessage


from prompts import system_prompt, qa_system_prompt
from my_tools import wiki_search, arxiv_search, web_search, visit_webpage, translate_to_english

hf_token = os.getenv("HF_TOKEN")

class GraphMessagesState(MessagesState):
    question: str


# --- Basic Agent Definition ---
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
class BasicAgent:
    def __init__(self):
        model = HuggingFaceEndpoint(
            repo_id="Qwen/Qwen3-8B",
            task="text-generation",
            max_new_tokens=512,
            do_sample=False,
            repetition_penalty=1.03,
        )
        llm = ChatHuggingFace(llm=model, verbose=True)
        
        model = HuggingFaceEndpoint(
            repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
            task="text-generation",
            max_new_tokens=512,
            do_sample=False,
            repetition_penalty=1.03,
        )
        self.llm_qa = ChatHuggingFace(llm=model, verbose=True)
        # llm = ChatOllama(
        #     model="qwen3:0.6b",
        #     api_base="http://localhost:11434",  # replace with
        #     # debug=True,
        # )
        tools = [
            wiki_search,
            arxiv_search,
            web_search,
            visit_webpage,
            translate_to_english,
        ]

        builder = StateGraph(GraphMessagesState)

        model = create_agent(
            llm,
            tools,
            system_prompt=system_prompt,
        )
        builder.add_node("assistant", model)
        builder.add_node("assistant_qa", self.call_qa)
        builder.add_node("tools", ToolNode(tools))

        # Define edges: these determine how the control flow moves
        builder.add_edge(START, "assistant")
        builder.add_conditional_edges(
            "assistant",
            # If the latest message requires a tool, route to tools
            # Otherwise, provide a direct response
            tools_condition,
            {
                "tools": "tools",
                END: "assistant_qa",
            },
        )
        builder.add_edge("tools", "assistant")
        builder.add_conditional_edges(
            "assistant_qa",
            tools_condition,
            {
                "tools": "tools",
                END: END,
            },
        )
        self.agent = builder.compile()

        print("BasicAgent initialized.")

    def __call__(self, question: str) -> str:
        print(f"Agent received question (first 50 chars): {question[:50]}...")

        fixed_answer = self.generate_answer(question)

        print(f"Agent returning fixed answer: {fixed_answer}")
        return fixed_answer

    def call_qa(self, graph_state: GraphMessagesState) -> str:
        # print(f"Calling LLM QA for question: {graph_state['question']}")
        # print(type(graph_state["messages"]))
        # print(graph_state["messages"])

        # parsed_messages = [
        #     {"role": m.type, "content": m.content} for m in graph_state["messages"]
        # ]
        parsed_messages = [
            SystemMessage(content=qa_system_prompt)
        ]
        parsed_messages.extend(graph_state["messages"][1:])
        parsed_messages.append(HumanMessage(content=f"Question: {graph_state['question']}"))
        print(f"\n\n\n parsed_messages => {parsed_messages}")

        # response = self.llm_qa.invoke(
        #     {
        #         "messages": [
        #             *parsed_messages,
        #             {
        #                 "role": "human",
        #                 "content": graph_state["question"],
        #             },
        #         ]
        #     },
        #     {"callbacks": [langfuse_handler]},
        # )
        response = self.llm_qa.invoke(
            parsed_messages,
            # {"callbacks": [langfuse_handler]},
        )
        print(f"LLAMA 2 -> QA Agent raw response: {response}")
        return response.model_dump()

    def generate_answer(self, question: str) -> str:
        response = self.agent.invoke(
            {
                "messages": [
                    {
                        "role": "system",
                        "content": system_prompt,
                    },
                    {
                        "role": "human",
                        "content": question,
                    },
                ],
                "question": question,
            },
            # {"callbacks": [langfuse_handler]},
        )
        print(f"Agent raw response: {response}")
        return response["messages"][-1].content