Spaces:
Sleeping
Sleeping
Update src/agent.py
Browse files- src/agent.py +38 -27
src/agent.py
CHANGED
|
@@ -89,73 +89,84 @@ class BasicAgent():
|
|
| 89 |
self.expected_tokens_per_step = 10000
|
| 90 |
self.max_retries = 3
|
| 91 |
self.base_delay = 5
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
final_answer = None
|
| 95 |
retry_count = 0
|
| 96 |
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
if file_content:
|
| 99 |
context = f"Story content:\n{file_content}"
|
| 100 |
elif file_path:
|
| 101 |
context = f"File path: {file_path}"
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
while True:
|
| 106 |
try:
|
| 107 |
-
final_input = f"{question}\n\n{context}"
|
| 108 |
-
# Run the agent
|
| 109 |
steps = self.agent.run(final_input)
|
| 110 |
|
| 111 |
-
#
|
| 112 |
if isinstance(steps, str):
|
| 113 |
steps = [steps]
|
| 114 |
|
| 115 |
for step in steps:
|
| 116 |
-
# Handle string steps
|
| 117 |
if isinstance(step, str):
|
| 118 |
final_answer = step
|
| 119 |
-
print(f"Step: String Output: {final_answer}")
|
| 120 |
continue
|
| 121 |
|
| 122 |
-
# Handle object steps
|
| 123 |
step_name = step.__class__.__name__
|
| 124 |
output = getattr(step, "output", None)
|
| 125 |
-
|
| 126 |
if output:
|
| 127 |
-
|
| 128 |
|
| 129 |
self.token_rate_limiter.maybe_wait(self.expected_tokens_per_step)
|
| 130 |
tokens_used = getattr(step, "token_usage", None)
|
| 131 |
if tokens_used:
|
| 132 |
self.token_rate_limiter.add_tokens(tokens_used.input_tokens)
|
| 133 |
|
| 134 |
-
# Capture the final answer from the final answer step
|
| 135 |
-
if step_name == "FinalAnswerStep":
|
| 136 |
-
final_answer = output
|
| 137 |
-
print(f"Captured Final Answer from step: {final_answer}")
|
| 138 |
-
|
| 139 |
break # Exit retry loop if successful
|
| 140 |
|
| 141 |
except Exception as e:
|
| 142 |
-
# Handle API overload/rate limits
|
| 143 |
if "overload" in str(e).lower() or "rate limit" in str(e).lower():
|
| 144 |
-
print("Rate limit exceeded. Retrying...")
|
| 145 |
if retry_count >= self.max_retries:
|
| 146 |
print("Max retries reached. Exiting...")
|
| 147 |
break
|
| 148 |
delay = self.base_delay * (2 ** retry_count) + random.random()
|
| 149 |
-
print(f"Retrying in {delay:.1f}s ... ({e})")
|
| 150 |
time.sleep(delay)
|
| 151 |
retry_count += 1
|
| 152 |
else:
|
| 153 |
-
print(f"Error: {e}")
|
| 154 |
break
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
| 158 |
return final_answer
|
| 159 |
|
| 160 |
-
def __call__(self, question: str, file_content: str = "", file_path: str = ""):
|
| 161 |
-
return self.run(question, file_content)
|
|
|
|
| 89 |
self.expected_tokens_per_step = 10000
|
| 90 |
self.max_retries = 3
|
| 91 |
self.base_delay = 5
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _read_file(self, file_path: str) -> str:
|
| 95 |
+
if not os.path.exists(file_path):
|
| 96 |
+
print(f"File not found: {file_path}")
|
| 97 |
+
return ""
|
| 98 |
+
if file_path.endswith(".txt"):
|
| 99 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 100 |
+
return f.read()
|
| 101 |
+
elif file_path.endswith(".docx"):
|
| 102 |
+
doc = docx.Document(file_path)
|
| 103 |
+
return "\n".join([p.text for p in doc.paragraphs])
|
| 104 |
+
else:
|
| 105 |
+
# For unsupported formats, return empty string
|
| 106 |
+
print(f"Unsupported file type: {file_path}")
|
| 107 |
+
return ""
|
| 108 |
+
|
| 109 |
+
def run(self, question: str, file_content: str = "", file_path: str = "") -> str:
|
| 110 |
final_answer = None
|
| 111 |
retry_count = 0
|
| 112 |
|
| 113 |
+
# If file content is empty but file_path exists, read the file
|
| 114 |
+
if not file_content and file_path:
|
| 115 |
+
file_content = self._read_file(file_path)
|
| 116 |
+
|
| 117 |
+
context = ""
|
| 118 |
if file_content:
|
| 119 |
context = f"Story content:\n{file_content}"
|
| 120 |
elif file_path:
|
| 121 |
context = f"File path: {file_path}"
|
| 122 |
+
|
| 123 |
+
print(f"Starting Agent with question: {question}\nContext length: {len(context)} chars")
|
| 124 |
+
|
| 125 |
while True:
|
| 126 |
try:
|
| 127 |
+
final_input = f"Question: {question}\n\n{context}"
|
|
|
|
| 128 |
steps = self.agent.run(final_input)
|
| 129 |
|
| 130 |
+
# Convert string steps to list
|
| 131 |
if isinstance(steps, str):
|
| 132 |
steps = [steps]
|
| 133 |
|
| 134 |
for step in steps:
|
|
|
|
| 135 |
if isinstance(step, str):
|
| 136 |
final_answer = step
|
|
|
|
| 137 |
continue
|
| 138 |
|
|
|
|
| 139 |
step_name = step.__class__.__name__
|
| 140 |
output = getattr(step, "output", None)
|
|
|
|
| 141 |
if output:
|
| 142 |
+
final_answer = output
|
| 143 |
|
| 144 |
self.token_rate_limiter.maybe_wait(self.expected_tokens_per_step)
|
| 145 |
tokens_used = getattr(step, "token_usage", None)
|
| 146 |
if tokens_used:
|
| 147 |
self.token_rate_limiter.add_tokens(tokens_used.input_tokens)
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
break # Exit retry loop if successful
|
| 150 |
|
| 151 |
except Exception as e:
|
|
|
|
| 152 |
if "overload" in str(e).lower() or "rate limit" in str(e).lower():
|
|
|
|
| 153 |
if retry_count >= self.max_retries:
|
| 154 |
print("Max retries reached. Exiting...")
|
| 155 |
break
|
| 156 |
delay = self.base_delay * (2 ** retry_count) + random.random()
|
| 157 |
+
print(f"Retrying in {delay:.1f}s due to rate limit... ({e})")
|
| 158 |
time.sleep(delay)
|
| 159 |
retry_count += 1
|
| 160 |
else:
|
| 161 |
+
print(f"Error during agent run: {e}")
|
| 162 |
break
|
| 163 |
|
| 164 |
+
# Ensure a valid answer is always returned
|
| 165 |
+
if not final_answer:
|
| 166 |
+
final_answer = "I am unable to answer"
|
| 167 |
+
|
| 168 |
+
print(f"Finished agent run. Final Answer: {final_answer}\n{'='*50}")
|
| 169 |
return final_answer
|
| 170 |
|
| 171 |
+
def __call__(self, question: str, file_content: str = "", file_path: str = "") -> str:
|
| 172 |
+
return self.run(question, file_content, file_path)
|