File size: 4,785 Bytes
da83164
 
 
dd75c3c
da83164
a6dbfdf
 
da83164
a6dbfdf
dd75c3c
a6dbfdf
 
da83164
a6dbfdf
 
 
da83164
dd75c3c
da83164
a6dbfdf
dd75c3c
 
 
 
a6dbfdf
dd75c3c
 
 
da83164
 
dd75c3c
a6dbfdf
dd75c3c
 
 
 
 
 
 
 
 
a6dbfdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd75c3c
 
 
 
 
a6dbfdf
 
 
dd75c3c
 
 
 
 
da83164
dd75c3c
 
 
da83164
 
dd75c3c
 
 
da83164
dd75c3c
 
da83164
a6dbfdf
dd75c3c
 
 
 
a6dbfdf
dd75c3c
da83164
a6dbfdf
 
 
da83164
a6dbfdf
 
 
 
 
 
 
dd75c3c
 
a6dbfdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da83164
a6dbfdf
da83164
dd75c3c
 
 
a6dbfdf
dd75c3c
da83164
 
dd75c3c
 
 
da83164
dd75c3c
a6dbfdf
 
dd75c3c
a6dbfdf
 
 
 
 
 
 
 
 
 
dd75c3c
a6dbfdf
dd75c3c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
from transformers import pipeline
from typing import Annotated, TypedDict, Optional, Any

from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langgraph.graph.message import add_messages
# from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langchain_core.tools import Tool
from math_tools import add, subtract, multiply, modulus, divide
from search_tools import wiki_search, web_search, arvix_search, question_search, vector_store
# from init_models import image_to_text_model

hf_token = os.environ.get("HF_TOKEN")

google_api_key = os.environ.get("GOOGLE_API_KEY")

# -----------------------------
# CODE LLM TOOL
# -----------------------------
def run_code_llm(input: str) -> str:
    """Call the coder model directly as a tool."""
    coder = HuggingFaceEndpoint(
        repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
        huggingfacehub_api_token=hf_token
    )
    chat = ChatHuggingFace(llm=coder, verbose=True)
    result = chat.invoke([{"role": "user", "content": input}])
    return result.content


code_llm_tool = Tool(
    name="code_llm",
    description="Use this tool to answer coding or programming questions.",
    func=run_code_llm
)

## Classify images

## Classify videos

## Classify other items


# def run_image_to_text_llm(prompt: str) -> str:
#     """Call the image to ext model directly as a tool."""
#     raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

#     # conditional image captioning
#     text = "a photography of"
#     inputs = processor(raw_image, text, return_tensors="pt").to("cuda")

#     out = model.generate(**inputs)


tools = [
    add,
    code_llm_tool,
    divide,
    subtract,
    multiply,
    modulus,
    arvix_search,
    web_search,
    question_search,
    wiki_search
]

# -----------------------------
# AGENT WRAPPER
# -----------------------------
class CurrentAgent:
    def __init__(self):        
        # 1. Define the base endpoint
        self.current_chat = ChatOpenAI(model="gpt-5-nano").bind_tools(tools)


# -----------------------------
# STATE
# -----------------------------
class AgentState(TypedDict):
    ai_agent: Optional[CurrentAgent]
    classification: str
    messages: Annotated[list[AnyMessage], add_messages]


# -----------------------------
# GENERAL ASSISTANT NODE
# -----------------------------
def general_assistant(state: AgentState) -> AgentState:
    if state["ai_agent"] is None:
        state["ai_agent"] = CurrentAgent()

    response = state["ai_agent"].current_chat.invoke(state["messages"])

    return {
        "ai_agent": state["ai_agent"],
        "classification": state["classification"],
        "messages": [response]  # with add_messages, this will be appended
    }

# load the system prompt from the file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
    system_prompt = f.read()

# System message
sys_msg = SystemMessage(content=system_prompt)



def retriever(state: AgentState):
    """Retriever node"""
    if state["ai_agent"] is None:
        state["ai_agent"] = CurrentAgent()
        
    # Find the latest human message
    user_messages = [m for m in state["messages"] if isinstance(m, HumanMessage)]
    if not user_messages:
        return {"messages": state["messages"]}

    query = user_messages[-1].content

    # Perform vector search
    similar_docs = vector_store.similarity_search(query, k=1)

    if similar_docs:
        context = similar_docs[0].page_content
        response = (
            "Here is a similar question and answer for reference:\n\n"
            f"{context}"
        )
    else:
        response = "No similar questions were found in the vector database."

    example_msg = HumanMessage(content=response)

    return {
        "ai_agent": state["ai_agent"],
        "classification": state["classification"],
        "messages": state["messages"] + [example_msg]
    }


# -----------------------------
# WORKFLOW
# -----------------------------
def build_workflow() -> Any:
    graph = StateGraph(AgentState)
    
    graph.add_node("retriever", retriever)
    graph.add_node("general_assistant", general_assistant)
    graph.add_node("tools", ToolNode(tools))
    
    graph.add_edge(START, "retriever")
    graph.add_edge("retriever", "general_assistant")
    graph.add_conditional_edges(
        "general_assistant",
        tools_condition,
    )
    graph.add_edge("tools", "general_assistant")
    
    graph.add_edge("general_assistant", END)
    
    return graph.compile()