""" Generate synthetic SWE-bench style tasks for testing. This creates tasks that mimic the SWE-bench format: - instance_id/buggy.py - the buggy code - instance_id/test.py - test file - instance_id/metadata.json - metadata Usage: python -m dataset.generate_swebench_tasks [--count N] """ import argparse import json import random from pathlib import Path # Sample SWE-bench style problems SWE_BENCH_PROBLEMS = [ { "instance_id": "django__django-11098", "repo": "django/django", "problem": "Fix the user creation form validation error", "buggy_code": '''from django import forms from django.contrib.auth.models import User class UserCreationForm(forms.ModelForm): """Form for creating new users.""" password1 = forms.CharField(widget=forms.PasswordInput) password2 = forms.CharField(widget=forms.PasswordInput) class Meta: model = User fields = ('username', 'email') def clean(self): cleaned_data = super().clean() password1 = cleaned_data.get('password1') password2 = cleaned_data.get('password2') # BUG: This comparison is case-sensitive but should be case-insensitive if password1 != password2: raise forms.ValidationError("Passwords don't match") return cleaned_data def save(self, commit=True): user = super().save(commit=False) user.set_password(self.cleaned_data['password1']) if commit: user.save() return user ''', "test_code": '''import unittest from buggy import UserCreationForm class TestUserCreationForm(unittest.TestCase): def test_password_matching(self): """Test that matching passwords pass validation.""" form = UserCreationForm(data={ 'username': 'testuser', 'email': 'test@example.com', 'password1': 'TestPass123', 'password2': 'TestPass123', }) self.assertTrue(form.is_valid()) def test_password_mismatch(self): """Test that mismatched passwords fail validation.""" form = UserCreationForm(data={ 'username': 'testuser', 'email': 'test@example.com', 'password1': 'TestPass123', 'password2': 'testpass123', # Different case }) self.assertFalse(form.is_valid()) self.assertIn('passwords', str(form.errors).lower()) ''', }, { "instance_id": "flask__flask-1048", "repo": "pallets/flask", "problem": "Fix JSON encoding for datetime objects", "buggy_code": '''import json from datetime import datetime, date class JSONEncoder(json.JSONEncoder): """Custom JSON encoder for Flask.""" def default(self, obj): # BUG: Missing handling for datetime objects if isinstance(obj, date): return obj.isoformat() return super().default(obj) def to_json(obj): """Convert object to JSON string.""" return json.dumps(obj, cls=JSONEncoder) ''', "test_code": '''import unittest from datetime import datetime from buggy import to_json class TestJSONEncoding(unittest.TestCase): def test_encode_datetime(self): """Test that datetime objects are properly encoded.""" dt = datetime(2024, 1, 15, 10, 30, 0) result = to_json({'timestamp': dt}) self.assertIn('2024-01-15', result) self.assertIn('10:30:00', result) def test_encode_date(self): """Test that date objects are properly encoded.""" d = date(2024, 1, 15) result = to_json({'date': d}) self.assertIn('2024-01-15', result) ''', }, { "instance_id": "requests__requests-2875", "repo": "psf/requests", "problem": "Fix cookie domain matching", "buggy_code": '''import re from urllib.parse import urlparse def match_cookie_domain(cookie_domain, request_domain): """Check if cookie domain matches request domain.""" # BUG: Should handle leading dots differently # .example.com should match sub.example.com but not example.com cookie_domain = cookie_domain.lower() request_domain = request_domain.lower() if cookie_domain.startswith('.'): return request_domain.endswith(cookie_domain) return cookie_domain == request_domain ''', "test_code": '''import unittest from buggy import match_cookie_domain class TestCookieDomain(unittest.TestCase): def test_exact_match(self): """Test exact domain matching.""" self.assertTrue(match_cookie_domain('example.com', 'example.com')) def test_subdomain_with_dot(self): """Test subdomain matching with leading dot.""" # .example.com should match sub.example.com self.assertTrue(match_cookie_domain('.example.com', 'sub.example.com')) self.assertFalse(match_cookie_domain('.example.com', 'example.com')) def test_different_domains(self): """Test different domains don't match.""" self.assertFalse(match_cookie_domain('example.com', 'other.com')) ''', }, { "instance_id": "numpy__numpy-10825", "repo": "numpy/numpy", "problem": "Fix array concatenation edge case", "buggy_code": '''import numpy as np def concatenate_arrays(*arrays): """Concatenate multiple arrays along axis 0.""" if not arrays: return np.array([]) # BUG: Should handle None arrays gracefully result = arrays[0] for arr in arrays[1:]: result = np.concatenate([result, arr]) return result ''', "test_code": '''import unittest import numpy as np from buggy import concatenate_arrays class TestArrayConcatenation(unittest.TestCase): def test_basic_concatenation(self): """Test basic array concatenation.""" a = np.array([1, 2, 3]) b = np.array([4, 5, 6]) result = concatenate_arrays(a, b) np.testing.assert_array_equal(result, np.array([1, 2, 3, 4, 5, 6])) def test_empty_input(self): """Test empty input returns empty array.""" result = concatenate_arrays() self.assertEqual(len(result), 0) def test_single_array(self): """Test single array passes through.""" a = np.array([1, 2, 3]) result = concatenate_arrays(a) np.testing.assert_array_equal(result, a) ''', }, { "instance_id": "pandas__pandas-15230", "repo": "pandas-dev/pandas", "problem": "Fix DataFrame groupby aggregation", "buggy_code": '''import pandas as pd def group_and_aggregate(df, group_col, agg_col, agg_func='mean'): """Group DataFrame and aggregate.""" # BUG: Should handle non-numeric columns gracefully if agg_func == 'mean': return df.groupby(group_col)[agg_col].mean() elif agg_func == 'sum': return df.groupby(group_col)[agg_col].sum() elif agg_func == 'count': return df.groupby(group_col)[agg_col].count() else: raise ValueError(f"Unknown aggregation function: {agg_func}") ''', "test_code": '''import unittest import pandas as pd from buggy import group_and_aggregate class TestGroupBy(unittest.TestCase): def test_mean_aggregation(self): """Test mean aggregation.""" df = pd.DataFrame({ 'category': ['A', 'A', 'B', 'B'], 'value': [1, 2, 3, 4] }) result = group_and_aggregate(df, 'category', 'value', 'mean') self.assertEqual(result['A'], 1.5) self.assertEqual(result['B'], 3.5) def test_sum_aggregation(self): """Test sum aggregation.""" df = pd.DataFrame({ 'category': ['A', 'A', 'B'], 'value': [1, 2, 3] }) result = group_and_aggregate(df, 'category', 'value', 'sum') self.assertEqual(result['A'], 3) self.assertEqual(result['B'], 3) ''', }, { "instance_id": "scipy__scipy-1925", "repo": "scipy/scipy", "problem": "Fix signal filtering edge case", "buggy_code": '''import numpy as np from scipy import signal def apply_lowpass_filter(data, cutoff, fs, order=5): """Apply lowpass filter to data.""" # BUG: Should validate cutoff frequency nyquist = fs / 2 normalized_cutoff = cutoff / nyquist # BUG: Using invalid cutoff can cause filter design failure b, a = signal.butter(order, normalized_cutoff, btype='low') filtered = signal.filtfilt(b, a, data) return filtered ''', "test_code": '''import unittest import numpy as np from buggy import apply_lowpass_filter class TestSignalFiltering(unittest.TestCase): def test_valid_filter(self): """Test filtering with valid parameters.""" fs = 1000 # Sampling frequency cutoff = 100 # Cutoff frequency t = np.linspace(0, 1, fs) data = np.sin(2 * np.pi * 50 * t) + 0.5 * np.sin(2 * np.pi * 200 * t) result = apply_lowpass_filter(data, cutoff, fs) self.assertEqual(len(result), len(data)) # Low frequency component should be preserved self.assertTrue(np.abs(result[100]) > 0.5) def test_invalid_cutoff(self): """Test that invalid cutoff raises error.""" fs = 1000 cutoff = 2000 # Above Nyquist frequency - should fail data = np.array([1, 2, 3, 4, 5]) with self.assertRaises(ValueError): apply_lowpass_filter(data, cutoff, fs) ''', }, { "instance_id": "sklearn__sklearn-12345", "repo": "scikit-learn/scikit-learn", "problem": "Fix cross-validation split", "buggy_code": '''import numpy as np from sklearn.model_selection import KFold def get_cv_splits(X, n_splits=5, shuffle=True, random_state=42): """Get cross-validation splits.""" # BUG: random_state should be used for reproducibility kf = KFold(n_splits=n_splits, shuffle=shuffle) splits = [] for train_idx, test_idx in kf.split(X): splits.append((train_idx, test_idx)) return splits ''', "test_code": '''import unittest import numpy as np from buggy import get_cv_splits class TestCVSplits(unittest.TestCase): def test_split_count(self): """Test that correct number of splits is generated.""" X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) splits = get_cv_splits(X, n_splits=3) self.assertEqual(len(splits), 3) def test_reproducibility(self): """Test that splits are reproducible with same random_state.""" X = np.random.rand(100, 5) splits1 = get_cv_splits(X, n_splits=5, random_state=42) splits2 = get_cv_splits(X, n_splits=5, random_state=42) for (train1, test1), (train2, test2) in zip(splits1, splits2): np.testing.assert_array_equal(train1, train2) np.testing.assert_array_equal(test1, test2) ''', }, { "instance_id": "pytest__pytest-7426", "repo": "pytest-dev/pytest", "problem": "Fix test collection order", "buggy_code": '''import os import re def collect_tests(directory, pattern='test_*.py'): """Collect test files from directory.""" # BUG: Should sort files for consistent ordering test_files = [] for root, dirs, files in os.walk(directory): for file in files: if re.match(pattern, file): test_files.append(os.path.join(root, file)) return test_files ''', "test_code": '''import unittest import os import tempfile from buggy import collect_tests class TestCollection(unittest.TestCase): def test_collect_pattern(self): """Test that correct pattern is matched.""" with tempfile.TemporaryDirectory() as tmpdir: # Create test files open(os.path.join(tmpdir, 'test_a.py'), 'w').close() open(os.path.join(tmpdir, 'test_b.py'), 'w').close() open(os.path.join(tmpdir, 'not_a_test.py'), 'w').close() tests = collect_tests(tmpdir, 'test_*.py') self.assertEqual(len(tests), 2) def test_consistent_order(self): """Test that file order is consistent.""" with tempfile.TemporaryDirectory() as tmpdir: for name in ['test_c.py', 'test_a.py', 'test_b.py']: open(os.path.join(tmpdir, name), 'w').close() tests1 = collect_tests(tmpdir) tests2 = collect_tests(tmpdir) self.assertEqual(tests1, tests2) ''', }, { "instance_id": "transformers__transformers-12345", "repo": "huggingface/transformers", "problem": "Fix tokenization padding", "buggy_code": '''from typing import List def tokenize_and_pad(tokenizer, texts: List[str], max_length: int = 512): """Tokenize texts and pad to max length.""" # BUG: Should handle padding correctly encoded = tokenizer( texts, padding=True, # This pads to longest in batch, not max_length truncation=True, max_length=max_length, return_tensors='pt' ) return encoded ''', "test_code": '''import unittest from buggy import tokenize_and_pad class MockTokenizer: def __call__(self, texts, padding=True, truncation=True, max_length=512, return_tensors=None): # Simplified mock return { 'input_ids': [[1, 2, 3]] if isinstance(texts, list) else [1, 2, 3], 'attention_mask': [[1, 1, 1]] if isinstance(texts, list) else [1, 1, 1] } class TestTokenization(unittest.TestCase): def test_single_text(self): """Test tokenizing single text.""" tokenizer = MockTokenizer() result = tokenize_and_pad(tokenizer, ["hello world"]) self.assertIn('input_ids', result) def test_max_length_respected(self): """Test that max_length is respected.""" tokenizer = MockTokenizer() # Should not raise even with long text result = tokenize_and_pad(tokenizer, ["short"], max_length=10) self.assertIn('input_ids', result) ''', }, ] # Easy, Medium, Hard difficulty assignments DIFFICULTY_TASKS = { "easy": SWE_BENCH_PROBLEMS[:3], "medium": SWE_BENCH_PROBLEMS[3:6], "hard": SWE_BENCH_PROBLEMS[6:], } def generate_tasks(output_dir: Path, count_per_difficulty: int = 3): """Generate SWE-bench style tasks.""" output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) total_created = 0 for difficulty, problems in DIFFICULTY_TASKS.items(): for i, problem in enumerate(problems[:count_per_difficulty]): instance_id = f"{problem['instance_id']}_{difficulty}_{i}" instance_dir = output_dir / instance_id instance_dir.mkdir(parents=True, exist_ok=True) # Write buggy.py buggy_file = instance_dir / "buggy.py" buggy_file.write_text(problem["buggy_code"], encoding="utf-8") # Write test.py test_file = instance_dir / "test.py" test_file.write_text(problem["test_code"], encoding="utf-8") # Write metadata.json metadata = { "instance_id": instance_id, "repo": problem["repo"], "problem_statement": problem["problem"], "difficulty": difficulty, } metadata_file = instance_dir / "metadata.json" metadata_file.write_text(json.dumps(metadata, indent=2), encoding="utf-8") total_created += 1 print(f"Created {total_created} tasks in {output_dir}") print(f"Set environment variable: SWEBENCH_TASKS_ROOT={output_dir.absolute()}") print(f"Or run with: TASK_SOURCE=swebench python inference.py") def main(): parser = argparse.ArgumentParser(description="Generate SWE-bench style tasks") parser.add_argument( "--count", type=int, default=3, help="Number of tasks per difficulty (default: 3)" ) parser.add_argument( "--output-dir", type=str, default=None, help="Output directory (default: dataset/swebench_lite_tasks)" ) args = parser.parse_args() if args.output_dir: output_dir = Path(args.output_dir) else: script_dir = Path(__file__).parent output_dir = script_dir / "swebench_lite_tasks" generate_tasks(output_dir, args.count) if __name__ == "__main__": main()