|
|
import os |
|
|
|
|
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
from langgraph.graph import StateGraph, START, MessagesState, END |
|
|
from langchain.agents import create_agent |
|
|
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace |
|
|
from langchain_community.tools import DuckDuckGoSearchRun |
|
|
from langchain_ollama import ChatOllama |
|
|
from langchain.agents.middleware.types import AgentState |
|
|
from langchain.messages import HumanMessage, AIMessage, SystemMessage |
|
|
|
|
|
|
|
|
from prompts import system_prompt, qa_system_prompt |
|
|
from my_tools import wiki_search, arxiv_search, web_search, visit_webpage, translate_to_english |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
class GraphMessagesState(MessagesState): |
|
|
question: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasicAgent: |
|
|
def __init__(self): |
|
|
model = HuggingFaceEndpoint( |
|
|
repo_id="Qwen/Qwen3-8B", |
|
|
task="text-generation", |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
repetition_penalty=1.03, |
|
|
) |
|
|
llm = ChatHuggingFace(llm=model, verbose=True) |
|
|
|
|
|
model = HuggingFaceEndpoint( |
|
|
repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
|
task="text-generation", |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
repetition_penalty=1.03, |
|
|
) |
|
|
self.llm_qa = ChatHuggingFace(llm=model, verbose=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools = [ |
|
|
wiki_search, |
|
|
arxiv_search, |
|
|
web_search, |
|
|
visit_webpage, |
|
|
translate_to_english, |
|
|
] |
|
|
|
|
|
builder = StateGraph(GraphMessagesState) |
|
|
|
|
|
model = create_agent( |
|
|
llm, |
|
|
tools, |
|
|
system_prompt=system_prompt, |
|
|
) |
|
|
builder.add_node("assistant", model) |
|
|
builder.add_node("assistant_qa", self.call_qa) |
|
|
builder.add_node("tools", ToolNode(tools)) |
|
|
|
|
|
|
|
|
builder.add_edge(START, "assistant") |
|
|
builder.add_conditional_edges( |
|
|
"assistant", |
|
|
|
|
|
|
|
|
tools_condition, |
|
|
{ |
|
|
"tools": "tools", |
|
|
END: "assistant_qa", |
|
|
}, |
|
|
) |
|
|
builder.add_edge("tools", "assistant") |
|
|
builder.add_conditional_edges( |
|
|
"assistant_qa", |
|
|
tools_condition, |
|
|
{ |
|
|
"tools": "tools", |
|
|
END: END, |
|
|
}, |
|
|
) |
|
|
self.agent = builder.compile() |
|
|
|
|
|
print("BasicAgent initialized.") |
|
|
|
|
|
def __call__(self, question: str) -> str: |
|
|
print(f"Agent received question (first 50 chars): {question[:50]}...") |
|
|
|
|
|
fixed_answer = self.generate_answer(question) |
|
|
|
|
|
print(f"Agent returning fixed answer: {fixed_answer}") |
|
|
return fixed_answer |
|
|
|
|
|
def call_qa(self, graph_state: GraphMessagesState) -> str: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parsed_messages = [ |
|
|
SystemMessage(content=qa_system_prompt) |
|
|
] |
|
|
parsed_messages.extend(graph_state["messages"][1:]) |
|
|
parsed_messages.append(HumanMessage(content=f"Question: {graph_state['question']}")) |
|
|
print(f"\n\n\n parsed_messages => {parsed_messages}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = self.llm_qa.invoke( |
|
|
parsed_messages, |
|
|
|
|
|
) |
|
|
print(f"LLAMA 2 -> QA Agent raw response: {response}") |
|
|
return response.model_dump() |
|
|
|
|
|
def generate_answer(self, question: str) -> str: |
|
|
response = self.agent.invoke( |
|
|
{ |
|
|
"messages": [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": system_prompt, |
|
|
}, |
|
|
{ |
|
|
"role": "human", |
|
|
"content": question, |
|
|
}, |
|
|
], |
|
|
"question": question, |
|
|
}, |
|
|
|
|
|
) |
|
|
print(f"Agent raw response: {response}") |
|
|
return response["messages"][-1].content |
|
|
|