ktluege commited on
Commit
9feda56
Β·
verified Β·
1 Parent(s): 3269d51

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +75 -14
agent.py CHANGED
@@ -39,6 +39,9 @@ interpreter_instance = CodeInterpreter()
39
 
40
  @tool
41
  def wiki_search(query: str) -> str:
 
 
 
42
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
43
  return {"wiki_results": "\n\n---\n\n".join(
44
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n{d.page_content}\n</Document>'
@@ -47,6 +50,9 @@ def wiki_search(query: str) -> str:
47
 
48
  @tool
49
  def web_search(query: str) -> str:
 
 
 
50
  docs = TavilySearchResults(max_results=3).invoke(query=query)
51
  return {"web_results": "\n\n---\n\n".join(
52
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n{d.page_content}\n</Document>'
@@ -55,6 +61,9 @@ def web_search(query: str) -> str:
55
 
56
  @tool
57
  def arxiv_search(query: str) -> str:
 
 
 
58
  docs = ArxivLoader(query=query, load_max_docs=3).load()
59
  return {"arxiv_results": "\n\n---\n\n".join(
60
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n{d.page_content[:1000]}\n</Document>'
@@ -63,31 +72,61 @@ def arxiv_search(query: str) -> str:
63
 
64
  @tool
65
  def execute_code_multilang(code: str, language: str = "python") -> str:
 
 
 
66
  return interpreter_instance.execute_code(code, language=language)
67
 
68
  # example numeric tools
69
  @tool
70
  def multiply(a: float, b: float) -> float:
 
 
 
71
  return a * b
 
72
  @tool
73
  def add(a: float, b: float) -> float:
 
 
 
74
  return a + b
 
75
  @tool
76
  def subtract(a: float, b: float) -> float:
 
 
 
77
  return a - b
 
78
  @tool
79
  def divide(a: float, b: float) -> float:
 
 
 
80
  if b == 0:
81
  raise ValueError("Cannot divide by zero.")
82
  return a / b
 
83
  @tool
84
  def modulus(a: int, b: int) -> int:
 
 
 
85
  return a % b
 
86
  @tool
87
  def power(a: float, b: float) -> float:
 
 
 
88
  return a ** b
 
89
  @tool
90
  def square_root(a: float) -> float | complex:
 
 
 
91
  if a >= 0:
92
  return a ** 0.5
93
  return cmath.sqrt(a)
@@ -95,6 +134,9 @@ def square_root(a: float) -> float | complex:
95
  # file and document tools (save/read, download, OCR, CSV/Excel)
96
  @tool
97
  def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
 
 
 
98
  temp_dir = tempfile.gettempdir()
99
  filepath = os.path.join(temp_dir, filename or f"file_{uuid.uuid4().hex[:8]}.txt")
100
  with open(filepath, "w") as f:
@@ -103,19 +145,26 @@ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
103
 
104
  @tool
105
  def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
 
 
 
106
  try:
107
  fname = filename or os.path.basename(urlparse(url).path) or f"file_{uuid.uuid4().hex[:8]}"
108
  path = os.path.join(tempfile.gettempdir(), fname)
109
  resp = requests.get(url, stream=True)
110
  resp.raise_for_status()
111
  with open(path, "wb") as f:
112
- for chunk in resp.iter_content(8192): f.write(chunk)
 
113
  return f"Downloaded to {path}"
114
  except Exception as e:
115
  return str(e)
116
 
117
  @tool
118
  def extract_text_from_image(image_path: str) -> str:
 
 
 
119
  try:
120
  img = Image.open(image_path)
121
  return pytesseract.image_to_string(img)
@@ -124,28 +173,37 @@ def extract_text_from_image(image_path: str) -> str:
124
 
125
  @tool
126
  def analyze_csv_file(file_path: str, query: str) -> str:
 
 
 
127
  df = pd.read_csv(file_path)
128
  return f"Rows: {len(df)}, Columns: {list(df.columns)}\n{df.describe()}"
129
 
130
  @tool
131
  def analyze_excel_file(file_path: str, query: str) -> str:
 
 
 
132
  df = pd.read_excel(file_path)
133
  return f"Rows: {len(df)}, Columns: {list(df.columns)}\n{df.describe()}"
134
 
135
  # image analysis/transforms
136
  @tool
137
  def analyze_image(image_base64: str) -> Dict[str, Any]:
138
- from image_processing import decode_image
 
 
139
  img = decode_image(image_base64)
140
- w,h = img.size
141
- return {"dimensions": (w,h), "mode": img.mode}
142
 
143
  @tool
144
  def transform_image(image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
145
- from image_processing import decode_image, encode_image
 
 
146
  img = decode_image(image_base64)
147
- # ... apply op as before ...
148
- # return {"image": encode_image(path)}
149
  return {"error": "placeholder"}
150
 
151
  # combine all tools into list
@@ -162,7 +220,7 @@ tools = [
162
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
163
  sys_msg = SystemMessage(content=f.read())
164
 
165
-
166
  emb = HuggingFaceEmbeddings(
167
  model_name="sentence-transformers/all-mpnet-base-v2"
168
  )
@@ -171,7 +229,6 @@ sup = create_client(
171
  os.getenv("SUPABASE_URL"),
172
  os.getenv("SUPABASE_SERVICE_ROLE_KEY")
173
  )
174
- # vectorstore setup (Supabase)
175
  vector_store = SupabaseVectorStore(
176
  client=sup,
177
  embedding=emb,
@@ -180,10 +237,10 @@ vector_store = SupabaseVectorStore(
180
  )
181
 
182
  def build_graph():
183
- """Build the LangGraph agent using OpenAI ChatGPT only."""
 
 
184
  # Initialize the OpenAI LLM
185
- from langchain_openai import ChatOpenAI
186
-
187
  llm = ChatOpenAI(
188
  model="gpt-3.5-turbo",
189
  temperature=0,
@@ -196,9 +253,7 @@ def build_graph():
196
  query = state["messages"][0].content
197
  hits = vector_store.similarity_search(query, k=1)
198
  if hits:
199
- # return the raw snippet as the single message
200
  return {"messages": [sys_msg, HumanMessage(content=hits[0].page_content)]}
201
- # fall back to LLM-with-tools on the user prompt
202
  resp = llm_with_tools.invoke([sys_msg] + state["messages"])
203
  return {"messages": [resp]}
204
 
@@ -219,3 +274,9 @@ def build_graph():
219
 
220
  return builder.compile()
221
 
 
 
 
 
 
 
 
39
 
40
  @tool
41
  def wiki_search(query: str) -> str:
42
+ """
43
+ Search Wikipedia for a query and return up to 2 formatted results.
44
+ """
45
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
46
  return {"wiki_results": "\n\n---\n\n".join(
47
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n{d.page_content}\n</Document>'
 
50
 
51
  @tool
52
  def web_search(query: str) -> str:
53
+ """
54
+ Search the web via Tavily for a query and return up to 3 formatted results.
55
+ """
56
  docs = TavilySearchResults(max_results=3).invoke(query=query)
57
  return {"web_results": "\n\n---\n\n".join(
58
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n{d.page_content}\n</Document>'
 
61
 
62
  @tool
63
  def arxiv_search(query: str) -> str:
64
+ """
65
+ Search arXiv for a query and return up to 3 formatted results.
66
+ """
67
  docs = ArxivLoader(query=query, load_max_docs=3).load()
68
  return {"arxiv_results": "\n\n---\n\n".join(
69
  f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n{d.page_content[:1000]}\n</Document>'
 
72
 
73
  @tool
74
  def execute_code_multilang(code: str, language: str = "python") -> str:
75
+ """
76
+ Execute code in multiple languages (Python, Bash, SQL, C, Java) and return execution output.
77
+ """
78
  return interpreter_instance.execute_code(code, language=language)
79
 
80
  # example numeric tools
81
  @tool
82
  def multiply(a: float, b: float) -> float:
83
+ """
84
+ Multiply two numbers and return the product.
85
+ """
86
  return a * b
87
+
88
  @tool
89
  def add(a: float, b: float) -> float:
90
+ """
91
+ Add two numbers and return the sum.
92
+ """
93
  return a + b
94
+
95
  @tool
96
  def subtract(a: float, b: float) -> float:
97
+ """
98
+ Subtract the second number from the first and return the result.
99
+ """
100
  return a - b
101
+
102
  @tool
103
  def divide(a: float, b: float) -> float:
104
+ """
105
+ Divide the first number by the second; raises error if division by zero.
106
+ """
107
  if b == 0:
108
  raise ValueError("Cannot divide by zero.")
109
  return a / b
110
+
111
  @tool
112
  def modulus(a: int, b: int) -> int:
113
+ """
114
+ Return the remainder of a divided by b.
115
+ """
116
  return a % b
117
+
118
  @tool
119
  def power(a: float, b: float) -> float:
120
+ """
121
+ Raise a to the power of b and return the result.
122
+ """
123
  return a ** b
124
+
125
  @tool
126
  def square_root(a: float) -> float | complex:
127
+ """
128
+ Return the square root of a number; returns complex for negative inputs.
129
+ """
130
  if a >= 0:
131
  return a ** 0.5
132
  return cmath.sqrt(a)
 
134
  # file and document tools (save/read, download, OCR, CSV/Excel)
135
  @tool
136
  def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
137
+ """
138
+ Save content to a temporary file and return the file path.
139
+ """
140
  temp_dir = tempfile.gettempdir()
141
  filepath = os.path.join(temp_dir, filename or f"file_{uuid.uuid4().hex[:8]}.txt")
142
  with open(filepath, "w") as f:
 
145
 
146
  @tool
147
  def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
148
+ """
149
+ Download a file from a URL, save locally, and return the file path or error string.
150
+ """
151
  try:
152
  fname = filename or os.path.basename(urlparse(url).path) or f"file_{uuid.uuid4().hex[:8]}"
153
  path = os.path.join(tempfile.gettempdir(), fname)
154
  resp = requests.get(url, stream=True)
155
  resp.raise_for_status()
156
  with open(path, "wb") as f:
157
+ for chunk in resp.iter_content(8192):
158
+ f.write(chunk)
159
  return f"Downloaded to {path}"
160
  except Exception as e:
161
  return str(e)
162
 
163
  @tool
164
  def extract_text_from_image(image_path: str) -> str:
165
+ """
166
+ Extract and return text from an image file using OCR.
167
+ """
168
  try:
169
  img = Image.open(image_path)
170
  return pytesseract.image_to_string(img)
 
173
 
174
  @tool
175
  def analyze_csv_file(file_path: str, query: str) -> str:
176
+ """
177
+ Analyze a CSV file: return row/column counts and summary statistics.
178
+ """
179
  df = pd.read_csv(file_path)
180
  return f"Rows: {len(df)}, Columns: {list(df.columns)}\n{df.describe()}"
181
 
182
  @tool
183
  def analyze_excel_file(file_path: str, query: str) -> str:
184
+ """
185
+ Analyze an Excel file: return row/column counts and summary statistics.
186
+ """
187
  df = pd.read_excel(file_path)
188
  return f"Rows: {len(df)}, Columns: {list(df.columns)}\n{df.describe()}"
189
 
190
  # image analysis/transforms
191
  @tool
192
  def analyze_image(image_base64: str) -> Dict[str, Any]:
193
+ """
194
+ Analyze a base64-encoded image: return dimensions and mode.
195
+ """
196
  img = decode_image(image_base64)
197
+ w, h = img.size
198
+ return {"dimensions": (w, h), "mode": img.mode}
199
 
200
  @tool
201
  def transform_image(image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
202
+ """
203
+ Apply a transformation to a base64-encoded image; placeholder implementation.
204
+ """
205
  img = decode_image(image_base64)
206
+ # operations logic here
 
207
  return {"error": "placeholder"}
208
 
209
  # combine all tools into list
 
220
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
221
  sys_msg = SystemMessage(content=f.read())
222
 
223
+ # vectorstore setup (Supabase)
224
  emb = HuggingFaceEmbeddings(
225
  model_name="sentence-transformers/all-mpnet-base-v2"
226
  )
 
229
  os.getenv("SUPABASE_URL"),
230
  os.getenv("SUPABASE_SERVICE_ROLE_KEY")
231
  )
 
232
  vector_store = SupabaseVectorStore(
233
  client=sup,
234
  embedding=emb,
 
237
  )
238
 
239
  def build_graph():
240
+ """
241
+ Build the LangGraph agent using OpenAI ChatGPT only.
242
+ """
243
  # Initialize the OpenAI LLM
 
 
244
  llm = ChatOpenAI(
245
  model="gpt-3.5-turbo",
246
  temperature=0,
 
253
  query = state["messages"][0].content
254
  hits = vector_store.similarity_search(query, k=1)
255
  if hits:
 
256
  return {"messages": [sys_msg, HumanMessage(content=hits[0].page_content)]}
 
257
  resp = llm_with_tools.invoke([sys_msg] + state["messages"])
258
  return {"messages": [resp]}
259
 
 
274
 
275
  return builder.compile()
276
 
277
+ # Optional test
278
+ if __name__ == "__main__":
279
+ graph = build_graph()
280
+ msgs = graph.invoke({"messages": [HumanMessage(content="Hello world")]})
281
+ for m in msgs["messages"]:
282
+ print(m.content)