File size: 6,640 Bytes
03a907a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""

Unified Task Manager: Abstractly load tasks from both local and SWE-bench datasets.



This module provides a single interface to load tasks from:

1. Local hardcoded dataset (dataset/problem_1, problem_10, etc.)

2. SWE-bench Lite (if available and configured)



Configuration via environment variables:

  TASK_SOURCE         "local" | "swebench" | "auto" (default: "auto")

  SWEBENCH_FALLBACK   "1" (enable fallback when SWE-bench fails, default: "1")

  SWEBENCH_TASKS_ROOT Path to SWE-bench tasks directory

  SWEBENCH_INDEX      Preferred task index within difficulty band

"""

import os
import logging
from pathlib import Path
from typing import Dict, Any, Optional, Literal

from rl_code_fix_env.dataset.loader import get_hardcoded_task
from rl_code_fix_env.dataset.swebench_adapter import get_swebench_task

logger = logging.getLogger(__name__)

TaskSource = Literal["local", "swebench", "auto"]
Difficulty = Literal["easy", "medium", "hard"]


class TaskLoadError(Exception):
    """Raised when task loading fails."""
    pass


class TaskManager:
    """

    Unified interface for loading tasks from any dataset.

    

    Handles fallback logic, logging, and error recovery.

    """

    def __init__(self, source: Optional[TaskSource] = None):
        """

        Initialize TaskManager.

        

        Args:

            source: "local", "swebench", or "auto" (tries swebench first, falls back to local)

                   If None, reads from TASK_SOURCE env var (default: "auto")

        """
        self.source = (source or os.getenv("TASK_SOURCE", "auto")).strip().lower()
        self.enable_fallback = (
            os.getenv("SWEBENCH_FALLBACK", "1").strip().lower() in {"1", "true", "yes"}
        )
        
        if self.source not in {"local", "swebench", "auto"}:
            raise ValueError(
                f"Invalid TASK_SOURCE='{self.source}'. "
                f"Must be one of: local, swebench, auto"
            )
        
        logger.info(
            f"TaskManager initialized: source={self.source}, "
            f"fallback_enabled={self.enable_fallback}"
        )

    def load_task(self, difficulty: Difficulty) -> Dict[str, Any]:
        """

        Load a task by difficulty level.

        

        Args:

            difficulty: "easy", "medium", or "hard"

        

        Returns:

            Task dict with structure:

            {

                "code": str,           # buggy Python code

                "tests": str,          # path to test.py

                "metadata": dict,      # source, repo, problem_statement, etc.

                "problem_dir": str,    # directory containing buggy.py and test.py

                "problem_id": str,     # unique identifier for this task

            }

        

        Raises:

            TaskLoadError: If no task can be loaded from any source

        """
        difficulty = (difficulty or "").strip().lower()
        if difficulty not in {"easy", "medium", "hard"}:
            raise ValueError(
                f"Invalid difficulty='{difficulty}'. Must be one of: easy, medium, hard"
            )

        # Strategy: try sources in order, with fallback if enabled
        if self.source == "local":
            return self._load_local(difficulty)
        
        elif self.source == "swebench":
            return self._load_swebench(difficulty)
        
        else:  # "auto" mode
            logger.debug("Auto mode: trying SWE-bench first...")
            swebench_error = None
            try:
                return self._load_swebench(difficulty)
            except Exception as e:
                swebench_error = str(e)
                logger.debug(f"SWE-bench failed: {e}")
            
            if self.enable_fallback:
                logger.info("SWE-bench unavailable, falling back to local dataset")
                try:
                    return self._load_local(difficulty)
                except Exception as local_error:
                    raise TaskLoadError(
                        f"Both SWE-bench and local fallback failed:\n"
                        f"  SWE-bench: {swebench_error}\n"
                        f"  Local: {local_error}"
                    ) from local_error
            else:
                raise TaskLoadError(
                    f"SWE-bench loading failed and fallback disabled: {swebench_error}"
                )

    def _load_local(self, difficulty: Difficulty) -> Dict[str, Any]:
        """Load from local hardcoded dataset."""
        try:
            task = get_hardcoded_task(difficulty)
            task["metadata"]["source"] = "local"
            logger.info(f"Loaded task from local dataset: {task.get('problem_id')}")
            return task
        except Exception as e:
            error_msg = f"Failed to load from local dataset: {e}"
            logger.warning(error_msg)
            raise TaskLoadError(error_msg) from e

    def _load_swebench(self, difficulty: Difficulty) -> Dict[str, Any]:
        """Load from SWE-bench Lite dataset."""
        try:
            task = get_swebench_task(difficulty)
            task["metadata"]["source"] = "swebench"
            logger.info(
                f"Loaded task from SWE-bench: {task.get('problem_id')} "
                f"(repo: {task['metadata'].get('repo', '?')})"
            )
            return task
        except Exception as e:
            error_msg = f"Failed to load from SWE-bench: {e}"
            logger.debug(error_msg)
            raise TaskLoadError(error_msg) from e


# Global singleton instance for backward compatibility
_default_manager: Optional[TaskManager] = None


def get_task_manager(source: Optional[TaskSource] = None) -> TaskManager:
    """

    Get or create the default TaskManager instance.

    

    Args:

        source: Override the source selection. If None, uses TASK_SOURCE env var.

    

    Returns:

        TaskManager instance

    """
    global _default_manager
    if _default_manager is None or source is not None:
        _default_manager = TaskManager(source=source)
    return _default_manager


def load_task(difficulty: Difficulty, source: Optional[TaskSource] = None) -> Dict[str, Any]:
    """

    Convenience function: load a task in one call.

    

    Args:

        difficulty: "easy", "medium", or "hard"

        source: Optional override for task source

    

    Returns:

        Task dict

    """
    manager = get_task_manager(source=source)
    return manager.load_task(difficulty)