Rudraprasad commited on
Commit
4816b83
·
verified ·
1 Parent(s): 70b9c58

Update agents.py

Browse files
Files changed (1) hide show
  1. agents.py +76 -59
agents.py CHANGED
@@ -1,62 +1,79 @@
1
- # agent.py
2
-
3
  import os
4
- from typing import TypedDict, Annotated, List
5
- from langchain_community.tools import DuckDuckGoSearchResults
6
- from langchain_groq import ChatGroq
7
- from langgraph.graph import StateGraph, END, START
8
- from langgraph.graph.message import AnyMessage, add_messages
9
- from langchain_core.messages import HumanMessage, AIMessage
10
-
11
- # --- Setup API Key ---
12
- groq_api_key = "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" # or hardcode if testing
13
-
14
- # --- Setup Search Tool ---
15
- search_tool = DuckDuckGoSearchResults()
16
-
17
- # --- Setup Base LLM Model (for generation) ---
18
- base_llm = ChatGroq(
19
- model_name="llama3-8b-8192",
20
- api_key=groq_api_key,
21
- temperature=0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  )
 
 
23
 
24
- # --- Final Content Extractor Tool (LLM based) ---
25
- class FinalContentExtractor:
26
- def __init__(self, model):
27
- self.model = model
28
-
29
- def extract(self, question: str, search_result: str) -> str:
30
- prompt = (
31
- f"You are an intelligent assistant. "
32
- f"Given the user's question and the search result, extract the most direct and relevant answer.\n\n"
33
- f"Question: {question}\n\n"
34
- f"Search Result: {search_result}\n\n"
35
- f"Answer:"
36
- )
37
- response = self.model.invoke([HumanMessage(content=prompt)])
38
- return response.content.strip()
39
-
40
- # --- Define State Class for the Graph ---
41
- class AgentState(TypedDict):
42
- messages: Annotated[List[AnyMessage], add_messages]
43
-
44
- # --- Define Search Node ---
45
- def search_node(state: AgentState):
46
- user_message = state["messages"][-1]
47
- query = user_message.content
48
- search_result = search_tool.invoke({"query": query})
49
-
50
- search_summary = f"Top Search Results:\n{search_result}"
51
- ai_message = AIMessage(content=search_summary)
52
- return {"messages": state["messages"] + [ai_message]}
53
-
54
- # --- Build the Graph ---
55
- builder = StateGraph(AgentState)
56
- builder.add_node("search", search_node)
57
- builder.add_edge(START, "search")
58
- builder.add_edge("search", END)
59
- graph = builder.compile()
60
-
61
- # --- Export these for app.py ---
62
- __all__ = ["graph", "FinalContentExtractor", "base_llm"]
 
1
+ import importlib
 
2
  import os
3
+
4
+ import requests
5
+ import yaml
6
+ import pandas as pd
7
+
8
+ from config import DEFAULT_API_URL
9
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, VisitWebpageTool, WikipediaSearchTool, Tool, OpenAIServerModel, SpeechToTextTool
10
+
11
+ class GetTaskFileTool(Tool):
12
+ name = "get_task_file_tool"
13
+ description = """This tool downloads the file content associated with the given task_id if exists. Returns absolute file path"""
14
+ inputs = {
15
+ "task_id": {"type": "string", "description": "Task id"},
16
+ "file_name": {"type": "string", "description": "File name"},
17
+ }
18
+ output_type = "string"
19
+
20
+ def forward(self, task_id: str, file_name: str) -> str:
21
+ response = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=15)
22
+ response.raise_for_status()
23
+ with open(file_name, 'wb') as file:
24
+ file.write(response.content)
25
+ return os.path.abspath(file_name)
26
+
27
+ class LoadXlsxFileTool(Tool):
28
+ name = "load_xlsx_file_tool"
29
+ description = """This tool loads xlsx file into pandas and returns it"""
30
+ inputs = {
31
+ "file_path": {"type": "string", "description": "File path"}
32
+ }
33
+ output_type = "object"
34
+
35
+ def forward(self, file_path: str) -> object:
36
+ return pd.read_excel(file_path)
37
+
38
+ class LoadTextFileTool(Tool):
39
+ name = "load_text_file_tool"
40
+ description = """This tool loads any text file"""
41
+ inputs = {
42
+ "file_path": {"type": "string", "description": "File path"}
43
+ }
44
+ output_type = "string"
45
+
46
+ def forward(self, file_path: str) -> object:
47
+ with open(file_path, 'r', encoding='utf-8') as file:
48
+ return file.read()
49
+
50
+
51
+ prompts = yaml.safe_load(
52
+ importlib.resources.files("smolagents.prompts").joinpath("code_agent.yaml").read_text()
53
  )
54
+ prompts["system_prompt"] = ("You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. "
55
+ + prompts["system_prompt"])
56
 
57
+ def init_agent():
58
+ gemini_model = OpenAIServerModel(
59
+ model_id="gemini-2.0-flash",
60
+ api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
61
+ api_key=os.getenv("API_KEY"),
62
+ temperature=0.7
63
+ )
64
+ agent = CodeAgent(
65
+ tools=[
66
+ DuckDuckGoSearchTool(),
67
+ VisitWebpageTool(),
68
+ WikipediaSearchTool(),
69
+ GetTaskFileTool(),
70
+ SpeechToTextTool(),
71
+ LoadXlsxFileTool(),
72
+ LoadTextFileTool()
73
+ ],
74
+ model=gemini_model,
75
+ prompt_templates=prompts,
76
+ max_steps=15,
77
+ additional_authorized_imports = ["pandas"]
78
+ )
79
+ return agent