Spaces:
Running
Running
| """ | |
| Generate synthetic SWE-bench style tasks for testing. | |
| This creates tasks that mimic the SWE-bench format: | |
| - instance_id/buggy.py - the buggy code | |
| - instance_id/test.py - test file | |
| - instance_id/metadata.json - metadata | |
| Usage: | |
| python -m dataset.generate_swebench_tasks [--count N] | |
| """ | |
| import argparse | |
| import json | |
| import random | |
| from pathlib import Path | |
| # Sample SWE-bench style problems | |
| SWE_BENCH_PROBLEMS = [ | |
| { | |
| "instance_id": "django__django-11098", | |
| "repo": "django/django", | |
| "problem": "Fix the user creation form validation error", | |
| "buggy_code": '''from django import forms | |
| from django.contrib.auth.models import User | |
| class UserCreationForm(forms.ModelForm): | |
| """Form for creating new users.""" | |
| password1 = forms.CharField(widget=forms.PasswordInput) | |
| password2 = forms.CharField(widget=forms.PasswordInput) | |
| class Meta: | |
| model = User | |
| fields = ('username', 'email') | |
| def clean(self): | |
| cleaned_data = super().clean() | |
| password1 = cleaned_data.get('password1') | |
| password2 = cleaned_data.get('password2') | |
| # BUG: This comparison is case-sensitive but should be case-insensitive | |
| if password1 != password2: | |
| raise forms.ValidationError("Passwords don't match") | |
| return cleaned_data | |
| def save(self, commit=True): | |
| user = super().save(commit=False) | |
| user.set_password(self.cleaned_data['password1']) | |
| if commit: | |
| user.save() | |
| return user | |
| ''', | |
| "test_code": '''import unittest | |
| from buggy import UserCreationForm | |
| class TestUserCreationForm(unittest.TestCase): | |
| def test_password_matching(self): | |
| """Test that matching passwords pass validation.""" | |
| form = UserCreationForm(data={ | |
| 'username': 'testuser', | |
| 'email': 'test@example.com', | |
| 'password1': 'TestPass123', | |
| 'password2': 'TestPass123', | |
| }) | |
| self.assertTrue(form.is_valid()) | |
| def test_password_mismatch(self): | |
| """Test that mismatched passwords fail validation.""" | |
| form = UserCreationForm(data={ | |
| 'username': 'testuser', | |
| 'email': 'test@example.com', | |
| 'password1': 'TestPass123', | |
| 'password2': 'testpass123', # Different case | |
| }) | |
| self.assertFalse(form.is_valid()) | |
| self.assertIn('passwords', str(form.errors).lower()) | |
| ''', | |
| }, | |
| { | |
| "instance_id": "flask__flask-1048", | |
| "repo": "pallets/flask", | |
| "problem": "Fix JSON encoding for datetime objects", | |
| "buggy_code": '''import json | |
| from datetime import datetime, date | |
| class JSONEncoder(json.JSONEncoder): | |
| """Custom JSON encoder for Flask.""" | |
| def default(self, obj): | |
| # BUG: Missing handling for datetime objects | |
| if isinstance(obj, date): | |
| return obj.isoformat() | |
| return super().default(obj) | |
| def to_json(obj): | |
| """Convert object to JSON string.""" | |
| return json.dumps(obj, cls=JSONEncoder) | |
| ''', | |
| "test_code": '''import unittest | |
| from datetime import datetime | |
| from buggy import to_json | |
| class TestJSONEncoding(unittest.TestCase): | |
| def test_encode_datetime(self): | |
| """Test that datetime objects are properly encoded.""" | |
| dt = datetime(2024, 1, 15, 10, 30, 0) | |
| result = to_json({'timestamp': dt}) | |
| self.assertIn('2024-01-15', result) | |
| self.assertIn('10:30:00', result) | |
| def test_encode_date(self): | |
| """Test that date objects are properly encoded.""" | |
| d = date(2024, 1, 15) | |
| result = to_json({'date': d}) | |
| self.assertIn('2024-01-15', result) | |
| ''', | |
| }, | |
| { | |
| "instance_id": "requests__requests-2875", | |
| "repo": "psf/requests", | |
| "problem": "Fix cookie domain matching", | |
| "buggy_code": '''import re | |
| from urllib.parse import urlparse | |
| def match_cookie_domain(cookie_domain, request_domain): | |
| """Check if cookie domain matches request domain.""" | |
| # BUG: Should handle leading dots differently | |
| # .example.com should match sub.example.com but not example.com | |
| cookie_domain = cookie_domain.lower() | |
| request_domain = request_domain.lower() | |
| if cookie_domain.startswith('.'): | |
| return request_domain.endswith(cookie_domain) | |
| return cookie_domain == request_domain | |
| ''', | |
| "test_code": '''import unittest | |
| from buggy import match_cookie_domain | |
| class TestCookieDomain(unittest.TestCase): | |
| def test_exact_match(self): | |
| """Test exact domain matching.""" | |
| self.assertTrue(match_cookie_domain('example.com', 'example.com')) | |
| def test_subdomain_with_dot(self): | |
| """Test subdomain matching with leading dot.""" | |
| # .example.com should match sub.example.com | |
| self.assertTrue(match_cookie_domain('.example.com', 'sub.example.com')) | |
| self.assertFalse(match_cookie_domain('.example.com', 'example.com')) | |
| def test_different_domains(self): | |
| """Test different domains don't match.""" | |
| self.assertFalse(match_cookie_domain('example.com', 'other.com')) | |
| ''', | |
| }, | |
| { | |
| "instance_id": "numpy__numpy-10825", | |
| "repo": "numpy/numpy", | |
| "problem": "Fix array concatenation edge case", | |
| "buggy_code": '''import numpy as np | |
| def concatenate_arrays(*arrays): | |
| """Concatenate multiple arrays along axis 0.""" | |
| if not arrays: | |
| return np.array([]) | |
| # BUG: Should handle None arrays gracefully | |
| result = arrays[0] | |
| for arr in arrays[1:]: | |
| result = np.concatenate([result, arr]) | |
| return result | |
| ''', | |
| "test_code": '''import unittest | |
| import numpy as np | |
| from buggy import concatenate_arrays | |
| class TestArrayConcatenation(unittest.TestCase): | |
| def test_basic_concatenation(self): | |
| """Test basic array concatenation.""" | |
| a = np.array([1, 2, 3]) | |
| b = np.array([4, 5, 6]) | |
| result = concatenate_arrays(a, b) | |
| np.testing.assert_array_equal(result, np.array([1, 2, 3, 4, 5, 6])) | |
| def test_empty_input(self): | |
| """Test empty input returns empty array.""" | |
| result = concatenate_arrays() | |
| self.assertEqual(len(result), 0) | |
| def test_single_array(self): | |
| """Test single array passes through.""" | |
| a = np.array([1, 2, 3]) | |
| result = concatenate_arrays(a) | |
| np.testing.assert_array_equal(result, a) | |
| ''', | |
| }, | |
| { | |
| "instance_id": "pandas__pandas-15230", | |
| "repo": "pandas-dev/pandas", | |
| "problem": "Fix DataFrame groupby aggregation", | |
| "buggy_code": '''import pandas as pd | |
| def group_and_aggregate(df, group_col, agg_col, agg_func='mean'): | |
| """Group DataFrame and aggregate.""" | |
| # BUG: Should handle non-numeric columns gracefully | |
| if agg_func == 'mean': | |
| return df.groupby(group_col)[agg_col].mean() | |
| elif agg_func == 'sum': | |
| return df.groupby(group_col)[agg_col].sum() | |
| elif agg_func == 'count': | |
| return df.groupby(group_col)[agg_col].count() | |
| else: | |
| raise ValueError(f"Unknown aggregation function: {agg_func}") | |
| ''', | |
| "test_code": '''import unittest | |
| import pandas as pd | |
| from buggy import group_and_aggregate | |
| class TestGroupBy(unittest.TestCase): | |
| def test_mean_aggregation(self): | |
| """Test mean aggregation.""" | |
| df = pd.DataFrame({ | |
| 'category': ['A', 'A', 'B', 'B'], | |
| 'value': [1, 2, 3, 4] | |
| }) | |
| result = group_and_aggregate(df, 'category', 'value', 'mean') | |
| self.assertEqual(result['A'], 1.5) | |
| self.assertEqual(result['B'], 3.5) | |
| def test_sum_aggregation(self): | |
| """Test sum aggregation.""" | |
| df = pd.DataFrame({ | |
| 'category': ['A', 'A', 'B'], | |
| 'value': [1, 2, 3] | |
| }) | |
| result = group_and_aggregate(df, 'category', 'value', 'sum') | |
| self.assertEqual(result['A'], 3) | |
| self.assertEqual(result['B'], 3) | |
| ''', | |
| }, | |
| { | |
| "instance_id": "scipy__scipy-1925", | |
| "repo": "scipy/scipy", | |
| "problem": "Fix signal filtering edge case", | |
| "buggy_code": '''import numpy as np | |
| from scipy import signal | |
| def apply_lowpass_filter(data, cutoff, fs, order=5): | |
| """Apply lowpass filter to data.""" | |
| # BUG: Should validate cutoff frequency | |
| nyquist = fs / 2 | |
| normalized_cutoff = cutoff / nyquist | |
| # BUG: Using invalid cutoff can cause filter design failure | |
| b, a = signal.butter(order, normalized_cutoff, btype='low') | |
| filtered = signal.filtfilt(b, a, data) | |
| return filtered | |
| ''', | |
| "test_code": '''import unittest | |
| import numpy as np | |
| from buggy import apply_lowpass_filter | |
| class TestSignalFiltering(unittest.TestCase): | |
| def test_valid_filter(self): | |
| """Test filtering with valid parameters.""" | |
| fs = 1000 # Sampling frequency | |
| cutoff = 100 # Cutoff frequency | |
| t = np.linspace(0, 1, fs) | |
| data = np.sin(2 * np.pi * 50 * t) + 0.5 * np.sin(2 * np.pi * 200 * t) | |
| result = apply_lowpass_filter(data, cutoff, fs) | |
| self.assertEqual(len(result), len(data)) | |
| # Low frequency component should be preserved | |
| self.assertTrue(np.abs(result[100]) > 0.5) | |
| def test_invalid_cutoff(self): | |
| """Test that invalid cutoff raises error.""" | |
| fs = 1000 | |
| cutoff = 2000 # Above Nyquist frequency - should fail | |
| data = np.array([1, 2, 3, 4, 5]) | |
| with self.assertRaises(ValueError): | |
| apply_lowpass_filter(data, cutoff, fs) | |
| ''', | |
| }, | |
| { | |
| "instance_id": "sklearn__sklearn-12345", | |
| "repo": "scikit-learn/scikit-learn", | |
| "problem": "Fix cross-validation split", | |
| "buggy_code": '''import numpy as np | |
| from sklearn.model_selection import KFold | |
| def get_cv_splits(X, n_splits=5, shuffle=True, random_state=42): | |
| """Get cross-validation splits.""" | |
| # BUG: random_state should be used for reproducibility | |
| kf = KFold(n_splits=n_splits, shuffle=shuffle) | |
| splits = [] | |
| for train_idx, test_idx in kf.split(X): | |
| splits.append((train_idx, test_idx)) | |
| return splits | |
| ''', | |
| "test_code": '''import unittest | |
| import numpy as np | |
| from buggy import get_cv_splits | |
| class TestCVSplits(unittest.TestCase): | |
| def test_split_count(self): | |
| """Test that correct number of splits is generated.""" | |
| X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) | |
| splits = get_cv_splits(X, n_splits=3) | |
| self.assertEqual(len(splits), 3) | |
| def test_reproducibility(self): | |
| """Test that splits are reproducible with same random_state.""" | |
| X = np.random.rand(100, 5) | |
| splits1 = get_cv_splits(X, n_splits=5, random_state=42) | |
| splits2 = get_cv_splits(X, n_splits=5, random_state=42) | |
| for (train1, test1), (train2, test2) in zip(splits1, splits2): | |
| np.testing.assert_array_equal(train1, train2) | |
| np.testing.assert_array_equal(test1, test2) | |
| ''', | |
| }, | |
| { | |
| "instance_id": "pytest__pytest-7426", | |
| "repo": "pytest-dev/pytest", | |
| "problem": "Fix test collection order", | |
| "buggy_code": '''import os | |
| import re | |
| def collect_tests(directory, pattern='test_*.py'): | |
| """Collect test files from directory.""" | |
| # BUG: Should sort files for consistent ordering | |
| test_files = [] | |
| for root, dirs, files in os.walk(directory): | |
| for file in files: | |
| if re.match(pattern, file): | |
| test_files.append(os.path.join(root, file)) | |
| return test_files | |
| ''', | |
| "test_code": '''import unittest | |
| import os | |
| import tempfile | |
| from buggy import collect_tests | |
| class TestCollection(unittest.TestCase): | |
| def test_collect_pattern(self): | |
| """Test that correct pattern is matched.""" | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| # Create test files | |
| open(os.path.join(tmpdir, 'test_a.py'), 'w').close() | |
| open(os.path.join(tmpdir, 'test_b.py'), 'w').close() | |
| open(os.path.join(tmpdir, 'not_a_test.py'), 'w').close() | |
| tests = collect_tests(tmpdir, 'test_*.py') | |
| self.assertEqual(len(tests), 2) | |
| def test_consistent_order(self): | |
| """Test that file order is consistent.""" | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| for name in ['test_c.py', 'test_a.py', 'test_b.py']: | |
| open(os.path.join(tmpdir, name), 'w').close() | |
| tests1 = collect_tests(tmpdir) | |
| tests2 = collect_tests(tmpdir) | |
| self.assertEqual(tests1, tests2) | |
| ''', | |
| }, | |
| { | |
| "instance_id": "transformers__transformers-12345", | |
| "repo": "huggingface/transformers", | |
| "problem": "Fix tokenization padding", | |
| "buggy_code": '''from typing import List | |
| def tokenize_and_pad(tokenizer, texts: List[str], max_length: int = 512): | |
| """Tokenize texts and pad to max length.""" | |
| # BUG: Should handle padding correctly | |
| encoded = tokenizer( | |
| texts, | |
| padding=True, # This pads to longest in batch, not max_length | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors='pt' | |
| ) | |
| return encoded | |
| ''', | |
| "test_code": '''import unittest | |
| from buggy import tokenize_and_pad | |
| class MockTokenizer: | |
| def __call__(self, texts, padding=True, truncation=True, max_length=512, return_tensors=None): | |
| # Simplified mock | |
| return { | |
| 'input_ids': [[1, 2, 3]] if isinstance(texts, list) else [1, 2, 3], | |
| 'attention_mask': [[1, 1, 1]] if isinstance(texts, list) else [1, 1, 1] | |
| } | |
| class TestTokenization(unittest.TestCase): | |
| def test_single_text(self): | |
| """Test tokenizing single text.""" | |
| tokenizer = MockTokenizer() | |
| result = tokenize_and_pad(tokenizer, ["hello world"]) | |
| self.assertIn('input_ids', result) | |
| def test_max_length_respected(self): | |
| """Test that max_length is respected.""" | |
| tokenizer = MockTokenizer() | |
| # Should not raise even with long text | |
| result = tokenize_and_pad(tokenizer, ["short"], max_length=10) | |
| self.assertIn('input_ids', result) | |
| ''', | |
| }, | |
| ] | |
| # Easy, Medium, Hard difficulty assignments | |
| DIFFICULTY_TASKS = { | |
| "easy": SWE_BENCH_PROBLEMS[:3], | |
| "medium": SWE_BENCH_PROBLEMS[3:6], | |
| "hard": SWE_BENCH_PROBLEMS[6:], | |
| } | |
| def generate_tasks(output_dir: Path, count_per_difficulty: int = 3): | |
| """Generate SWE-bench style tasks.""" | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| total_created = 0 | |
| for difficulty, problems in DIFFICULTY_TASKS.items(): | |
| for i, problem in enumerate(problems[:count_per_difficulty]): | |
| instance_id = f"{problem['instance_id']}_{difficulty}_{i}" | |
| instance_dir = output_dir / instance_id | |
| instance_dir.mkdir(parents=True, exist_ok=True) | |
| # Write buggy.py | |
| buggy_file = instance_dir / "buggy.py" | |
| buggy_file.write_text(problem["buggy_code"], encoding="utf-8") | |
| # Write test.py | |
| test_file = instance_dir / "test.py" | |
| test_file.write_text(problem["test_code"], encoding="utf-8") | |
| # Write metadata.json | |
| metadata = { | |
| "instance_id": instance_id, | |
| "repo": problem["repo"], | |
| "problem_statement": problem["problem"], | |
| "difficulty": difficulty, | |
| } | |
| metadata_file = instance_dir / "metadata.json" | |
| metadata_file.write_text(json.dumps(metadata, indent=2), encoding="utf-8") | |
| total_created += 1 | |
| print(f"Created {total_created} tasks in {output_dir}") | |
| print(f"Set environment variable: SWEBENCH_TASKS_ROOT={output_dir.absolute()}") | |
| print(f"Or run with: TASK_SOURCE=swebench python inference.py") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Generate SWE-bench style tasks") | |
| parser.add_argument( | |
| "--count", | |
| type=int, | |
| default=3, | |
| help="Number of tasks per difficulty (default: 3)" | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default=None, | |
| help="Output directory (default: dataset/swebench_lite_tasks)" | |
| ) | |
| args = parser.parse_args() | |
| if args.output_dir: | |
| output_dir = Path(args.output_dir) | |
| else: | |
| script_dir = Path(__file__).parent | |
| output_dir = script_dir / "swebench_lite_tasks" | |
| generate_tasks(output_dir, args.count) | |
| if __name__ == "__main__": | |
| main() |