File size: 15,702 Bytes
a2ec7b6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 | #!/usr/bin/env python3
"""
Enhanced Base Task Manager with Common Task Discovery Logic
===========================================================
This module provides an improved base class for task managers that consolidates
common task discovery patterns while maintaining flexibility for service-specific needs.
"""
import json
import subprocess
import sys
from abc import ABC
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
from src.logger import get_logger
from src.results_reporter import TaskResult
logger = get_logger(__name__)
@dataclass
class BaseTask:
"""Base class for evaluation tasks."""
task_instruction_path: Path
task_verification_path: Path
service: str
category_id: str # From meta.json if available, otherwise directory name
task_id: str # From meta.json if available, otherwise directory name
@property
def name(self) -> str:
"""Return the task name using '__' separator format: 'category_id__task_id'."""
return f"{self.category_id}__{self.task_id}"
def get_task_instruction(self) -> str:
"""Return the full text content of the task instruction file."""
if not self.task_instruction_path.exists():
raise FileNotFoundError(
f"Task instruction file not found: {self.task_instruction_path}"
)
return self.task_instruction_path.read_text(encoding="utf-8")
class BaseTaskManager(ABC):
"""Enhanced base class for service-specific task managers with common discovery logic."""
def __init__(
self,
tasks_root: Path,
mcp_service: str = None,
task_class: type = None,
task_organization: str = None,
task_suite: str | None = "standard",
):
"""Initialize the base task manager.
Args:
tasks_root: Root directory containing all tasks
mcp_service: MCP service name (e.g., 'notion', 'github', 'filesystem')
task_class: Custom task class to use (defaults to BaseTask)
task_organization: 'file' or 'directory' based task organization
task_suite: Logical task suite (e.g., 'standard', 'easy')
"""
self.tasks_root = tasks_root
self.mcp_service = mcp_service or self.__class__.__name__.lower().replace(
"taskmanager", ""
)
self.task_class = task_class or BaseTask
self.task_organization = task_organization
self.task_suite = task_suite
self._tasks_cache = None
# =========================================================================
# Common Task Discovery Implementation
# =========================================================================
def discover_all_tasks(self) -> List[BaseTask]:
"""Discover all available tasks for this service (common implementation)."""
if self._tasks_cache is not None:
return self._tasks_cache
tasks = []
service_dir = self.tasks_root / (
self.mcp_service or self._get_service_directory_name()
)
if self.task_suite:
service_dir = service_dir / self.task_suite
if not service_dir.exists():
logger.warning(
f"{self.mcp_service.title()} tasks directory does not exist: {service_dir}"
)
return tasks
# Scan categories
for category_dir in service_dir.iterdir():
if not self._is_valid_category_dir(category_dir):
continue
category_id = category_dir.name
logger.info("Discovering tasks in category: %s", category_id)
# Find tasks using service-specific logic
task_files = self._find_task_files(category_dir)
for task_files_info in task_files:
task = self._create_task_from_files(category_id, task_files_info)
if task:
tasks.append(task)
logger.debug("Found task: %s", task.name)
# Sort and cache
# Sort by category_id and a stringified task_id to handle both numeric IDs and slugs uniformly
self._tasks_cache = sorted(tasks, key=lambda t: (t.category_id, str(t.task_id)))
logger.info(
"Discovered %d %s tasks across all categories (suite=%s)",
len(self._tasks_cache),
self.mcp_service.title(),
self.task_suite or "default",
)
return self._tasks_cache
def get_categories(self) -> List[str]:
"""Get a list of all task categories (common implementation)."""
tasks = self.discover_all_tasks()
return sorted(list(set(task.category_id for task in tasks)))
def filter_tasks(self, task_filter: str) -> List[BaseTask]:
"""Filter tasks based on category or specific task pattern (common implementation)."""
all_tasks = self.discover_all_tasks()
if not task_filter or task_filter.lower() == "all":
return all_tasks
# Check if it's a category filter
categories = self.get_categories()
if task_filter in categories:
return [task for task in all_tasks if task.category_id == task_filter]
# Check for specific task pattern (category_id/task_id)
if "/" in task_filter:
try:
category, task_part = task_filter.split("/", 1)
# First try to match by task_id (could be numeric or string)
for task in all_tasks:
if task.category_id == category:
# Check if task_id matches (as string or as specific pattern)
if str(task.task_id) == task_part:
return [task]
except (ValueError, IndexError):
pass
# Fallback: check for partial matches in task names or categories
filtered_tasks = []
for task in all_tasks:
if (
task_filter in task.category_id
or task_filter in task.name
or task_filter == str(task.task_id)
):
filtered_tasks.append(task)
return filtered_tasks
# =========================================================================
# Common Helper Methods
# =========================================================================
def get_task_instruction(self, task: BaseTask) -> str:
"""Get formatted task instruction (template method)."""
base_instruction = self._read_task_instruction(task)
return self._format_task_instruction(base_instruction)
def execute_task(self, task: BaseTask, agent_result: Dict[str, Any]) -> TaskResult:
"""Execute task verification (template method)."""
logger.info(f"| Verifying task ({self.mcp_service.title()}): {task.name}")
# Track agent success separately
agent_success = agent_result.get("success", False)
agent_error = None
verification_success = False
verification_error = None
verification_output = None
# Handle agent failure (but still continue to verification)
if not agent_success:
agent_error = agent_result.get("error", "Agent execution failed")
# Standardize MCP network errors
agent_error = self._standardize_error_message(agent_error)
logger.error(f"| ✗ Agent execution failed for task")
logger.error(f"| ⚠️ Error: {agent_error}")
logger.info(f"| - Proceeding with verification despite agent failure")
try:
# Always run verification regardless of agent success
verify_result = self.run_verification(task)
# Process verification results
verification_success = verify_result.returncode == 0
verification_output = verify_result.stdout
# Log verification output
if verification_output:
print(verification_output)
# Capture verification error if failed
if not verification_success:
verification_error = verify_result.stderr if verify_result.stderr else "Verification failed with no error message"
if verification_success:
logger.info(f"| Verification Result: \033[92m✓ PASSED\033[0m")
else:
logger.error(f"| Verification Result: \033[91m✗ FAILED\033[0m")
if verification_error:
logger.error(f"| Verification Error: {verification_error}")
return TaskResult(
task_name=task.name,
success=verification_success,
error_message=agent_error, # Agent execution error
verification_error=verification_error, # Verification error
verification_output=verification_output, # Verification output
model_output=agent_result.get("output", ""),
category_id=task.category_id,
task_id=task.task_id,
token_usage=agent_result.get("token_usage", {}),
turn_count=agent_result.get("turn_count", -1),
)
except Exception as e:
logger.error(f"| Task verification failed: {e}", exc_info=True)
return TaskResult(
task_name=task.name,
success=False,
error_message=agent_error, # Keep agent error if any
verification_error=str(e), # Verification exception
verification_output=None,
category_id=task.category_id,
task_id=task.task_id,
model_output=agent_result.get("output", ""),
token_usage=agent_result.get("token_usage", {}),
turn_count=agent_result.get("turn_count", 0),
)
def run_verification(self, task: BaseTask) -> subprocess.CompletedProcess:
"""Run the verification script for a task (can be overridden).
Default implementation runs the verification command.
Services can override this to add environment variables or custom logic.
"""
return subprocess.run(
self._get_verification_command(task),
capture_output=True, # Capture stdout and stderr for logging
text=True,
timeout=300,
)
# =========================================================================
# Abstract Methods - Minimal Set Required
# =========================================================================
def _get_service_directory_name(self) -> str:
"""Return the service directory name (e.g., 'notion', 'github').
Default implementation uses the service parameter if provided.
"""
if self.mcp_service:
return self.mcp_service
raise NotImplementedError(
"Must provide service parameter or implement _get_service_directory_name"
)
def _get_task_organization(self) -> str:
"""Return task organization type: 'directory' or 'file'.
- 'directory': Tasks organized as task_X/description.md (Notion)
- 'file': Tasks organized as task_X.md (GitHub, Filesystem)
Default implementation uses the task_organization parameter if provided.
"""
if self.task_organization:
return self.task_organization
raise NotImplementedError(
"Must provide task_organization parameter or implement _get_task_organization"
)
# Note: _create_task_instance is no longer needed - use task_class parameter instead
# =========================================================================
# Hook Methods with Smart Defaults
# =========================================================================
def _is_valid_category_dir(self, category_dir: Path) -> bool:
"""Check if a directory is a valid category directory."""
return (
category_dir.is_dir()
and not category_dir.name.startswith(".")
and category_dir.name != "utils"
and category_dir.name != "__pycache__"
)
def _find_task_files(self, category_dir: Path) -> List[Dict[str, Any]]:
"""Find task files in a category directory (smart default implementation).
Automatically handles both directory-based and file-based organization.
"""
task_files: List[Dict[str, Any]] = []
for task_dir in category_dir.iterdir():
# Skip anything that is not a directory or is hidden
if not task_dir.is_dir() or task_dir.name.startswith("."):
continue
description_path = task_dir / "description.md"
verify_path = task_dir / "verify.py"
# We consider a directory a valid task only if the two mandatory files exist
if not (description_path.exists() and verify_path.exists()):
logger.warning(
"Skipping %s – missing description.md or verify.py", task_dir
)
continue
task_files.append(
{
"task_id": task_dir.name,
"instruction_path": description_path,
"verification_path": verify_path,
}
)
return task_files
def _create_task_from_files(
self, category_id: str, task_files_info: Dict[str, Any]
) -> Optional[BaseTask]:
"""Create a task from file information with meta.json support."""
# Check for meta.json
meta_path = task_files_info["instruction_path"].parent / "meta.json"
# Default to directory names
task_id = task_files_info["task_id"]
final_category_id = category_id
if meta_path.exists():
try:
with open(meta_path, 'r') as f:
meta_data = json.load(f)
# Use values from meta.json if available
final_category_id = meta_data.get("category_id", category_id)
task_id = meta_data.get("task_id", task_id)
except Exception as e:
logger.warning(f"Failed to load meta.json from {meta_path}: {e}")
return self.task_class(
task_instruction_path=task_files_info["instruction_path"],
task_verification_path=task_files_info["verification_path"],
service=self.mcp_service,
category_id=final_category_id,
task_id=task_id,
)
def _read_task_instruction(self, task: BaseTask) -> str:
"""Read and return the task instruction content."""
return task.get_task_instruction()
def _format_task_instruction(self, base_instruction: str) -> str:
"""Format task instruction with Notion-specific additions."""
return (
base_instruction
+ "\n\nNote: Based on your understanding, solve the task all at once by yourself, don't ask for my opinions on anything."
)
def _get_verification_command(self, task: BaseTask) -> List[str]:
"""Get the command to run task verification (default implementation)."""
return [sys.executable, str(task.task_verification_path)]
def _standardize_error_message(self, error_message: str) -> str:
"""Standardize error messages for consistent reporting."""
from src.errors import standardize_error_message
return standardize_error_message(error_message, mcp_service=self.mcp_service)
|