File size: 4,022 Bytes
a2ec7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
PostgreSQL Task Manager for MCPMark
====================================

Manages PostgreSQL task discovery, execution, and verification.
"""

import os
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional

from src.base.task_manager import BaseTask, BaseTaskManager
from src.logger import get_logger

logger = get_logger(__name__)


@dataclass
class PostgresTask(BaseTask):
    """PostgreSQL-specific task with database information."""

    task_name: str = ""
    database_name: Optional[str] = None
    database_url: Optional[str] = None
    expected_queries: Optional[List[str]] = None
    expected_tables: Optional[List[str]] = None


class PostgresTaskManager(BaseTaskManager):
    """Manages PostgreSQL tasks for MCPMark evaluation."""

    def __init__(self, tasks_root: Path = None, task_suite: str = "standard"):
        """Initialize PostgreSQL task manager.

        Args:
            tasks_root: Path to tasks directory
            task_suite: Logical task suite (e.g., 'standard', 'easy')
        """
        if tasks_root is None:
            tasks_root = Path(__file__).resolve().parents[3] / "tasks"

        super().__init__(
            tasks_root,
            mcp_service="postgres",
            task_class=PostgresTask,
            task_organization="file",  # PostgreSQL uses file-based tasks
            task_suite=task_suite,
        )

    def _create_task_from_files(
        self, category_id: str, task_files_info: Dict[str, Any]
    ) -> Optional[PostgresTask]:
        """Instantiate a `PostgresTask` from the dictionary returned by `_find_task_files`."""
        import json
        
        # Check for meta.json
        meta_path = task_files_info["instruction_path"].parent / "meta.json"
        final_category_id = category_id
        task_id = task_files_info["task_id"]
        
        if meta_path.exists():
            try:
                with open(meta_path, 'r') as f:
                    meta_data = json.load(f)
                    # Use values from meta.json if available
                    final_category_id = meta_data.get("category_id", category_id)
                    task_id = meta_data.get("task_id", task_id)
            except Exception as e:
                logger.warning(f"Failed to load meta.json from {meta_path}: {e}")

        return PostgresTask(
            task_instruction_path=task_files_info["instruction_path"],
            task_verification_path=task_files_info["verification_path"],
            service="postgres",
            category_id=final_category_id,
            task_id=task_id,
            task_name=task_files_info["task_id"],
        )

    def _get_verification_command(self, task: PostgresTask) -> List[str]:
        """Get verification command with database info."""
        cmd = [sys.executable, str(task.task_verification_path)]

        # Pass database name as argument if available
        if task.database_name:
            cmd.append(task.database_name)

        return cmd

    def run_verification(self, task: BaseTask) -> subprocess.CompletedProcess:
        """Run verification with PostgreSQL environment."""
        env = os.environ.copy()

        # Pass database connection info to verification script
        if hasattr(task, "database_name") and task.database_name:
            env["POSTGRES_DATABASE"] = task.database_name

        if hasattr(task, "database_url") and task.database_url:
            env["DATABASE_URL"] = task.database_url

        return subprocess.run(
            self._get_verification_command(task),
            capture_output=True,
            text=True,
            timeout=300,
            env=env,
        )

    def _format_task_instruction(self, base_instruction: str) -> str:
        """Add PostgreSQL-specific instructions."""
        return (
            base_instruction
            + "\n\nNote: Use PostgreSQL MCP tools to complete this task. The database connection is already configured."
        )