Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -461,14 +461,34 @@ class MagAgent:
|
|
| 461 |
|
| 462 |
def _build_task_prompt(self, question: str, task_id: str) -> str:
|
| 463 |
"""Constructs task-specific prompts using templates"""
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
subtasks=self._generate_subtasks(question),
|
| 469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
)
|
| 471 |
-
|
|
|
|
| 472 |
|
| 473 |
async def _execute_agent(self, question: str, task_id: str) -> str:
|
| 474 |
return await asyncio.to_thread(
|
|
@@ -501,7 +521,16 @@ class MagAgent:
|
|
| 501 |
|
| 502 |
except Exception as e:
|
| 503 |
return self._handle_error(e, context)
|
| 504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
def _handle_error(self, error: Exception, context: dict) -> str:
|
| 506 |
"""Central error handling with context-aware formatting"""
|
| 507 |
error_type = error.__class__.__name__
|
|
|
|
| 461 |
|
| 462 |
def _build_task_prompt(self, question: str, task_id: str) -> str:
|
| 463 |
"""Constructs task-specific prompts using templates"""
|
| 464 |
+
managed_agent = self.prompt_templates["managed_agent"]
|
| 465 |
+
|
| 466 |
+
task_context = managed_agent["task_template"].format(
|
| 467 |
+
task_id=task_id,
|
| 468 |
+
question=question,
|
| 469 |
+
current_date=datetime.now().strftime("%Y-%m-%d"),
|
| 470 |
+
answer_format="concise, accurate, properly punctuated",
|
| 471 |
+
validation_level="strict"
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
analysis = managed_agent["analysis_template"].format(
|
| 475 |
+
question_analysis=self._analyze_question(question),
|
| 476 |
+
entities=self._extract_entities(question),
|
| 477 |
+
temporal_constraints=self._find_temporal_limits(question)
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
subtasks = managed_agent["subtask_template"].format(
|
| 481 |
subtasks=self._generate_subtasks(question),
|
| 482 |
+
tools=", ".join([tool.name for tool in self.tools])
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
validation = managed_agent["validation_template"].format(
|
| 486 |
+
validation_rules=self._get_validation_rules(question),
|
| 487 |
+
max_retries=3,
|
| 488 |
+
fallback_protocol="Multi-source verification"
|
| 489 |
)
|
| 490 |
+
|
| 491 |
+
return "\n\n".join([task_context, analysis, subtasks, validation])
|
| 492 |
|
| 493 |
async def _execute_agent(self, question: str, task_id: str) -> str:
|
| 494 |
return await asyncio.to_thread(
|
|
|
|
| 521 |
|
| 522 |
except Exception as e:
|
| 523 |
return self._handle_error(e, context)
|
| 524 |
+
|
| 525 |
+
def _extract_entities(self, question: str) -> str:
|
| 526 |
+
"""Extract key entities from question"""
|
| 527 |
+
return ", ".join(re.findall(r'\b[A-Z][a-z]+\b', question))
|
| 528 |
+
|
| 529 |
+
def _find_temporal_limits(self, question: str) -> str:
|
| 530 |
+
"""Find date ranges in question"""
|
| 531 |
+
dates = re.findall(r'\b\d{4}\b', question)
|
| 532 |
+
return f"{min(dates)}-{max(dates)}" if dates else "No date constraints"
|
| 533 |
+
|
| 534 |
def _handle_error(self, error: Exception, context: dict) -> str:
|
| 535 |
"""Central error handling with context-aware formatting"""
|
| 536 |
error_type = error.__class__.__name__
|