File size: 5,231 Bytes
6800d28
 
523e34e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6800d28
7cdf11d
6800d28
523e34e
6800d28
 
 
 
13317d6
b97b365
523e34e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b97b365
 
523e34e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc287f3
523e34e
854a96f
f61d1e4
523e34e
 
 
 
6800d28
523e34e
6800d28
 
523e34e
6800d28
523e34e
 
 
 
 
 
 
 
 
 
 
 
6800d28
523e34e
6800d28
 
523e34e
6800d28
523e34e
 
6800d28
523e34e
b97b365
6800d28
523e34e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.checkpoint.memory import MemorySaver
from langchain.globals import set_debug
from langchain.globals import set_verbose
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt.chat_agent_executor import AgentState

from smolagents import DuckDuckGoSearchTool
from smolagents import PythonInterpreterTool
from tools import analyze_audio
from tools import analyze_excel
from tools import analyze_image
from tools import analyze_video
from tools import download_file_for_task
from tools import read_file_contents
from tools import search_arxiv
from tools import search_tavily
from tools import search_wikipedia
from tools import SmolagentToolWrapper
from tools import tavily_extract_tool
from utils import get_llm
from config import GOOGLE_API_KEY, AGENT_MODEL_NAME


GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
if not GOOGLE_API_KEY:
    raise ValueError("GOOGLE_API_KEY environment variable is not set.")

AGENT_MODEL_NAME = os.getenv("AGENT_MODEL_NAME", "gemini-2.0-flash-lite")

MULTIMODAL_TASK_SOLVER_PROMPT = """
You are a specialized multimodal task-solving AI assistant capable of handling complex data analysis and information retrieval tasks.
Core Operating Guidelines:
- Employ systematic analysis: Break down problems into logical steps
- Maintain brevity: Provide answers in the most concise format possible - raw numbers, single words, or comma-delimited lists
- Format compliance:
  * Numbers: No commas, units, or currency symbols
  * Lists: Pure comma-separated values without additional text
  * Text: Bare minimum words, no sentences or explanations
- Tool utilization: 
  * For multimedia content (images, audio, video) - use dedicated analysis tools
  * For data processing (Excel, structured data) - use appropriate parsers
  * For information retrieval - leverage search tools
- Verification principle: Never guess - use available tools to verify information
- Code usage: Implement Python code for calculations and data transformations
- Answer format: Always prefix final answers with 'FINAL ANSWER: '
- Counting queries: Return only the numerical count
- Listing queries: Return only the comma-separated items
- Sorting queries: Return only the ordered list

Sample Responses:
Q: Current Bitcoin price in USD? A: 47392
Q: Sort these colors: blue, red, azure A: azure, blue, red
Q: Capital of France? A: Paris
Q: Count vowels in 'hello' A: 2
Q: Temperature scale used in USA? A: Fahrenheit
Q: List prime numbers under 10 A: 2, 3, 5, 7
Q: Most streamed artist 2023? A: Taylor Swift
"""

#set_debug(True)
#set_verbose(True)


class MultiModalTaskState(AgentState):
    task_identifier: str
    query_text: str
    input_file_path: str


class MultiModalAgent:
    def __init__(self, model_name: str | None = None):
        if model_name is None:
            model_name = AGENT_MODEL_NAME
        llm = self._get_llm(model_name)
        tools = self._get_tools()
        self.agent = create_react_agent(
            model=llm,
            tools=tools,
            prompt=MULTIMODAL_TASK_SOLVER_PROMPT,
            checkpointer=MemorySaver()
        )

    def _get_llm(self, model_name: str):
        return get_llm(
            llm_provider_api_key=GOOGLE_API_KEY,
            model_name=model_name,
        )

    def _get_tools(self):
        tools = [
            SmolagentToolWrapper(DuckDuckGoSearchTool()),
            SmolagentToolWrapper(PythonInterpreterTool()),
            download_file_for_task,
            read_file_contents,
            analyze_audio,
            analyze_image,
            analyze_excel,
            analyze_video,
            search_arxiv,
            search_tavily,
            search_wikipedia,
            tavily_extract_tool,
        ]
        return ToolNode(tools)

    async def __call__(
        self, task_identifier: str, query_text: str, input_file_path: str | None = None
    ) -> str:

        execution_config = RunnableConfig(
            recursion_limit=64,
            configurable={ "thread_id": task_identifier }
        )

        if not input_file_path:
            input_file_path = "None - no file present"

        user_input = HumanMessage(
            content=
            [
                {
                    "type": "text",
                    "text": f"Task Id: {task_identifier}, Question: {query_text}, Filename: {input_file_path}. If a filename is present (and is not 'None'), download the file for the task that's referenced in the question. If there isn't a filename present, please use tools where applicable."
                }
            ]
        )

        response = await self.agent.ainvoke(
            {
                "messages": [user_input],
                "question": query_text,
                "task_id": task_identifier,
                "file_name": input_file_path
            }, execution_config)

        final_response = response['messages'][-1].content
        if "FINAL ANSWER: " in final_response:
            return final_response.split("FINAL ANSWER: ", 1)[1].strip()
        else:
            return final_response