mabelwang21 commited on
Commit
96229ca
·
1 Parent(s): fec32f4

add download_file tool, update add_files

Browse files
Files changed (1) hide show
  1. agent.py +91 -16
agent.py CHANGED
@@ -5,6 +5,9 @@ import json
5
  import operator as op
6
  from pathlib import Path
7
  from typing import List, TypedDict, Annotated, Optional
 
 
 
8
 
9
  from langchain.tools import tool, StructuredTool
10
  from langchain_community.document_loaders import (
@@ -206,12 +209,53 @@ def python_interpreter(code: str) -> str:
206
  except Exception as e:
207
  return f"Error executing Python code: {e}"
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  # Update tools list
210
  tools: List[StructuredTool] = [
211
  calculate, web_search, wikipedia_search, image_recognition,
212
  read_pdf, read_csv, read_spreadsheet, transcribe_audio,
213
  youtube_transcript_tool, youtube_transcript_api, read_jsonl,
214
- python_interpreter # Add python_interpreter
215
  ]
216
 
217
  class AgentState(TypedDict):
@@ -237,23 +281,45 @@ class MyAgent:
237
  def add_files(self, file_paths: List[str]):
238
  """
239
  Load and index documents for RAG based on file extensions or URLs.
240
- Supports: PDF, CSV, audio (mp3/wav), and YouTube URLs.
241
  """
242
  for path in file_paths:
243
  ext = Path(path).suffix.lower()
244
- if ext == ".csv":
245
- loader = CSVLoader(path)
246
- self.docs.extend(loader.load())
247
- elif ext == ".pdf":
248
- loader = PyPDFLoader(path)
249
- self.docs.extend(loader.load())
250
- elif ext in [".mp3", ".wav"]:
251
- loader = AssemblyAIAudioTranscriptLoader(file_path=path)
252
- self.docs.extend(loader.load())
253
- elif "youtube" in path:
254
- loader = YoutubeLoader.from_youtube_url(path)
255
- self.docs.extend(loader.load())
256
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  continue
258
 
259
  def build_retriever(self):
@@ -286,7 +352,16 @@ class MyAgent:
286
 
287
  # Use structured tool attributes
288
  tool_desc = "\n".join(f"{t.name}: {t.description}" for t in self.tools)
289
- sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\nTools:\n{tool_desc}")
 
 
 
 
 
 
 
 
 
290
  state["messages"].append(sys_msg)
291
 
292
  # Optionally load RAG docs
 
5
  import operator as op
6
  from pathlib import Path
7
  from typing import List, TypedDict, Annotated, Optional
8
+ import requests
9
+ from urllib.parse import urlparse
10
+ import shutil
11
 
12
  from langchain.tools import tool, StructuredTool
13
  from langchain_community.document_loaders import (
 
209
  except Exception as e:
210
  return f"Error executing Python code: {e}"
211
 
212
+ @tool
213
+ def download_file(url_or_path: str, save_dir: str = "./downloads") -> str:
214
+ """Download a file from URL or copy from local path to the downloads directory."""
215
+ try:
216
+ # Create downloads directory if it doesn't exist
217
+ save_dir = Path(save_dir)
218
+ save_dir.mkdir(parents=True, exist_ok=True)
219
+
220
+ # Check if input is URL or local path
221
+ if url_or_path.startswith(('http://', 'https://')):
222
+ # Handle URL download
223
+ response = requests.get(url_or_path, stream=True)
224
+ response.raise_for_status()
225
+
226
+ # Get filename from URL or Content-Disposition header
227
+ filename = response.headers.get('Content-Disposition')
228
+ if filename and 'filename=' in filename:
229
+ filename = filename.split('filename=')[1].strip('"')
230
+ else:
231
+ filename = Path(urlparse(url_or_path).path).name
232
+
233
+ save_path = save_dir / filename
234
+
235
+ # Download file
236
+ with open(save_path, 'wb') as f:
237
+ shutil.copyfileobj(response.raw, f)
238
+
239
+ else:
240
+ # Handle local file copy
241
+ src_path = Path(url_or_path)
242
+ if not src_path.exists():
243
+ return f"Error: Source file {url_or_path} not found"
244
+
245
+ save_path = save_dir / src_path.name
246
+ shutil.copy2(src_path, save_path)
247
+
248
+ return f"File successfully saved to {save_path}"
249
+
250
+ except Exception as e:
251
+ return f"Error downloading/copying file: {e}"
252
+
253
  # Update tools list
254
  tools: List[StructuredTool] = [
255
  calculate, web_search, wikipedia_search, image_recognition,
256
  read_pdf, read_csv, read_spreadsheet, transcribe_audio,
257
  youtube_transcript_tool, youtube_transcript_api, read_jsonl,
258
+ python_interpreter, download_file # Add download_file
259
  ]
260
 
261
  class AgentState(TypedDict):
 
281
  def add_files(self, file_paths: List[str]):
282
  """
283
  Load and index documents for RAG based on file extensions or URLs.
284
+ Supports: PDF, CSV, Excel, JSONL, images, audio (mp3/wav), and YouTube URLs.
285
  """
286
  for path in file_paths:
287
  ext = Path(path).suffix.lower()
288
+ try:
289
+ if ext == ".csv":
290
+ loader = CSVLoader(path)
291
+ self.docs.extend(loader.load())
292
+ elif ext == ".pdf":
293
+ loader = PyPDFLoader(path)
294
+ self.docs.extend(loader.load())
295
+ elif ext in [".xlsx", ".xls"]:
296
+ # Handle spreadsheets
297
+ import pandas as pd
298
+ df = pd.read_excel(path)
299
+ text_content = df.to_string()
300
+ self.docs.append(Document(page_content=text_content))
301
+ elif ext == ".jsonl":
302
+ # Handle JSONL files
303
+ with open(path, 'r', encoding='utf-8') as file:
304
+ content = [json.loads(line) for line in file]
305
+ text_content = json.dumps(content, indent=2)
306
+ self.docs.append(Document(page_content=text_content))
307
+ elif ext in [".png", ".jpg", ".jpeg"]:
308
+ # Handle images
309
+ text = pytesseract.image_to_string(Image.open(path))
310
+ if text.strip():
311
+ self.docs.append(Document(page_content=text))
312
+ elif ext in [".mp3", ".wav"]:
313
+ loader = AssemblyAIAudioTranscriptLoader(file_path=path)
314
+ self.docs.extend(loader.load())
315
+ elif "youtube" in path:
316
+ loader = YoutubeLoader.from_youtube_url(path)
317
+ self.docs.extend(loader.load())
318
+ else:
319
+ print(f"Unsupported file type: {ext}")
320
+ continue
321
+ except Exception as e:
322
+ print(f"Error loading {path}: {e}")
323
  continue
324
 
325
  def build_retriever(self):
 
352
 
353
  # Use structured tool attributes
354
  tool_desc = "\n".join(f"{t.name}: {t.description}" for t in self.tools)
355
+
356
+ # Enhanced system prompt with RAG guidance
357
+ rag_prompt = """
358
+ If the question seems to be about any loaded documents, ALWAYS:
359
+ 1. Use the rag_search tool first to find relevant information
360
+ 2. Base your answer on the retrieved content
361
+ 3. If no relevant content is found, say so
362
+ """
363
+
364
+ sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\n{rag_prompt if file_paths else ''}\n\nTools:\n{tool_desc}")
365
  state["messages"].append(sys_msg)
366
 
367
  # Optionally load RAG docs