File size: 7,314 Bytes
17e605d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fea783c
 
17e605d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import mimetypes
import base64
import yaml
from typing import TypedDict, Annotated
from dotenv import load_dotenv
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition
from langchain_core.messages.utils import (
    trim_messages,
    count_tokens_approximately
)

# Import our custom tools from their modules
from tools import webpage_reader_tool, python_repl_tool, transcribe_youtube_video_tool, wikipedia_query_tool, web_search_tool, read_excel_csv, arxiv_query_tool

load_dotenv()

class FinalAgent:
    
    def __init__(self, model_type="GOOGLE", system_prompt_path="system_prompt.yaml", use_memory=False):
        """
        Args: model_type "GOOGLE" or "HUGGINGFACE" or "OLLAMA"
        """
        with open(system_prompt_path, 'r') as stream:
            prompt_templates = yaml.safe_load(stream)

        self.model_type = model_type 

        if model_type == "HUGGINGFACE":
            from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
            # Initialize the Hugging Face model
            # Generate the chat interface, including the tools
            llm = HuggingFaceEndpoint(
                repo_id="Qwen/Qwen2.5-Coder-32B-Instruct"
            )

            chat = ChatHuggingFace(llm=llm, verbose=True)
        elif model_type == "OLLAMA":
            from langchain_ollama import ChatOllama
            chat = ChatOllama(model = "qwen3:8b")
        elif model_type == "GOOGLE":
            from langchain_google_genai import ChatGoogleGenerativeAI
            from langchain_core.rate_limiters import InMemoryRateLimiter
            rate_limiter = InMemoryRateLimiter(
                        # Max allowed rate per free API: 10 requests per minute, but we use 6 to avoid hitting the limit on subsquent answers.
                        requests_per_second=6/60, 
                        # Wake up every 100 ms to check whether allowed to make a request,
                        check_every_n_seconds=0.1,
                        max_bucket_size=10,  # Controls the maximum burst size.
                    )
            chat = ChatGoogleGenerativeAI(model="gemini-2.5-flash", rate_limiter=rate_limiter)
        else:
            raise ValueError(f'Model provider can be only one between GOOGLE, OLLAMA or HUGGINGFACE, received {model_type}')

        tools = [webpage_reader_tool,
                transcribe_youtube_video_tool,
                web_search_tool,
                wikipedia_query_tool,
                arxiv_query_tool,
                read_excel_csv,
                python_repl_tool,]
        chat_with_tools = chat.bind_tools(tools)

        class AgentState(TypedDict):
            messages: Annotated[list[AnyMessage], add_messages]

        def assistant(state: AgentState):
            messages = trim_messages(
                state["messages"],
                strategy="last",
                token_counter=count_tokens_approximately,
                max_tokens=1e6 if self.model_type == "GOOGLE" else 126000,
                start_on="human",
                end_on=("human", "tool"),
            )
            return {
                "messages": [chat_with_tools.invoke([SystemMessage(content=prompt_templates['system_prompt'])] + messages)],
            }

        builder = StateGraph(AgentState)

        builder.add_node("assistant", assistant)
        builder.add_node("tools", ToolNode(tools))

        builder.add_edge(START, "assistant")
        builder.add_conditional_edges("assistant", tools_condition)
        builder.add_edge("tools", "assistant")

        if use_memory:
            checkpointer = InMemorySaver()
            self.agent = builder.compile(checkpointer=checkpointer)
        else:
            checkpointer = None
            self.agent = builder.compile()
        print("FinalAgent initialized.")
    
    def clear_memory(self, thread_id: str) -> None:
        """ Clear the memory for a given thread_id. """
        memory = self.agent.checkpointer
        if memory is None:
            return
        try:
            # If it's an InMemorySaver (which MemorySaver is an alias for),
            # we can directly clear the storage and writes
            if hasattr(memory, 'storage') and hasattr(memory, 'writes'):
                # Clear all checkpoints for this thread_id (all namespaces)
                memory.storage.pop(thread_id, None)

                # Clear all writes for this thread_id (for all namespaces)
                keys_to_remove = [key for key in memory.writes.keys() if key[0] == thread_id]
                for key in keys_to_remove:
                    memory.writes.pop(key, None)

                print(f"Memory cleared for thread_id: {thread_id}")
                return

        except Exception as e:
            print(f"Error clearing InMemorySaver storage for thread_id {thread_id}: {e}")
    
    def __call__(self, question: str, attached_file: dict, recursion_limit=9) -> str:
        print(f"Agent received question (first 100 chars): {question[:100]}...")

        if attached_file['name'] != "" and attached_file['content'] is not None:
            mime_type, _ = mimetypes.guess_type(attached_file['name'])
            if mime_type.startswith("image/") or mime_type.startswith("audio/") or mime_type.startswith("video/"):
                # Image file - convert to base64
                encoded_file = base64.b64encode(attached_file['content']).decode('utf-8')
                #
                if self.model_type == "GOOGLE":
                    question = [{"type": "text", "text": question},
                            {"type": "image" if mime_type.startswith("image/") else "media",
                             "source_type": "base64",
                             "data": encoded_file,
                             "mime_type": mime_type,},
                                ]
                else:
                    question = f"{question}\n\nAttached file extension:{attached_file['name'].split('.')[-1]} - Attached file base64 encoded: \n{encoded_file}"
            elif mime_type.startswith("text/"):
                # Text-based file (like .py, .txt, .json)
                question = f"{question}\n\nAttached file extension:{attached_file['name'].split('.')[-1]} - Attached file content: \n{attached_file['content'].decode('utf-8')}"
            else:
                encoded_file = base64.b64encode(attached_file['content']).decode('utf-8')
                print(f"Unsupported file {attached_file['name']} type: {mime_type}. Only images, audio, video, and text files are supported.")
                question = f"{question}\n\nAttached file extension: {attached_file['name'].split('.')[-1]}. File path: {attached_file['path']} - Attached file base64 encoded:\n{encoded_file}"

        if recursion_limit>0:
            agent_reply = self.agent.invoke({"messages": [HumanMessage(content=question)]}, {"recursion_limit": recursion_limit})
        else:
            agent_reply = self.agent.invoke({"messages": [HumanMessage(content=question)]})
        return str(agent_reply['messages'][-1].content)