from typing import Dict, Any, List from environment.models import CodeContext, TaskMetadata class TaskDefinitions: TASK_ALIASES = { "bug_detection_easy": "bug_detection_easy_1", "bug_detection_medium": "memory_leak_medium_1", "bug_detection_hard": "security_hard_1", } EASY_TASKS = [ { "task_id": "bug_detection_easy_1", "task_name": "Division by Zero", "description": "Find the division by zero vulnerability in the calculate_average function", "code_diff": """def calculate_average(numbers): total = sum(numbers) return total / len(numbers)""", "surrounding_code": """class StatisticsCalculator: def __init__(self): self.results = [] def calculate_average(self, numbers): total = sum(numbers) return total / len(numbers) def add_result(self, value): self.results.append(value)""", "file_path": "statistics.py", "language": "python", "line_count": 3, "expected_issues": [ { "line": 3, "type": "division_by_zero", "severity": "high", "description": "No check for empty list before division" } ] }, { "task_id": "bug_detection_easy_2", "task_name": "Off-by-One Error", "description": "Find the off-by-one error in the array iteration", "code_diff": """def process_items(items): for i in range(len(items)): item = items[i] next_item = items[i + 1] process_pair(item, next_item)""", "surrounding_code": """def process_items(items): for i in range(len(items)): item = items[i] next_item = items[i + 1] process_pair(item, next_item) return True""", "file_path": "processor.py", "language": "python", "line_count": 4, "expected_issues": [ { "line": 4, "type": "index_error", "severity": "medium", "description": "Index out of bounds when i is the last element" } ] } ] MEDIUM_TASKS = [ { "task_id": "memory_leak_medium_1", "task_name": "File Handle Leak", "description": "Find the memory leak where file handles are not properly closed", "code_diff": """def read_files(file_list): contents = [] for filename in file_list: f = open(filename, 'r') data = f.read() contents.append(data) return contents""", "surrounding_code": """import os def read_files(file_list): contents = [] for filename in file_list: f = open(filename, 'r') data = f.read() contents.append(data) return contents def write_output(data, filename): with open(filename, 'w') as f: f.write(data)""", "file_path": "file_handler.py", "language": "python", "line_count": 6, "expected_issues": [ { "line": 4, "type": "resource_leak", "severity": "high", "description": "File not closed after reading" } ] }, { "task_id": "performance_medium_2", "task_name": "Inefficient String Concatenation", "description": "Find the performance issue with string concatenation in a loop", "code_diff": """def build_string(items): result = "" for item in items: result = result + item + "," return result[:-1]""", "surrounding_code": """def build_string(items): result = "" for item in items: result = result + item + "," return result[:-1] def format_output(data): return build_string(data)""", "file_path": "string_builder.py", "language": "python", "line_count": 4, "expected_issues": [ { "line": 4, "type": "performance", "severity": "medium", "description": "Inefficient string concatenation in loop" } ] } ] HARD_TASKS = [ { "task_id": "security_hard_1", "task_name": "SQL Injection Vulnerability", "description": "Find the SQL injection vulnerability in the database query", "code_diff": """def get_user_data(user_id): query = f"SELECT * FROM users WHERE id = {user_id}" return database.execute(query)""", "surrounding_code": """import database def get_user_data(user_id): query = f"SELECT * FROM users WHERE id = {user_id}" return database.execute(query) def get_all_users(): return database.execute("SELECT * FROM users")""", "file_path": "user_repository.py", "language": "python", "line_count": 3, "expected_issues": [ { "line": 2, "type": "sql_injection", "severity": "critical", "description": "SQL injection vulnerability from string interpolation" } ] }, { "task_id": "race_condition_hard_2", "task_name": "Race Condition", "description": "Find the race condition in the thread-safe counter", "code_diff": """class Counter: def __init__(self): self.count = 0 def increment(self): current = self.count self.count = current + 1 return self.count""", "surrounding_code": """import threading class Counter: def __init__(self): self.count = 0 def increment(self): current = self.count self.count = current + 1 return self.count def get_count(self): return self.count""", "file_path": "counter.py", "language": "python", "line_count": 7, "expected_issues": [ { "line": 6, "type": "race_condition", "severity": "high", "description": "Non-atomic increment operation" } ] } ] @classmethod def get_task(cls, task_id: str) -> Dict[str, Any]: canonical_task_id = cls.TASK_ALIASES.get(task_id, task_id) all_tasks = cls.EASY_TASKS + cls.MEDIUM_TASKS + cls.HARD_TASKS for task in all_tasks: if task["task_id"] == canonical_task_id: return task return cls.EASY_TASKS[0] @classmethod def get_all_tasks(cls) -> List[Dict[str, Any]]: return cls.EASY_TASKS + cls.MEDIUM_TASKS + cls.HARD_TASKS @classmethod def get_tasks_by_difficulty(cls, difficulty: str) -> List[Dict[str, Any]]: if difficulty == "easy": return cls.EASY_TASKS elif difficulty == "medium": return cls.MEDIUM_TASKS elif difficulty == "hard": return cls.HARD_TASKS return [] @classmethod def create_code_context(cls, task_data: Dict[str, Any]) -> CodeContext: return CodeContext( file_path=task_data["file_path"], file_extension=task_data["file_path"].split(".")[-1], code_diff=task_data["code_diff"], surrounding_code=task_data["surrounding_code"], language=task_data["language"], line_count=task_data["line_count"] ) @classmethod def create_task_metadata(cls, task_data: Dict[str, Any]) -> TaskMetadata: difficulty = "easy" if "medium" in task_data["task_id"]: difficulty = "medium" elif "hard" in task_data["task_id"]: difficulty = "hard" return TaskMetadata( task_id=task_data["task_id"], task_name=task_data["task_name"], difficulty=difficulty, description=task_data["description"], expected_issues=task_data.get("expected_issues", []) )