Shardul Dhekane
Normalize provider-agnostic API config and add Space server endpoints
edaad73
from typing import Dict, Any, List
from environment.models import CodeContext, TaskMetadata
class TaskDefinitions:
TASK_ALIASES = {
"bug_detection_easy": "bug_detection_easy_1",
"bug_detection_medium": "memory_leak_medium_1",
"bug_detection_hard": "security_hard_1",
}
EASY_TASKS = [
{
"task_id": "bug_detection_easy_1",
"task_name": "Division by Zero",
"description": "Find the division by zero vulnerability in the calculate_average function",
"code_diff": """def calculate_average(numbers):
total = sum(numbers)
return total / len(numbers)""",
"surrounding_code": """class StatisticsCalculator:
def __init__(self):
self.results = []
def calculate_average(self, numbers):
total = sum(numbers)
return total / len(numbers)
def add_result(self, value):
self.results.append(value)""",
"file_path": "statistics.py",
"language": "python",
"line_count": 3,
"expected_issues": [
{
"line": 3,
"type": "division_by_zero",
"severity": "high",
"description": "No check for empty list before division"
}
]
},
{
"task_id": "bug_detection_easy_2",
"task_name": "Off-by-One Error",
"description": "Find the off-by-one error in the array iteration",
"code_diff": """def process_items(items):
for i in range(len(items)):
item = items[i]
next_item = items[i + 1]
process_pair(item, next_item)""",
"surrounding_code": """def process_items(items):
for i in range(len(items)):
item = items[i]
next_item = items[i + 1]
process_pair(item, next_item)
return True""",
"file_path": "processor.py",
"language": "python",
"line_count": 4,
"expected_issues": [
{
"line": 4,
"type": "index_error",
"severity": "medium",
"description": "Index out of bounds when i is the last element"
}
]
}
]
MEDIUM_TASKS = [
{
"task_id": "memory_leak_medium_1",
"task_name": "File Handle Leak",
"description": "Find the memory leak where file handles are not properly closed",
"code_diff": """def read_files(file_list):
contents = []
for filename in file_list:
f = open(filename, 'r')
data = f.read()
contents.append(data)
return contents""",
"surrounding_code": """import os
def read_files(file_list):
contents = []
for filename in file_list:
f = open(filename, 'r')
data = f.read()
contents.append(data)
return contents
def write_output(data, filename):
with open(filename, 'w') as f:
f.write(data)""",
"file_path": "file_handler.py",
"language": "python",
"line_count": 6,
"expected_issues": [
{
"line": 4,
"type": "resource_leak",
"severity": "high",
"description": "File not closed after reading"
}
]
},
{
"task_id": "performance_medium_2",
"task_name": "Inefficient String Concatenation",
"description": "Find the performance issue with string concatenation in a loop",
"code_diff": """def build_string(items):
result = ""
for item in items:
result = result + item + ","
return result[:-1]""",
"surrounding_code": """def build_string(items):
result = ""
for item in items:
result = result + item + ","
return result[:-1]
def format_output(data):
return build_string(data)""",
"file_path": "string_builder.py",
"language": "python",
"line_count": 4,
"expected_issues": [
{
"line": 4,
"type": "performance",
"severity": "medium",
"description": "Inefficient string concatenation in loop"
}
]
}
]
HARD_TASKS = [
{
"task_id": "security_hard_1",
"task_name": "SQL Injection Vulnerability",
"description": "Find the SQL injection vulnerability in the database query",
"code_diff": """def get_user_data(user_id):
query = f"SELECT * FROM users WHERE id = {user_id}"
return database.execute(query)""",
"surrounding_code": """import database
def get_user_data(user_id):
query = f"SELECT * FROM users WHERE id = {user_id}"
return database.execute(query)
def get_all_users():
return database.execute("SELECT * FROM users")""",
"file_path": "user_repository.py",
"language": "python",
"line_count": 3,
"expected_issues": [
{
"line": 2,
"type": "sql_injection",
"severity": "critical",
"description": "SQL injection vulnerability from string interpolation"
}
]
},
{
"task_id": "race_condition_hard_2",
"task_name": "Race Condition",
"description": "Find the race condition in the thread-safe counter",
"code_diff": """class Counter:
def __init__(self):
self.count = 0
def increment(self):
current = self.count
self.count = current + 1
return self.count""",
"surrounding_code": """import threading
class Counter:
def __init__(self):
self.count = 0
def increment(self):
current = self.count
self.count = current + 1
return self.count
def get_count(self):
return self.count""",
"file_path": "counter.py",
"language": "python",
"line_count": 7,
"expected_issues": [
{
"line": 6,
"type": "race_condition",
"severity": "high",
"description": "Non-atomic increment operation"
}
]
}
]
@classmethod
def get_task(cls, task_id: str) -> Dict[str, Any]:
canonical_task_id = cls.TASK_ALIASES.get(task_id, task_id)
all_tasks = cls.EASY_TASKS + cls.MEDIUM_TASKS + cls.HARD_TASKS
for task in all_tasks:
if task["task_id"] == canonical_task_id:
return task
return cls.EASY_TASKS[0]
@classmethod
def get_all_tasks(cls) -> List[Dict[str, Any]]:
return cls.EASY_TASKS + cls.MEDIUM_TASKS + cls.HARD_TASKS
@classmethod
def get_tasks_by_difficulty(cls, difficulty: str) -> List[Dict[str, Any]]:
if difficulty == "easy":
return cls.EASY_TASKS
elif difficulty == "medium":
return cls.MEDIUM_TASKS
elif difficulty == "hard":
return cls.HARD_TASKS
return []
@classmethod
def create_code_context(cls, task_data: Dict[str, Any]) -> CodeContext:
return CodeContext(
file_path=task_data["file_path"],
file_extension=task_data["file_path"].split(".")[-1],
code_diff=task_data["code_diff"],
surrounding_code=task_data["surrounding_code"],
language=task_data["language"],
line_count=task_data["line_count"]
)
@classmethod
def create_task_metadata(cls, task_data: Dict[str, Any]) -> TaskMetadata:
difficulty = "easy"
if "medium" in task_data["task_id"]:
difficulty = "medium"
elif "hard" in task_data["task_id"]:
difficulty = "hard"
return TaskMetadata(
task_id=task_data["task_id"],
task_name=task_data["task_name"],
difficulty=difficulty,
description=task_data["description"],
expected_issues=task_data.get("expected_issues", [])
)