rl_code_fix_env / dataset /generate_swebench_tasks.py
Viraj0112's picture
Upload folder using huggingface_hub
03a907a verified
"""
Generate synthetic SWE-bench style tasks for testing.
This creates tasks that mimic the SWE-bench format:
- instance_id/buggy.py - the buggy code
- instance_id/test.py - test file
- instance_id/metadata.json - metadata
Usage:
python -m dataset.generate_swebench_tasks [--count N]
"""
import argparse
import json
import random
from pathlib import Path
# Sample SWE-bench style problems
SWE_BENCH_PROBLEMS = [
{
"instance_id": "django__django-11098",
"repo": "django/django",
"problem": "Fix the user creation form validation error",
"buggy_code": '''from django import forms
from django.contrib.auth.models import User
class UserCreationForm(forms.ModelForm):
"""Form for creating new users."""
password1 = forms.CharField(widget=forms.PasswordInput)
password2 = forms.CharField(widget=forms.PasswordInput)
class Meta:
model = User
fields = ('username', 'email')
def clean(self):
cleaned_data = super().clean()
password1 = cleaned_data.get('password1')
password2 = cleaned_data.get('password2')
# BUG: This comparison is case-sensitive but should be case-insensitive
if password1 != password2:
raise forms.ValidationError("Passwords don't match")
return cleaned_data
def save(self, commit=True):
user = super().save(commit=False)
user.set_password(self.cleaned_data['password1'])
if commit:
user.save()
return user
''',
"test_code": '''import unittest
from buggy import UserCreationForm
class TestUserCreationForm(unittest.TestCase):
def test_password_matching(self):
"""Test that matching passwords pass validation."""
form = UserCreationForm(data={
'username': 'testuser',
'email': 'test@example.com',
'password1': 'TestPass123',
'password2': 'TestPass123',
})
self.assertTrue(form.is_valid())
def test_password_mismatch(self):
"""Test that mismatched passwords fail validation."""
form = UserCreationForm(data={
'username': 'testuser',
'email': 'test@example.com',
'password1': 'TestPass123',
'password2': 'testpass123', # Different case
})
self.assertFalse(form.is_valid())
self.assertIn('passwords', str(form.errors).lower())
''',
},
{
"instance_id": "flask__flask-1048",
"repo": "pallets/flask",
"problem": "Fix JSON encoding for datetime objects",
"buggy_code": '''import json
from datetime import datetime, date
class JSONEncoder(json.JSONEncoder):
"""Custom JSON encoder for Flask."""
def default(self, obj):
# BUG: Missing handling for datetime objects
if isinstance(obj, date):
return obj.isoformat()
return super().default(obj)
def to_json(obj):
"""Convert object to JSON string."""
return json.dumps(obj, cls=JSONEncoder)
''',
"test_code": '''import unittest
from datetime import datetime
from buggy import to_json
class TestJSONEncoding(unittest.TestCase):
def test_encode_datetime(self):
"""Test that datetime objects are properly encoded."""
dt = datetime(2024, 1, 15, 10, 30, 0)
result = to_json({'timestamp': dt})
self.assertIn('2024-01-15', result)
self.assertIn('10:30:00', result)
def test_encode_date(self):
"""Test that date objects are properly encoded."""
d = date(2024, 1, 15)
result = to_json({'date': d})
self.assertIn('2024-01-15', result)
''',
},
{
"instance_id": "requests__requests-2875",
"repo": "psf/requests",
"problem": "Fix cookie domain matching",
"buggy_code": '''import re
from urllib.parse import urlparse
def match_cookie_domain(cookie_domain, request_domain):
"""Check if cookie domain matches request domain."""
# BUG: Should handle leading dots differently
# .example.com should match sub.example.com but not example.com
cookie_domain = cookie_domain.lower()
request_domain = request_domain.lower()
if cookie_domain.startswith('.'):
return request_domain.endswith(cookie_domain)
return cookie_domain == request_domain
''',
"test_code": '''import unittest
from buggy import match_cookie_domain
class TestCookieDomain(unittest.TestCase):
def test_exact_match(self):
"""Test exact domain matching."""
self.assertTrue(match_cookie_domain('example.com', 'example.com'))
def test_subdomain_with_dot(self):
"""Test subdomain matching with leading dot."""
# .example.com should match sub.example.com
self.assertTrue(match_cookie_domain('.example.com', 'sub.example.com'))
self.assertFalse(match_cookie_domain('.example.com', 'example.com'))
def test_different_domains(self):
"""Test different domains don't match."""
self.assertFalse(match_cookie_domain('example.com', 'other.com'))
''',
},
{
"instance_id": "numpy__numpy-10825",
"repo": "numpy/numpy",
"problem": "Fix array concatenation edge case",
"buggy_code": '''import numpy as np
def concatenate_arrays(*arrays):
"""Concatenate multiple arrays along axis 0."""
if not arrays:
return np.array([])
# BUG: Should handle None arrays gracefully
result = arrays[0]
for arr in arrays[1:]:
result = np.concatenate([result, arr])
return result
''',
"test_code": '''import unittest
import numpy as np
from buggy import concatenate_arrays
class TestArrayConcatenation(unittest.TestCase):
def test_basic_concatenation(self):
"""Test basic array concatenation."""
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
result = concatenate_arrays(a, b)
np.testing.assert_array_equal(result, np.array([1, 2, 3, 4, 5, 6]))
def test_empty_input(self):
"""Test empty input returns empty array."""
result = concatenate_arrays()
self.assertEqual(len(result), 0)
def test_single_array(self):
"""Test single array passes through."""
a = np.array([1, 2, 3])
result = concatenate_arrays(a)
np.testing.assert_array_equal(result, a)
''',
},
{
"instance_id": "pandas__pandas-15230",
"repo": "pandas-dev/pandas",
"problem": "Fix DataFrame groupby aggregation",
"buggy_code": '''import pandas as pd
def group_and_aggregate(df, group_col, agg_col, agg_func='mean'):
"""Group DataFrame and aggregate."""
# BUG: Should handle non-numeric columns gracefully
if agg_func == 'mean':
return df.groupby(group_col)[agg_col].mean()
elif agg_func == 'sum':
return df.groupby(group_col)[agg_col].sum()
elif agg_func == 'count':
return df.groupby(group_col)[agg_col].count()
else:
raise ValueError(f"Unknown aggregation function: {agg_func}")
''',
"test_code": '''import unittest
import pandas as pd
from buggy import group_and_aggregate
class TestGroupBy(unittest.TestCase):
def test_mean_aggregation(self):
"""Test mean aggregation."""
df = pd.DataFrame({
'category': ['A', 'A', 'B', 'B'],
'value': [1, 2, 3, 4]
})
result = group_and_aggregate(df, 'category', 'value', 'mean')
self.assertEqual(result['A'], 1.5)
self.assertEqual(result['B'], 3.5)
def test_sum_aggregation(self):
"""Test sum aggregation."""
df = pd.DataFrame({
'category': ['A', 'A', 'B'],
'value': [1, 2, 3]
})
result = group_and_aggregate(df, 'category', 'value', 'sum')
self.assertEqual(result['A'], 3)
self.assertEqual(result['B'], 3)
''',
},
{
"instance_id": "scipy__scipy-1925",
"repo": "scipy/scipy",
"problem": "Fix signal filtering edge case",
"buggy_code": '''import numpy as np
from scipy import signal
def apply_lowpass_filter(data, cutoff, fs, order=5):
"""Apply lowpass filter to data."""
# BUG: Should validate cutoff frequency
nyquist = fs / 2
normalized_cutoff = cutoff / nyquist
# BUG: Using invalid cutoff can cause filter design failure
b, a = signal.butter(order, normalized_cutoff, btype='low')
filtered = signal.filtfilt(b, a, data)
return filtered
''',
"test_code": '''import unittest
import numpy as np
from buggy import apply_lowpass_filter
class TestSignalFiltering(unittest.TestCase):
def test_valid_filter(self):
"""Test filtering with valid parameters."""
fs = 1000 # Sampling frequency
cutoff = 100 # Cutoff frequency
t = np.linspace(0, 1, fs)
data = np.sin(2 * np.pi * 50 * t) + 0.5 * np.sin(2 * np.pi * 200 * t)
result = apply_lowpass_filter(data, cutoff, fs)
self.assertEqual(len(result), len(data))
# Low frequency component should be preserved
self.assertTrue(np.abs(result[100]) > 0.5)
def test_invalid_cutoff(self):
"""Test that invalid cutoff raises error."""
fs = 1000
cutoff = 2000 # Above Nyquist frequency - should fail
data = np.array([1, 2, 3, 4, 5])
with self.assertRaises(ValueError):
apply_lowpass_filter(data, cutoff, fs)
''',
},
{
"instance_id": "sklearn__sklearn-12345",
"repo": "scikit-learn/scikit-learn",
"problem": "Fix cross-validation split",
"buggy_code": '''import numpy as np
from sklearn.model_selection import KFold
def get_cv_splits(X, n_splits=5, shuffle=True, random_state=42):
"""Get cross-validation splits."""
# BUG: random_state should be used for reproducibility
kf = KFold(n_splits=n_splits, shuffle=shuffle)
splits = []
for train_idx, test_idx in kf.split(X):
splits.append((train_idx, test_idx))
return splits
''',
"test_code": '''import unittest
import numpy as np
from buggy import get_cv_splits
class TestCVSplits(unittest.TestCase):
def test_split_count(self):
"""Test that correct number of splits is generated."""
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
splits = get_cv_splits(X, n_splits=3)
self.assertEqual(len(splits), 3)
def test_reproducibility(self):
"""Test that splits are reproducible with same random_state."""
X = np.random.rand(100, 5)
splits1 = get_cv_splits(X, n_splits=5, random_state=42)
splits2 = get_cv_splits(X, n_splits=5, random_state=42)
for (train1, test1), (train2, test2) in zip(splits1, splits2):
np.testing.assert_array_equal(train1, train2)
np.testing.assert_array_equal(test1, test2)
''',
},
{
"instance_id": "pytest__pytest-7426",
"repo": "pytest-dev/pytest",
"problem": "Fix test collection order",
"buggy_code": '''import os
import re
def collect_tests(directory, pattern='test_*.py'):
"""Collect test files from directory."""
# BUG: Should sort files for consistent ordering
test_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if re.match(pattern, file):
test_files.append(os.path.join(root, file))
return test_files
''',
"test_code": '''import unittest
import os
import tempfile
from buggy import collect_tests
class TestCollection(unittest.TestCase):
def test_collect_pattern(self):
"""Test that correct pattern is matched."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create test files
open(os.path.join(tmpdir, 'test_a.py'), 'w').close()
open(os.path.join(tmpdir, 'test_b.py'), 'w').close()
open(os.path.join(tmpdir, 'not_a_test.py'), 'w').close()
tests = collect_tests(tmpdir, 'test_*.py')
self.assertEqual(len(tests), 2)
def test_consistent_order(self):
"""Test that file order is consistent."""
with tempfile.TemporaryDirectory() as tmpdir:
for name in ['test_c.py', 'test_a.py', 'test_b.py']:
open(os.path.join(tmpdir, name), 'w').close()
tests1 = collect_tests(tmpdir)
tests2 = collect_tests(tmpdir)
self.assertEqual(tests1, tests2)
''',
},
{
"instance_id": "transformers__transformers-12345",
"repo": "huggingface/transformers",
"problem": "Fix tokenization padding",
"buggy_code": '''from typing import List
def tokenize_and_pad(tokenizer, texts: List[str], max_length: int = 512):
"""Tokenize texts and pad to max length."""
# BUG: Should handle padding correctly
encoded = tokenizer(
texts,
padding=True, # This pads to longest in batch, not max_length
truncation=True,
max_length=max_length,
return_tensors='pt'
)
return encoded
''',
"test_code": '''import unittest
from buggy import tokenize_and_pad
class MockTokenizer:
def __call__(self, texts, padding=True, truncation=True, max_length=512, return_tensors=None):
# Simplified mock
return {
'input_ids': [[1, 2, 3]] if isinstance(texts, list) else [1, 2, 3],
'attention_mask': [[1, 1, 1]] if isinstance(texts, list) else [1, 1, 1]
}
class TestTokenization(unittest.TestCase):
def test_single_text(self):
"""Test tokenizing single text."""
tokenizer = MockTokenizer()
result = tokenize_and_pad(tokenizer, ["hello world"])
self.assertIn('input_ids', result)
def test_max_length_respected(self):
"""Test that max_length is respected."""
tokenizer = MockTokenizer()
# Should not raise even with long text
result = tokenize_and_pad(tokenizer, ["short"], max_length=10)
self.assertIn('input_ids', result)
''',
},
]
# Easy, Medium, Hard difficulty assignments
DIFFICULTY_TASKS = {
"easy": SWE_BENCH_PROBLEMS[:3],
"medium": SWE_BENCH_PROBLEMS[3:6],
"hard": SWE_BENCH_PROBLEMS[6:],
}
def generate_tasks(output_dir: Path, count_per_difficulty: int = 3):
"""Generate SWE-bench style tasks."""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
total_created = 0
for difficulty, problems in DIFFICULTY_TASKS.items():
for i, problem in enumerate(problems[:count_per_difficulty]):
instance_id = f"{problem['instance_id']}_{difficulty}_{i}"
instance_dir = output_dir / instance_id
instance_dir.mkdir(parents=True, exist_ok=True)
# Write buggy.py
buggy_file = instance_dir / "buggy.py"
buggy_file.write_text(problem["buggy_code"], encoding="utf-8")
# Write test.py
test_file = instance_dir / "test.py"
test_file.write_text(problem["test_code"], encoding="utf-8")
# Write metadata.json
metadata = {
"instance_id": instance_id,
"repo": problem["repo"],
"problem_statement": problem["problem"],
"difficulty": difficulty,
}
metadata_file = instance_dir / "metadata.json"
metadata_file.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
total_created += 1
print(f"Created {total_created} tasks in {output_dir}")
print(f"Set environment variable: SWEBENCH_TASKS_ROOT={output_dir.absolute()}")
print(f"Or run with: TASK_SOURCE=swebench python inference.py")
def main():
parser = argparse.ArgumentParser(description="Generate SWE-bench style tasks")
parser.add_argument(
"--count",
type=int,
default=3,
help="Number of tasks per difficulty (default: 3)"
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory (default: dataset/swebench_lite_tasks)"
)
args = parser.parse_args()
if args.output_dir:
output_dir = Path(args.output_dir)
else:
script_dir = Path(__file__).parent
output_dir = script_dir / "swebench_lite_tasks"
generate_tasks(output_dir, args.count)
if __name__ == "__main__":
main()