|
|
import mimetypes |
|
|
import base64 |
|
|
import yaml |
|
|
from typing import TypedDict, Annotated |
|
|
from dotenv import load_dotenv |
|
|
from langgraph.checkpoint.memory import InMemorySaver |
|
|
from langgraph.graph.message import add_messages |
|
|
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage |
|
|
from langgraph.prebuilt import ToolNode |
|
|
from langgraph.graph import START, StateGraph |
|
|
from langgraph.prebuilt import tools_condition |
|
|
from langchain_core.messages.utils import ( |
|
|
trim_messages, |
|
|
count_tokens_approximately |
|
|
) |
|
|
|
|
|
|
|
|
from tools import webpage_reader_tool, python_repl_tool, transcribe_youtube_video_tool, wikipedia_query_tool, web_search_tool, read_excel_csv, arxiv_query_tool |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
class FinalAgent: |
|
|
|
|
|
def __init__(self, model_type="GOOGLE", system_prompt_path="system_prompt.yaml", use_memory=False): |
|
|
""" |
|
|
Args: model_type "GOOGLE" or "HUGGINGFACE" or "OLLAMA" |
|
|
""" |
|
|
with open(system_prompt_path, 'r') as stream: |
|
|
prompt_templates = yaml.safe_load(stream) |
|
|
|
|
|
self.model_type = model_type |
|
|
|
|
|
if model_type == "HUGGINGFACE": |
|
|
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace |
|
|
|
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
|
repo_id="Qwen/Qwen2.5-Coder-32B-Instruct" |
|
|
) |
|
|
|
|
|
chat = ChatHuggingFace(llm=llm, verbose=True) |
|
|
elif model_type == "OLLAMA": |
|
|
from langchain_ollama import ChatOllama |
|
|
chat = ChatOllama(model = "qwen3:8b") |
|
|
elif model_type == "GOOGLE": |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from langchain_core.rate_limiters import InMemoryRateLimiter |
|
|
rate_limiter = InMemoryRateLimiter( |
|
|
|
|
|
requests_per_second=6/60, |
|
|
|
|
|
check_every_n_seconds=0.1, |
|
|
max_bucket_size=10, |
|
|
) |
|
|
chat = ChatGoogleGenerativeAI(model="gemini-2.5-flash", rate_limiter=rate_limiter) |
|
|
else: |
|
|
raise ValueError(f'Model provider can be only one between GOOGLE, OLLAMA or HUGGINGFACE, received {model_type}') |
|
|
|
|
|
tools = [webpage_reader_tool, |
|
|
transcribe_youtube_video_tool, |
|
|
web_search_tool, |
|
|
wikipedia_query_tool, |
|
|
arxiv_query_tool, |
|
|
read_excel_csv, |
|
|
python_repl_tool,] |
|
|
chat_with_tools = chat.bind_tools(tools) |
|
|
|
|
|
class AgentState(TypedDict): |
|
|
messages: Annotated[list[AnyMessage], add_messages] |
|
|
|
|
|
def assistant(state: AgentState): |
|
|
messages = trim_messages( |
|
|
state["messages"], |
|
|
strategy="last", |
|
|
token_counter=count_tokens_approximately, |
|
|
max_tokens=1e6 if self.model_type == "GOOGLE" else 126000, |
|
|
start_on="human", |
|
|
end_on=("human", "tool"), |
|
|
) |
|
|
return { |
|
|
"messages": [chat_with_tools.invoke([SystemMessage(content=prompt_templates['system_prompt'])] + messages)], |
|
|
} |
|
|
|
|
|
builder = StateGraph(AgentState) |
|
|
|
|
|
builder.add_node("assistant", assistant) |
|
|
builder.add_node("tools", ToolNode(tools)) |
|
|
|
|
|
builder.add_edge(START, "assistant") |
|
|
builder.add_conditional_edges("assistant", tools_condition) |
|
|
builder.add_edge("tools", "assistant") |
|
|
|
|
|
if use_memory: |
|
|
checkpointer = InMemorySaver() |
|
|
self.agent = builder.compile(checkpointer=checkpointer) |
|
|
else: |
|
|
checkpointer = None |
|
|
self.agent = builder.compile() |
|
|
print("FinalAgent initialized.") |
|
|
|
|
|
def clear_memory(self, thread_id: str) -> None: |
|
|
""" Clear the memory for a given thread_id. """ |
|
|
memory = self.agent.checkpointer |
|
|
if memory is None: |
|
|
return |
|
|
try: |
|
|
|
|
|
|
|
|
if hasattr(memory, 'storage') and hasattr(memory, 'writes'): |
|
|
|
|
|
memory.storage.pop(thread_id, None) |
|
|
|
|
|
|
|
|
keys_to_remove = [key for key in memory.writes.keys() if key[0] == thread_id] |
|
|
for key in keys_to_remove: |
|
|
memory.writes.pop(key, None) |
|
|
|
|
|
print(f"Memory cleared for thread_id: {thread_id}") |
|
|
return |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error clearing InMemorySaver storage for thread_id {thread_id}: {e}") |
|
|
|
|
|
def __call__(self, question: str, attached_file: dict, recursion_limit=9) -> str: |
|
|
print(f"Agent received question (first 100 chars): {question[:100]}...") |
|
|
|
|
|
if attached_file['name'] != "" and attached_file['content'] is not None: |
|
|
mime_type, _ = mimetypes.guess_type(attached_file['name']) |
|
|
if mime_type.startswith("image/") or mime_type.startswith("audio/") or mime_type.startswith("video/"): |
|
|
|
|
|
encoded_file = base64.b64encode(attached_file['content']).decode('utf-8') |
|
|
|
|
|
if self.model_type == "GOOGLE": |
|
|
question = [{"type": "text", "text": question}, |
|
|
{"type": "image" if mime_type.startswith("image/") else "media", |
|
|
"source_type": "base64", |
|
|
"data": encoded_file, |
|
|
"mime_type": mime_type,}, |
|
|
] |
|
|
else: |
|
|
question = f"{question}\n\nAttached file extension:{attached_file['name'].split('.')[-1]} - Attached file base64 encoded: \n{encoded_file}" |
|
|
elif mime_type.startswith("text/"): |
|
|
|
|
|
question = f"{question}\n\nAttached file extension:{attached_file['name'].split('.')[-1]} - Attached file content: \n{attached_file['content'].decode('utf-8')}" |
|
|
else: |
|
|
encoded_file = base64.b64encode(attached_file['content']).decode('utf-8') |
|
|
print(f"Unsupported file {attached_file['name']} type: {mime_type}. Only images, audio, video, and text files are supported.") |
|
|
question = f"{question}\n\nAttached file extension: {attached_file['name'].split('.')[-1]}. File path: {attached_file['path']} - Attached file base64 encoded:\n{encoded_file}" |
|
|
|
|
|
if recursion_limit>0: |
|
|
agent_reply = self.agent.invoke({"messages": [HumanMessage(content=question)]}, {"recursion_limit": recursion_limit}) |
|
|
else: |
|
|
agent_reply = self.agent.invoke({"messages": [HumanMessage(content=question)]}) |
|
|
return str(agent_reply['messages'][-1].content) |