rl_code_fix_env / dataset /task_manager.py
Viraj0112's picture
Upload folder using huggingface_hub
03a907a verified
"""
Unified Task Manager: Abstractly load tasks from both local and SWE-bench datasets.
This module provides a single interface to load tasks from:
1. Local hardcoded dataset (dataset/problem_1, problem_10, etc.)
2. SWE-bench Lite (if available and configured)
Configuration via environment variables:
TASK_SOURCE "local" | "swebench" | "auto" (default: "auto")
SWEBENCH_FALLBACK "1" (enable fallback when SWE-bench fails, default: "1")
SWEBENCH_TASKS_ROOT Path to SWE-bench tasks directory
SWEBENCH_INDEX Preferred task index within difficulty band
"""
import os
import logging
from pathlib import Path
from typing import Dict, Any, Optional, Literal
from rl_code_fix_env.dataset.loader import get_hardcoded_task
from rl_code_fix_env.dataset.swebench_adapter import get_swebench_task
logger = logging.getLogger(__name__)
TaskSource = Literal["local", "swebench", "auto"]
Difficulty = Literal["easy", "medium", "hard"]
class TaskLoadError(Exception):
"""Raised when task loading fails."""
pass
class TaskManager:
"""
Unified interface for loading tasks from any dataset.
Handles fallback logic, logging, and error recovery.
"""
def __init__(self, source: Optional[TaskSource] = None):
"""
Initialize TaskManager.
Args:
source: "local", "swebench", or "auto" (tries swebench first, falls back to local)
If None, reads from TASK_SOURCE env var (default: "auto")
"""
self.source = (source or os.getenv("TASK_SOURCE", "auto")).strip().lower()
self.enable_fallback = (
os.getenv("SWEBENCH_FALLBACK", "1").strip().lower() in {"1", "true", "yes"}
)
if self.source not in {"local", "swebench", "auto"}:
raise ValueError(
f"Invalid TASK_SOURCE='{self.source}'. "
f"Must be one of: local, swebench, auto"
)
logger.info(
f"TaskManager initialized: source={self.source}, "
f"fallback_enabled={self.enable_fallback}"
)
def load_task(self, difficulty: Difficulty) -> Dict[str, Any]:
"""
Load a task by difficulty level.
Args:
difficulty: "easy", "medium", or "hard"
Returns:
Task dict with structure:
{
"code": str, # buggy Python code
"tests": str, # path to test.py
"metadata": dict, # source, repo, problem_statement, etc.
"problem_dir": str, # directory containing buggy.py and test.py
"problem_id": str, # unique identifier for this task
}
Raises:
TaskLoadError: If no task can be loaded from any source
"""
difficulty = (difficulty or "").strip().lower()
if difficulty not in {"easy", "medium", "hard"}:
raise ValueError(
f"Invalid difficulty='{difficulty}'. Must be one of: easy, medium, hard"
)
# Strategy: try sources in order, with fallback if enabled
if self.source == "local":
return self._load_local(difficulty)
elif self.source == "swebench":
return self._load_swebench(difficulty)
else: # "auto" mode
logger.debug("Auto mode: trying SWE-bench first...")
swebench_error = None
try:
return self._load_swebench(difficulty)
except Exception as e:
swebench_error = str(e)
logger.debug(f"SWE-bench failed: {e}")
if self.enable_fallback:
logger.info("SWE-bench unavailable, falling back to local dataset")
try:
return self._load_local(difficulty)
except Exception as local_error:
raise TaskLoadError(
f"Both SWE-bench and local fallback failed:\n"
f" SWE-bench: {swebench_error}\n"
f" Local: {local_error}"
) from local_error
else:
raise TaskLoadError(
f"SWE-bench loading failed and fallback disabled: {swebench_error}"
)
def _load_local(self, difficulty: Difficulty) -> Dict[str, Any]:
"""Load from local hardcoded dataset."""
try:
task = get_hardcoded_task(difficulty)
task["metadata"]["source"] = "local"
logger.info(f"Loaded task from local dataset: {task.get('problem_id')}")
return task
except Exception as e:
error_msg = f"Failed to load from local dataset: {e}"
logger.warning(error_msg)
raise TaskLoadError(error_msg) from e
def _load_swebench(self, difficulty: Difficulty) -> Dict[str, Any]:
"""Load from SWE-bench Lite dataset."""
try:
task = get_swebench_task(difficulty)
task["metadata"]["source"] = "swebench"
logger.info(
f"Loaded task from SWE-bench: {task.get('problem_id')} "
f"(repo: {task['metadata'].get('repo', '?')})"
)
return task
except Exception as e:
error_msg = f"Failed to load from SWE-bench: {e}"
logger.debug(error_msg)
raise TaskLoadError(error_msg) from e
# Global singleton instance for backward compatibility
_default_manager: Optional[TaskManager] = None
def get_task_manager(source: Optional[TaskSource] = None) -> TaskManager:
"""
Get or create the default TaskManager instance.
Args:
source: Override the source selection. If None, uses TASK_SOURCE env var.
Returns:
TaskManager instance
"""
global _default_manager
if _default_manager is None or source is not None:
_default_manager = TaskManager(source=source)
return _default_manager
def load_task(difficulty: Difficulty, source: Optional[TaskSource] = None) -> Dict[str, Any]:
"""
Convenience function: load a task in one call.
Args:
difficulty: "easy", "medium", or "hard"
source: Optional override for task source
Returns:
Task dict
"""
manager = get_task_manager(source=source)
return manager.load_task(difficulty)