iammartian0 commited on
Commit
cde9caa
·
verified ·
1 Parent(s): 81917a3

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +236 -0
agent.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cerebras-powered Research Agent for GAIA-style questions
3
+ """
4
+ import os
5
+ from cerebras.cloud.sdk import Cerebras
6
+ from tavily import TavilyClient
7
+
8
+
9
+ class WebSearchTool:
10
+ """Search the web using Tavily"""
11
+
12
+ def __init__(self, api_key: str):
13
+ self.client = TavilyClient(api_key=api_key)
14
+
15
+ def search(self, query: str, max_results: int = 5) -> str:
16
+ try:
17
+ response = self.client.search(
18
+ query=query,
19
+ search_depth="advanced",
20
+ max_results=max_results,
21
+ include_answer=True
22
+ )
23
+
24
+ output = []
25
+ if response.get("answer"):
26
+ output.append(f"Quick Answer: {response['answer']}\n")
27
+
28
+ output.append("Search Results:")
29
+ for i, result in enumerate(response.get("results", []), 1):
30
+ output.append(f"\n{i}. {result['title']}")
31
+ output.append(f" {result['content'][:200]}...")
32
+
33
+ return "\n".join(output)
34
+ except Exception as e:
35
+ return f"Search error: {str(e)}"
36
+
37
+
38
+ class FileReaderTool:
39
+ """Read various file formats"""
40
+
41
+ def read(self, file_path: str) -> str:
42
+ if not os.path.exists(file_path):
43
+ return f"Error: File not found"
44
+
45
+ ext = os.path.splitext(file_path)[1].lower()
46
+
47
+ try:
48
+ if ext == '.docx':
49
+ from docx import Document
50
+ doc = Document(file_path)
51
+ text = []
52
+ for para in doc.paragraphs:
53
+ if para.text.strip():
54
+ text.append(para.text)
55
+ for table in doc.tables:
56
+ for row in table.rows:
57
+ cells = [cell.text.strip() for cell in row.cells]
58
+ text.append(" | ".join(cells))
59
+ return "\n".join(text)
60
+
61
+ elif ext == '.pdf':
62
+ import pdfplumber
63
+ with pdfplumber.open(file_path) as pdf:
64
+ text = []
65
+ for page in pdf.pages:
66
+ if page.extract_text():
67
+ text.append(page.extract_text())
68
+ return "\n".join(text)
69
+
70
+ elif ext in ['.xlsx', '.xls', '.csv']:
71
+ import pandas as pd
72
+ df = pd.read_csv(file_path) if ext == '.csv' else pd.read_excel(file_path)
73
+ return df.to_string()
74
+
75
+ elif ext in ['.txt', '.md', '.json']:
76
+ with open(file_path, 'r', encoding='utf-8') as f:
77
+ return f.read()
78
+
79
+ else:
80
+ return f"Unsupported file type: {ext}"
81
+ except Exception as e:
82
+ return f"Error reading file: {str(e)}"
83
+
84
+
85
+ class ImageAnalysisTool:
86
+ """Analyze images using OCR"""
87
+
88
+ def analyze(self, image_path: str) -> str:
89
+ if not os.path.exists(image_path):
90
+ return "Error: Image not found"
91
+
92
+ try:
93
+ import pytesseract
94
+ from PIL import Image
95
+
96
+ img = Image.open(image_path)
97
+ text = pytesseract.image_to_string(img)
98
+ return f"OCR text:\n{text}" if text.strip() else "No text found"
99
+ except ImportError:
100
+ return "Error: pytesseract not installed"
101
+ except Exception as e:
102
+ return f"Error: {str(e)}"
103
+
104
+
105
+ class ResearchAgent:
106
+ """
107
+ Cerebras-powered research agent
108
+
109
+ Features:
110
+ - Web search via Tavily
111
+ - File reading (PDF, DOCX, CSV, Excel, TXT)
112
+ - Image OCR
113
+ - Fast inference via Cerebras
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ cerebras_api_key: str = None,
119
+ tavily_api_key: str = None,
120
+ model: str = "llama3.1-70b"
121
+ ):
122
+ """
123
+ Initialize agent
124
+
125
+ Args:
126
+ cerebras_api_key: Cerebras API key (or from env)
127
+ tavily_api_key: Tavily API key (or from env)
128
+ model: Cerebras model to use
129
+ """
130
+ print("🤖 Initializing Research Agent...")
131
+
132
+ # Get API keys
133
+ self.cerebras_key = cerebras_api_key or os.getenv("CEREBRAS_API_KEY")
134
+ self.tavily_key = tavily_api_key or os.getenv("TAVILY_API_KEY")
135
+
136
+ if not self.cerebras_key:
137
+ raise ValueError("CEREBRAS_API_KEY not found")
138
+ if not self.tavily_key:
139
+ raise ValueError("TAVILY_API_KEY not found")
140
+
141
+ # Initialize LLM
142
+ self.llm = Cerebras(api_key=self.cerebras_key)
143
+ self.model = model
144
+
145
+ # Initialize tools
146
+ self.web_search = WebSearchTool(self.tavily_key)
147
+ self.file_reader = FileReaderTool()
148
+ self.image_analyzer = ImageAnalysisTool()
149
+
150
+ print("✅ Agent ready")
151
+
152
+ def _call_llm(self, messages: list) -> str:
153
+ """Call Cerebras LLM"""
154
+ try:
155
+ response = self.llm.chat.completions.create(
156
+ model=self.model,
157
+ messages=messages,
158
+ temperature=0.1,
159
+ max_tokens=2000
160
+ )
161
+ return response.choices[0].message.content.strip()
162
+ except Exception as e:
163
+ raise RuntimeError(f"LLM error: {str(e)}")
164
+
165
+ def answer(self, question: str, file_path: str = None) -> str:
166
+ """
167
+ Answer a question
168
+
169
+ Args:
170
+ question: The question
171
+ file_path: Optional file to analyze
172
+
173
+ Returns:
174
+ Answer string
175
+ """
176
+ print(f"📝 Question: {question[:80]}...")
177
+
178
+ # Detect question type
179
+ is_logic = any(kw in question.lower() for kw in [
180
+ 'opposite', 'backwards', 'reversed'
181
+ ])
182
+
183
+ # Gather context
184
+ context_parts = []
185
+
186
+ if file_path:
187
+ ext = os.path.splitext(file_path)[1].lower()
188
+ if ext in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']:
189
+ content = self.image_analyzer.analyze(file_path)
190
+ else:
191
+ content = self.file_reader.read(file_path)
192
+ context_parts.append(f"File:\n{content}")
193
+
194
+ if not is_logic and not file_path:
195
+ print(" 🔍 Searching web...")
196
+ search = self.web_search.search(question)
197
+ context_parts.append(f"Search:\n{search}")
198
+
199
+ context = "\n\n".join(context_parts) if context_parts else "Use knowledge."
200
+
201
+ # Create prompt
202
+ messages = [
203
+ {
204
+ "role": "system",
205
+ "content": (
206
+ "You are an expert researcher. "
207
+ "Think step-by-step. "
208
+ "Provide ONLY the exact answer - no explanations."
209
+ )
210
+ },
211
+ {
212
+ "role": "user",
213
+ "content": f"""Context:
214
+ {context}
215
+
216
+ Question: {question}
217
+
218
+ Analyze and provide only the final answer:"""
219
+ }
220
+ ]
221
+
222
+ # Get answer
223
+ answer = self._call_llm(messages)
224
+
225
+ # Clean answer
226
+ answer = answer.strip()
227
+ for prefix in ["Answer:", "The answer is:", "Final answer:"]:
228
+ if answer.lower().startswith(prefix.lower()):
229
+ answer = answer[len(prefix):].strip()
230
+
231
+ print(f" ✅ Answer: {answer[:80]}...")
232
+ return answer
233
+
234
+ def __call__(self, question: str, file_path: str = None) -> str:
235
+ """Allow agent(question) syntax"""
236
+ return self.answer(question, file_path)