| from typing import Dict, Any, List |
| from environment.models import CodeContext, TaskMetadata |
|
|
|
|
| class TaskDefinitions: |
|
|
| TASK_ALIASES = { |
| "bug_detection_easy": "bug_detection_easy_1", |
| "bug_detection_medium": "memory_leak_medium_1", |
| "bug_detection_hard": "security_hard_1", |
| } |
| |
| EASY_TASKS = [ |
| { |
| "task_id": "bug_detection_easy_1", |
| "task_name": "Division by Zero", |
| "description": "Find the division by zero vulnerability in the calculate_average function", |
| "code_diff": """def calculate_average(numbers): |
| total = sum(numbers) |
| return total / len(numbers)""", |
| "surrounding_code": """class StatisticsCalculator: |
| def __init__(self): |
| self.results = [] |
| |
| def calculate_average(self, numbers): |
| total = sum(numbers) |
| return total / len(numbers) |
| |
| def add_result(self, value): |
| self.results.append(value)""", |
| "file_path": "statistics.py", |
| "language": "python", |
| "line_count": 3, |
| "expected_issues": [ |
| { |
| "line": 3, |
| "type": "division_by_zero", |
| "severity": "high", |
| "description": "No check for empty list before division" |
| } |
| ] |
| }, |
| { |
| "task_id": "bug_detection_easy_2", |
| "task_name": "Off-by-One Error", |
| "description": "Find the off-by-one error in the array iteration", |
| "code_diff": """def process_items(items): |
| for i in range(len(items)): |
| item = items[i] |
| next_item = items[i + 1] |
| process_pair(item, next_item)""", |
| "surrounding_code": """def process_items(items): |
| for i in range(len(items)): |
| item = items[i] |
| next_item = items[i + 1] |
| process_pair(item, next_item) |
| return True""", |
| "file_path": "processor.py", |
| "language": "python", |
| "line_count": 4, |
| "expected_issues": [ |
| { |
| "line": 4, |
| "type": "index_error", |
| "severity": "medium", |
| "description": "Index out of bounds when i is the last element" |
| } |
| ] |
| } |
| ] |
| |
| MEDIUM_TASKS = [ |
| { |
| "task_id": "memory_leak_medium_1", |
| "task_name": "File Handle Leak", |
| "description": "Find the memory leak where file handles are not properly closed", |
| "code_diff": """def read_files(file_list): |
| contents = [] |
| for filename in file_list: |
| f = open(filename, 'r') |
| data = f.read() |
| contents.append(data) |
| return contents""", |
| "surrounding_code": """import os |
| |
| def read_files(file_list): |
| contents = [] |
| for filename in file_list: |
| f = open(filename, 'r') |
| data = f.read() |
| contents.append(data) |
| return contents |
| |
| def write_output(data, filename): |
| with open(filename, 'w') as f: |
| f.write(data)""", |
| "file_path": "file_handler.py", |
| "language": "python", |
| "line_count": 6, |
| "expected_issues": [ |
| { |
| "line": 4, |
| "type": "resource_leak", |
| "severity": "high", |
| "description": "File not closed after reading" |
| } |
| ] |
| }, |
| { |
| "task_id": "performance_medium_2", |
| "task_name": "Inefficient String Concatenation", |
| "description": "Find the performance issue with string concatenation in a loop", |
| "code_diff": """def build_string(items): |
| result = "" |
| for item in items: |
| result = result + item + "," |
| return result[:-1]""", |
| "surrounding_code": """def build_string(items): |
| result = "" |
| for item in items: |
| result = result + item + "," |
| return result[:-1] |
| |
| def format_output(data): |
| return build_string(data)""", |
| "file_path": "string_builder.py", |
| "language": "python", |
| "line_count": 4, |
| "expected_issues": [ |
| { |
| "line": 4, |
| "type": "performance", |
| "severity": "medium", |
| "description": "Inefficient string concatenation in loop" |
| } |
| ] |
| } |
| ] |
| |
| HARD_TASKS = [ |
| { |
| "task_id": "security_hard_1", |
| "task_name": "SQL Injection Vulnerability", |
| "description": "Find the SQL injection vulnerability in the database query", |
| "code_diff": """def get_user_data(user_id): |
| query = f"SELECT * FROM users WHERE id = {user_id}" |
| return database.execute(query)""", |
| "surrounding_code": """import database |
| |
| def get_user_data(user_id): |
| query = f"SELECT * FROM users WHERE id = {user_id}" |
| return database.execute(query) |
| |
| def get_all_users(): |
| return database.execute("SELECT * FROM users")""", |
| "file_path": "user_repository.py", |
| "language": "python", |
| "line_count": 3, |
| "expected_issues": [ |
| { |
| "line": 2, |
| "type": "sql_injection", |
| "severity": "critical", |
| "description": "SQL injection vulnerability from string interpolation" |
| } |
| ] |
| }, |
| { |
| "task_id": "race_condition_hard_2", |
| "task_name": "Race Condition", |
| "description": "Find the race condition in the thread-safe counter", |
| "code_diff": """class Counter: |
| def __init__(self): |
| self.count = 0 |
| |
| def increment(self): |
| current = self.count |
| self.count = current + 1 |
| return self.count""", |
| "surrounding_code": """import threading |
| |
| class Counter: |
| def __init__(self): |
| self.count = 0 |
| |
| def increment(self): |
| current = self.count |
| self.count = current + 1 |
| return self.count |
| |
| def get_count(self): |
| return self.count""", |
| "file_path": "counter.py", |
| "language": "python", |
| "line_count": 7, |
| "expected_issues": [ |
| { |
| "line": 6, |
| "type": "race_condition", |
| "severity": "high", |
| "description": "Non-atomic increment operation" |
| } |
| ] |
| } |
| ] |
| |
| @classmethod |
| def get_task(cls, task_id: str) -> Dict[str, Any]: |
| canonical_task_id = cls.TASK_ALIASES.get(task_id, task_id) |
| all_tasks = cls.EASY_TASKS + cls.MEDIUM_TASKS + cls.HARD_TASKS |
| for task in all_tasks: |
| if task["task_id"] == canonical_task_id: |
| return task |
| return cls.EASY_TASKS[0] |
|
|
| @classmethod |
| def get_all_tasks(cls) -> List[Dict[str, Any]]: |
| return cls.EASY_TASKS + cls.MEDIUM_TASKS + cls.HARD_TASKS |
| |
| @classmethod |
| def get_tasks_by_difficulty(cls, difficulty: str) -> List[Dict[str, Any]]: |
| if difficulty == "easy": |
| return cls.EASY_TASKS |
| elif difficulty == "medium": |
| return cls.MEDIUM_TASKS |
| elif difficulty == "hard": |
| return cls.HARD_TASKS |
| return [] |
| |
| @classmethod |
| def create_code_context(cls, task_data: Dict[str, Any]) -> CodeContext: |
| return CodeContext( |
| file_path=task_data["file_path"], |
| file_extension=task_data["file_path"].split(".")[-1], |
| code_diff=task_data["code_diff"], |
| surrounding_code=task_data["surrounding_code"], |
| language=task_data["language"], |
| line_count=task_data["line_count"] |
| ) |
| |
| @classmethod |
| def create_task_metadata(cls, task_data: Dict[str, Any]) -> TaskMetadata: |
| difficulty = "easy" |
| if "medium" in task_data["task_id"]: |
| difficulty = "medium" |
| elif "hard" in task_data["task_id"]: |
| difficulty = "hard" |
| |
| return TaskMetadata( |
| task_id=task_data["task_id"], |
| task_name=task_data["task_name"], |
| difficulty=difficulty, |
| description=task_data["description"], |
| expected_issues=task_data.get("expected_issues", []) |
| ) |