diff --git a/.gitignore b/.gitignore index e9462672e893eba13c4b5c88a27217eecc312c8c..323e06e034048483f30d6f26755d0fdbc756ef3e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,6 @@ __pycache__/ commands.md logs.md inference&docker.md +logs2.md +.env.example +file.txt \ No newline at end of file diff --git a/prompts.py b/prompts.py index 9374593927bb55189df1dddfe8137baf12d7fc6d..6837ef5ca541f8301a5ecedd402b0b73a6a3e07d 100644 --- a/prompts.py +++ b/prompts.py @@ -1,37 +1,20 @@ LLM_SCORER_PROMPT = """ - You are a reward model for an autonomous code bug-fixing agent trained with reinforcement learning. - Your scores are used directly as a learning signal — be precise, consistent, and strict. - - You will receive: - - ORIGINAL: the buggy code before the agent's fix - - PATCHED: the code after the agent applied its patch - - Evaluate the agent's fix on exactly three axes, each scored 0.0–10.0: - - 1. CORRECTNESS — Does the patch fix the bug(s) without introducing new ones? - Full marks only if the fix is semantically correct and complete. - Penalise partial fixes, over-patches, or fixes that mask rather than resolve the root cause. - - 2. MINIMALITY — Is the diff minimal? Penalise unnecessary refactors, renames, whitespace-only changes, - or reformatting of lines unrelated to the bug. - - 3. QUALITY — Is the patched code readable and idiomatic? Penalise: broken naming conventions, - added dead code, removed necessary comments, or degraded clarity vs. the original. - - Respond ONLY with this JSON — no preamble, no trailing text: - { - "correctness": , - "minimality": , - "quality": , - "reasoning": "" - } +You are a reward model for a code-fixing RL agent. Evaluate the PATCHED code vs. ORIGINAL on three axes (0.0–10.0): +1. CORRECTNESS — Does the patch fix the bug(s) without new bugs? +2. MINIMALITY — Is the diff minimal? Penalize unrelated changes. +3. QUALITY — Is the code readable and idiomatic? +Respond ONLY with this JSON (no preamble): +{"correctness": , "minimality": , "quality": , "reasoning": ""} """ - -USER_TEMPLATE =""" - ORIGINAL: - ```python - {original_code} - ``` - Return only the JSON. +USER_TEMPLATE = """ +ORIGINAL: +```python +{original_code} +``` +PATCHED: +```python +{patched_code} +``` +Return only the JSON. """ \ No newline at end of file diff --git a/rl_code_fix_env/.gitignore b/rl_code_fix_env/.gitignore index d5b1d8e9730c796e19caf22ae32ddb43390ab663..f867f007182bd969114c3463b5914c46d05485e4 100644 --- a/rl_code_fix_env/.gitignore +++ b/rl_code_fix_env/.gitignore @@ -5,4 +5,5 @@ __pycache__/ .env *.pyc *.egg -pytest-cache-files-*/ \ No newline at end of file +pytest-cache-files-*/ +*.ps1 \ No newline at end of file diff --git a/rl_code_fix_env/README.md b/rl_code_fix_env/README.md index 26d93558f7908ef6d5fa0037dc4fbb93d0bc2cf4..55a00dea8542fc39ce9ac1a894cceb3fe7354b45 100644 --- a/rl_code_fix_env/README.md +++ b/rl_code_fix_env/README.md @@ -5,6 +5,7 @@ colorFrom: green colorTo: purple sdk: docker pinned: false +dockerfile: server/Dockerfile app_port: 8000 base_path: /web tags: diff --git a/rl_code_fix_env/_aliases.py b/rl_code_fix_env/_aliases.py new file mode 100644 index 0000000000000000000000000000000000000000..9773017049079305f08955e3ce2dfc60ee5d899f --- /dev/null +++ b/rl_code_fix_env/_aliases.py @@ -0,0 +1,21 @@ +import sys +import importlib +from pathlib import Path + +_REPO_ROOT = str(Path(__file__).parent) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +import dataset as _real_dataset + +sys.modules.setdefault("src.dataset", _real_dataset) + +import pkgutil +for _pkg in pkgutil.iter_modules(_real_dataset.__path__): + _full = f"dataset.{_pkg.name}" + _alias = f"src.dataset.{_pkg.name}" + try: + _mod = importlib.import_module(_full) + sys.modules.setdefault(_alias, _mod) + except Exception: + pass diff --git a/rl_code_fix_env/conftest.py b/rl_code_fix_env/conftest.py index d30217483dab6e2cc3a660b77e336c9374976b43..9773017049079305f08955e3ce2dfc60ee5d899f 100644 --- a/rl_code_fix_env/conftest.py +++ b/rl_code_fix_env/conftest.py @@ -1,20 +1,3 @@ -""" -conftest.py repo-root pytest configuration. - -Registers `src.dataset` as a sys.modules alias for `dataset` so that all -problem test files using `from src.dataset.problem_X.buggy import ...` -resolve correctly without needing to rename 24 test files. - -The physical layout is: - /dataset/problem_X/buggy.py real files - /src/ has environment/, reward/, etc. - but NO dataset/ subfolder - -With PYTHONPATH=: - import dataset.problem_1.buggy works natively - import src.dataset.problem_1.buggy would fail fixed here via alias -""" - import sys import importlib from pathlib import Path diff --git a/rl_code_fix_env/dataset/generate_swebench_tasks.py b/rl_code_fix_env/dataset/generate_swebench_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..b23185978b1bb3ff4f17bb8d6221566af04f4b4b --- /dev/null +++ b/rl_code_fix_env/dataset/generate_swebench_tasks.py @@ -0,0 +1,498 @@ +""" +Generate synthetic SWE-bench style tasks for testing. + +This creates tasks that mimic the SWE-bench format: +- instance_id/buggy.py - the buggy code +- instance_id/test.py - test file +- instance_id/metadata.json - metadata + +Usage: + python -m dataset.generate_swebench_tasks [--count N] +""" + +import argparse +import json +import random +from pathlib import Path + + +# Sample SWE-bench style problems +SWE_BENCH_PROBLEMS = [ + { + "instance_id": "django__django-11098", + "repo": "django/django", + "problem": "Fix the user creation form validation error", + "buggy_code": '''from django import forms +from django.contrib.auth.models import User + +class UserCreationForm(forms.ModelForm): + """Form for creating new users.""" + password1 = forms.CharField(widget=forms.PasswordInput) + password2 = forms.CharField(widget=forms.PasswordInput) + + class Meta: + model = User + fields = ('username', 'email') + + def clean(self): + cleaned_data = super().clean() + password1 = cleaned_data.get('password1') + password2 = cleaned_data.get('password2') + + # BUG: This comparison is case-sensitive but should be case-insensitive + if password1 != password2: + raise forms.ValidationError("Passwords don't match") + + return cleaned_data + + def save(self, commit=True): + user = super().save(commit=False) + user.set_password(self.cleaned_data['password1']) + if commit: + user.save() + return user +''', + "test_code": '''import unittest +from buggy import UserCreationForm + +class TestUserCreationForm(unittest.TestCase): + def test_password_matching(self): + """Test that matching passwords pass validation.""" + form = UserCreationForm(data={ + 'username': 'testuser', + 'email': 'test@example.com', + 'password1': 'TestPass123', + 'password2': 'TestPass123', + }) + self.assertTrue(form.is_valid()) + + def test_password_mismatch(self): + """Test that mismatched passwords fail validation.""" + form = UserCreationForm(data={ + 'username': 'testuser', + 'email': 'test@example.com', + 'password1': 'TestPass123', + 'password2': 'testpass123', # Different case + }) + self.assertFalse(form.is_valid()) + self.assertIn('passwords', str(form.errors).lower()) +''', + }, + { + "instance_id": "flask__flask-1048", + "repo": "pallets/flask", + "problem": "Fix JSON encoding for datetime objects", + "buggy_code": '''import json +from datetime import datetime, date + +class JSONEncoder(json.JSONEncoder): + """Custom JSON encoder for Flask.""" + + def default(self, obj): + # BUG: Missing handling for datetime objects + if isinstance(obj, date): + return obj.isoformat() + return super().default(obj) + +def to_json(obj): + """Convert object to JSON string.""" + return json.dumps(obj, cls=JSONEncoder) +''', + "test_code": '''import unittest +from datetime import datetime +from buggy import to_json + +class TestJSONEncoding(unittest.TestCase): + def test_encode_datetime(self): + """Test that datetime objects are properly encoded.""" + dt = datetime(2024, 1, 15, 10, 30, 0) + result = to_json({'timestamp': dt}) + self.assertIn('2024-01-15', result) + self.assertIn('10:30:00', result) + + def test_encode_date(self): + """Test that date objects are properly encoded.""" + d = date(2024, 1, 15) + result = to_json({'date': d}) + self.assertIn('2024-01-15', result) +''', + }, + { + "instance_id": "requests__requests-2875", + "repo": "psf/requests", + "problem": "Fix cookie domain matching", + "buggy_code": '''import re +from urllib.parse import urlparse + +def match_cookie_domain(cookie_domain, request_domain): + """Check if cookie domain matches request domain.""" + # BUG: Should handle leading dots differently + # .example.com should match sub.example.com but not example.com + cookie_domain = cookie_domain.lower() + request_domain = request_domain.lower() + + if cookie_domain.startswith('.'): + return request_domain.endswith(cookie_domain) + + return cookie_domain == request_domain +''', + "test_code": '''import unittest +from buggy import match_cookie_domain + +class TestCookieDomain(unittest.TestCase): + def test_exact_match(self): + """Test exact domain matching.""" + self.assertTrue(match_cookie_domain('example.com', 'example.com')) + + def test_subdomain_with_dot(self): + """Test subdomain matching with leading dot.""" + # .example.com should match sub.example.com + self.assertTrue(match_cookie_domain('.example.com', 'sub.example.com')) + self.assertFalse(match_cookie_domain('.example.com', 'example.com')) + + def test_different_domains(self): + """Test different domains don't match.""" + self.assertFalse(match_cookie_domain('example.com', 'other.com')) +''', + }, + { + "instance_id": "numpy__numpy-10825", + "repo": "numpy/numpy", + "problem": "Fix array concatenation edge case", + "buggy_code": '''import numpy as np + +def concatenate_arrays(*arrays): + """Concatenate multiple arrays along axis 0.""" + if not arrays: + return np.array([]) + + # BUG: Should handle None arrays gracefully + result = arrays[0] + for arr in arrays[1:]: + result = np.concatenate([result, arr]) + + return result +''', + "test_code": '''import unittest +import numpy as np +from buggy import concatenate_arrays + +class TestArrayConcatenation(unittest.TestCase): + def test_basic_concatenation(self): + """Test basic array concatenation.""" + a = np.array([1, 2, 3]) + b = np.array([4, 5, 6]) + result = concatenate_arrays(a, b) + np.testing.assert_array_equal(result, np.array([1, 2, 3, 4, 5, 6])) + + def test_empty_input(self): + """Test empty input returns empty array.""" + result = concatenate_arrays() + self.assertEqual(len(result), 0) + + def test_single_array(self): + """Test single array passes through.""" + a = np.array([1, 2, 3]) + result = concatenate_arrays(a) + np.testing.assert_array_equal(result, a) +''', + }, + { + "instance_id": "pandas__pandas-15230", + "repo": "pandas-dev/pandas", + "problem": "Fix DataFrame groupby aggregation", + "buggy_code": '''import pandas as pd + +def group_and_aggregate(df, group_col, agg_col, agg_func='mean'): + """Group DataFrame and aggregate.""" + # BUG: Should handle non-numeric columns gracefully + if agg_func == 'mean': + return df.groupby(group_col)[agg_col].mean() + elif agg_func == 'sum': + return df.groupby(group_col)[agg_col].sum() + elif agg_func == 'count': + return df.groupby(group_col)[agg_col].count() + else: + raise ValueError(f"Unknown aggregation function: {agg_func}") +''', + "test_code": '''import unittest +import pandas as pd +from buggy import group_and_aggregate + +class TestGroupBy(unittest.TestCase): + def test_mean_aggregation(self): + """Test mean aggregation.""" + df = pd.DataFrame({ + 'category': ['A', 'A', 'B', 'B'], + 'value': [1, 2, 3, 4] + }) + result = group_and_aggregate(df, 'category', 'value', 'mean') + self.assertEqual(result['A'], 1.5) + self.assertEqual(result['B'], 3.5) + + def test_sum_aggregation(self): + """Test sum aggregation.""" + df = pd.DataFrame({ + 'category': ['A', 'A', 'B'], + 'value': [1, 2, 3] + }) + result = group_and_aggregate(df, 'category', 'value', 'sum') + self.assertEqual(result['A'], 3) + self.assertEqual(result['B'], 3) +''', + }, + { + "instance_id": "scipy__scipy-1925", + "repo": "scipy/scipy", + "problem": "Fix signal filtering edge case", + "buggy_code": '''import numpy as np +from scipy import signal + +def apply_lowpass_filter(data, cutoff, fs, order=5): + """Apply lowpass filter to data.""" + # BUG: Should validate cutoff frequency + nyquist = fs / 2 + normalized_cutoff = cutoff / nyquist + + # BUG: Using invalid cutoff can cause filter design failure + b, a = signal.butter(order, normalized_cutoff, btype='low') + filtered = signal.filtfilt(b, a, data) + + return filtered +''', + "test_code": '''import unittest +import numpy as np +from buggy import apply_lowpass_filter + +class TestSignalFiltering(unittest.TestCase): + def test_valid_filter(self): + """Test filtering with valid parameters.""" + fs = 1000 # Sampling frequency + cutoff = 100 # Cutoff frequency + t = np.linspace(0, 1, fs) + data = np.sin(2 * np.pi * 50 * t) + 0.5 * np.sin(2 * np.pi * 200 * t) + + result = apply_lowpass_filter(data, cutoff, fs) + self.assertEqual(len(result), len(data)) + # Low frequency component should be preserved + self.assertTrue(np.abs(result[100]) > 0.5) + + def test_invalid_cutoff(self): + """Test that invalid cutoff raises error.""" + fs = 1000 + cutoff = 2000 # Above Nyquist frequency - should fail + data = np.array([1, 2, 3, 4, 5]) + + with self.assertRaises(ValueError): + apply_lowpass_filter(data, cutoff, fs) +''', + }, + { + "instance_id": "sklearn__sklearn-12345", + "repo": "scikit-learn/scikit-learn", + "problem": "Fix cross-validation split", + "buggy_code": '''import numpy as np +from sklearn.model_selection import KFold + +def get_cv_splits(X, n_splits=5, shuffle=True, random_state=42): + """Get cross-validation splits.""" + # BUG: random_state should be used for reproducibility + kf = KFold(n_splits=n_splits, shuffle=shuffle) + + splits = [] + for train_idx, test_idx in kf.split(X): + splits.append((train_idx, test_idx)) + + return splits +''', + "test_code": '''import unittest +import numpy as np +from buggy import get_cv_splits + +class TestCVSplits(unittest.TestCase): + def test_split_count(self): + """Test that correct number of splits is generated.""" + X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) + splits = get_cv_splits(X, n_splits=3) + self.assertEqual(len(splits), 3) + + def test_reproducibility(self): + """Test that splits are reproducible with same random_state.""" + X = np.random.rand(100, 5) + splits1 = get_cv_splits(X, n_splits=5, random_state=42) + splits2 = get_cv_splits(X, n_splits=5, random_state=42) + + for (train1, test1), (train2, test2) in zip(splits1, splits2): + np.testing.assert_array_equal(train1, train2) + np.testing.assert_array_equal(test1, test2) +''', + }, + { + "instance_id": "pytest__pytest-7426", + "repo": "pytest-dev/pytest", + "problem": "Fix test collection order", + "buggy_code": '''import os +import re + +def collect_tests(directory, pattern='test_*.py'): + """Collect test files from directory.""" + # BUG: Should sort files for consistent ordering + test_files = [] + + for root, dirs, files in os.walk(directory): + for file in files: + if re.match(pattern, file): + test_files.append(os.path.join(root, file)) + + return test_files +''', + "test_code": '''import unittest +import os +import tempfile +from buggy import collect_tests + +class TestCollection(unittest.TestCase): + def test_collect_pattern(self): + """Test that correct pattern is matched.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test files + open(os.path.join(tmpdir, 'test_a.py'), 'w').close() + open(os.path.join(tmpdir, 'test_b.py'), 'w').close() + open(os.path.join(tmpdir, 'not_a_test.py'), 'w').close() + + tests = collect_tests(tmpdir, 'test_*.py') + self.assertEqual(len(tests), 2) + + def test_consistent_order(self): + """Test that file order is consistent.""" + with tempfile.TemporaryDirectory() as tmpdir: + for name in ['test_c.py', 'test_a.py', 'test_b.py']: + open(os.path.join(tmpdir, name), 'w').close() + + tests1 = collect_tests(tmpdir) + tests2 = collect_tests(tmpdir) + + self.assertEqual(tests1, tests2) +''', + }, + { + "instance_id": "transformers__transformers-12345", + "repo": "huggingface/transformers", + "problem": "Fix tokenization padding", + "buggy_code": '''from typing import List + +def tokenize_and_pad(tokenizer, texts: List[str], max_length: int = 512): + """Tokenize texts and pad to max length.""" + # BUG: Should handle padding correctly + encoded = tokenizer( + texts, + padding=True, # This pads to longest in batch, not max_length + truncation=True, + max_length=max_length, + return_tensors='pt' + ) + + return encoded +''', + "test_code": '''import unittest +from buggy import tokenize_and_pad + +class MockTokenizer: + def __call__(self, texts, padding=True, truncation=True, max_length=512, return_tensors=None): + # Simplified mock + return { + 'input_ids': [[1, 2, 3]] if isinstance(texts, list) else [1, 2, 3], + 'attention_mask': [[1, 1, 1]] if isinstance(texts, list) else [1, 1, 1] + } + +class TestTokenization(unittest.TestCase): + def test_single_text(self): + """Test tokenizing single text.""" + tokenizer = MockTokenizer() + result = tokenize_and_pad(tokenizer, ["hello world"]) + self.assertIn('input_ids', result) + + def test_max_length_respected(self): + """Test that max_length is respected.""" + tokenizer = MockTokenizer() + # Should not raise even with long text + result = tokenize_and_pad(tokenizer, ["short"], max_length=10) + self.assertIn('input_ids', result) +''', + }, +] + +# Easy, Medium, Hard difficulty assignments +DIFFICULTY_TASKS = { + "easy": SWE_BENCH_PROBLEMS[:3], + "medium": SWE_BENCH_PROBLEMS[3:6], + "hard": SWE_BENCH_PROBLEMS[6:], +} + + +def generate_tasks(output_dir: Path, count_per_difficulty: int = 3): + """Generate SWE-bench style tasks.""" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + total_created = 0 + + for difficulty, problems in DIFFICULTY_TASKS.items(): + for i, problem in enumerate(problems[:count_per_difficulty]): + instance_id = f"{problem['instance_id']}_{difficulty}_{i}" + instance_dir = output_dir / instance_id + instance_dir.mkdir(parents=True, exist_ok=True) + + # Write buggy.py + buggy_file = instance_dir / "buggy.py" + buggy_file.write_text(problem["buggy_code"], encoding="utf-8") + + # Write test.py + test_file = instance_dir / "test.py" + test_file.write_text(problem["test_code"], encoding="utf-8") + + # Write metadata.json + metadata = { + "instance_id": instance_id, + "repo": problem["repo"], + "problem_statement": problem["problem"], + "difficulty": difficulty, + } + metadata_file = instance_dir / "metadata.json" + metadata_file.write_text(json.dumps(metadata, indent=2), encoding="utf-8") + + total_created += 1 + + print(f"Created {total_created} tasks in {output_dir}") + print(f"Set environment variable: SWEBENCH_TASKS_ROOT={output_dir.absolute()}") + print(f"Or run with: TASK_SOURCE=swebench python inference.py") + + +def main(): + parser = argparse.ArgumentParser(description="Generate SWE-bench style tasks") + parser.add_argument( + "--count", + type=int, + default=3, + help="Number of tasks per difficulty (default: 3)" + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Output directory (default: dataset/swebench_lite_tasks)" + ) + + args = parser.parse_args() + + if args.output_dir: + output_dir = Path(args.output_dir) + else: + script_dir = Path(__file__).parent + output_dir = script_dir / "swebench_lite_tasks" + + generate_tasks(output_dir, args.count) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/rl_code_fix_env/dataset/prepare_swebench.py b/rl_code_fix_env/dataset/prepare_swebench.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2c3cde1a868ea7666cd1cfeecb7f39b83e3d78 --- /dev/null +++ b/rl_code_fix_env/dataset/prepare_swebench.py @@ -0,0 +1,274 @@ +""" +Script to download and materialize SWE-bench Lite tasks. + +This script: +1. Downloads SWE-bench Lite dataset from HuggingFace +2. Extracts the buggy code and creates test files +3. Organizes them into the expected directory structure + +Usage: + python -m dataset.prepare_swebench [--max-tasks N] [--difficulty easy|medium|hard|all] +""" + +import argparse +import os +import sys +from pathlib import Path + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from datasets import load_dataset + + +def get_problem_statement(row): + """Extract problem statement from row.""" + return row.get("problem_statement", "") + + +def get_patch(row): + """Extract the patch/fix from row.""" + return row.get("patch", "") + + +def get_instance_id(row): + """Get instance ID from row.""" + return row.get("instance_id", "") + + +def create_buggy_file(instance_dir: Path, row): + """ + Create buggy.py from the base commit and instance. + + The SWE-bench dataset provides the full repository at base_commit. + We need to extract just the relevant file that has the bug. + """ + # For SWE-bench, the "buggy" version is actually the version BEFORE the patch + # We need to get the file content from the base commit + # This is complex as it requires cloning the repo at a specific commit + + # For simplicity, we'll use a different approach: + # The problem_statement describes the bug, and we can create a simplified + # buggy version based on that description + + instance_id = get_instance_id(row) + problem_stmt = get_problem_statement(row) + + # Try to extract the file from the created files in the instance + # SWE-bench provides 'repo' and we need to find the relevant file + created_files = row.get("created_files", []) + + if not created_files: + # Fallback: create a placeholder + buggy_code = f'''# Buggy code for {instance_id} +# Problem: {problem_stmt[:200]}... + +def solution(): + """Placeholder solution - needs to be fixed.""" + pass +''' + else: + # For now, create a simple placeholder + # In a full implementation, we'd clone the repo at base_commit + file_path = created_files[0] if created_files else "solution.py" + buggy_code = f'''# Buggy code for {instance_id} +# File: {file_path} +# Problem: {problem_stmt[:200]}... + +def solution(): + """Placeholder solution - needs to be fixed.""" + pass +''' + + buggy_file = instance_dir / "buggy.py" + buggy_file.write_text(buggy_code, encoding="utf-8") + return buggy_file + + +def create_test_file(instance_dir: Path, row): + """ + Create test.py based on the problem statement. + + For SWE-bench, tests are typically derived from the issue description. + We'll create a simple test that checks if the solution works. + """ + instance_id = get_instance_id(row) + problem_stmt = get_problem_statement(row) + + # Create a simple test file + # In practice, SWE-bench has a test.json file with test cases + test_cases = row.get("test_cases", []) + + if test_cases: + # Create tests from provided test cases + test_code = "import unittest\\n\\n" + for i, tc in enumerate(test_cases): + input_str = tc.get("input", "") + output_str = tc.get("output", "") + test_code += f'''class TestSolution(unittest.TestCase): + def test_case_{i+1}(self): + # Input: {input_str} + # Expected: {output_str} + pass # TODO: Add actual test +''' + else: + # Create a basic test based on problem statement + test_code = f'''"""Test file for {instance_id}""" + +import unittest +from buggy import solution + + +class TestSolution(unittest.TestCase): + def test_basic(self): + """Test based on problem statement.""" + # Problem: {problem_stmt[:300]}... + result = solution() + self.assertIsNotNone(result) + + +if __name__ == "__main__": + unittest.main() +''' + + test_file = instance_dir / "test.py" + test_file.write_text(test_code, encoding="utf-8") + return test_file + + +def create_metadata_file(instance_dir: Path, row): + """Create metadata.json with instance info.""" + import json + + metadata = { + "instance_id": get_instance_id(row), + "repo": row.get("repo", ""), + "base_commit": row.get("base_commit", ""), + "problem_statement": get_problem_statement(row), + "patch": get_patch(row), + "difficulty": "medium", # Will be set based on index + } + + metadata_file = instance_dir / "metadata.json" + metadata_file.write_text(json.dumps(metadata, indent=2), encoding="utf-8") + return metadata_file + + +def prepare_swebench_tasks( + output_dir: Path, + max_tasks: int = 30, + difficulty: str = "all" +): + """ + Download and prepare SWE-bench Lite tasks. + + Args: + output_dir: Directory to save tasks + max_tasks: Maximum number of tasks to download + difficulty: "easy", "medium", "hard", or "all" + """ + print(f"Loading SWE-bench Lite dataset...") + + try: + ds = load_dataset("princeton-nlp/SWE-bench_Lite", split="test") + except Exception as e: + print(f"Error loading dataset: {e}") + print("Trying alternative dataset name...") + ds = load_dataset("swe-bench/swe-bench-lite", split="test") + + print(f"Loaded {len(ds)} tasks") + + # Calculate difficulty bounds + total = len(ds) + one_third = max(total // 3, 1) + two_third = max((2 * total) // 3, one_third + 1) + + difficulty_ranges = { + "easy": (0, one_third), + "medium": (one_third, two_third), + "hard": (two_third, total), + } + + # Determine which tasks to download + if difficulty == "all": + ranges = list(difficulty_ranges.values()) + indices = [] + for start, end in ranges: + indices.extend(range(start, min(end, start + max_tasks // 3))) + else: + start, end = difficulty_ranges.get(difficulty, (0, total)) + indices = list(range(start, min(end, max_tasks))) + + # Create output directory + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Preparing {len(indices)} tasks...") + + success_count = 0 + for i, idx in enumerate(indices): + try: + row = ds[idx] + instance_id = get_instance_id(row) + + # Create instance directory + instance_dir = output_dir / instance_id + instance_dir.mkdir(parents=True, exist_ok=True) + + # Create files + create_buggy_file(instance_dir, row) + create_test_file(instance_dir, row) + create_metadata_file(instance_dir, row) + + success_count += 1 + if (i + 1) % 10 == 0: + print(f" Processed {i + 1}/{len(indices)} tasks...") + + except Exception as e: + print(f" Warning: Failed to process task {idx}: {e}") + continue + + print(f"\nDone! Prepared {success_count}/{len(indices)} tasks in {output_dir}") + print(f"Set SWEBENCH_TASKS_ROOT={output_dir.absolute()} to use these tasks.") + + +def main(): + parser = argparse.ArgumentParser(description="Prepare SWE-bench Lite tasks") + parser.add_argument( + "--max-tasks", + type=int, + default=30, + help="Maximum number of tasks to download (default: 30)" + ) + parser.add_argument( + "--difficulty", + type=str, + default="all", + choices=["easy", "medium", "hard", "all"], + help="Difficulty level to download (default: all)" + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Output directory (default: dataset/swebench_lite_tasks)" + ) + + args = parser.parse_args() + + # Determine output directory + if args.output_dir: + output_dir = Path(args.output_dir) + else: + script_dir = Path(__file__).parent + output_dir = script_dir / "swebench_lite_tasks" + + prepare_swebench_tasks( + output_dir=output_dir, + max_tasks=args.max_tasks, + difficulty=args.difficulty + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/rl_code_fix_env/dataset/problem_1/buggy.py b/rl_code_fix_env/dataset/problem_1/buggy.py index 3c720ebdab2e813dd39e8c9bd8cc507467fbee7b..9a61c235b09278b83fd9bcc6b84ca9f7af30a4f5 100644 --- a/rl_code_fix_env/dataset/problem_1/buggy.py +++ b/rl_code_fix_env/dataset/problem_1/buggy.py @@ -1,5 +1,7 @@ -def reverse_words(text: str) -> str: - """Return the words in reverse order.""" - # BUG: split(" ") keeps empty items for repeated spaces. - words = text.split(" ") - return " ".join(reversed(words)) +def safe_divide(a: float, b: float) -> float: + """Divide a by b; only return inf for division by zero.""" + try: + return a / b + except Exception: + # BUG: catches unrelated errors too broadly. + return float("inf") diff --git a/rl_code_fix_env/dataset/problem_1/metadata.json b/rl_code_fix_env/dataset/problem_1/metadata.json index c12462825df0c4547d72a3f875886df3e0e48abd..e560d9292e18b71b995405259b5fe93abaa01c76 100644 --- a/rl_code_fix_env/dataset/problem_1/metadata.json +++ b/rl_code_fix_env/dataset/problem_1/metadata.json @@ -1,5 +1,5 @@ { "difficulty": "easy", - "bug_type": "string-splitting", + "bug_type": "exception-handling", "expected_steps": 1 } \ No newline at end of file diff --git a/rl_code_fix_env/dataset/problem_1/test.py b/rl_code_fix_env/dataset/problem_1/test.py index a900fca361edef71e762408ef72cc425e8c41722..b34db6e79416fb7fe8cefab78f346232e73c8a7d 100644 --- a/rl_code_fix_env/dataset/problem_1/test.py +++ b/rl_code_fix_env/dataset/problem_1/test.py @@ -1,13 +1,17 @@ import unittest -from src.dataset.problem_1.buggy import reverse_words +from dataset.problem_1.buggy import safe_divide -class TestReverseWords(unittest.TestCase): - def test_simple(self): - self.assertEqual(reverse_words("hello world"), "world hello") +class TestSafeDivide(unittest.TestCase): + def test_normal(self): + self.assertEqual(safe_divide(8, 2), 4) - def test_multiple_spaces(self): - self.assertEqual(reverse_words("one two three"), "three two one") + def test_zero_division(self): + self.assertEqual(safe_divide(1, 0), float("inf")) + + def test_type_error_should_raise(self): + with self.assertRaises(TypeError): + safe_divide("1", 1) if __name__ == "__main__": diff --git a/rl_code_fix_env/dataset/problem_10/buggy.py b/rl_code_fix_env/dataset/problem_10/buggy.py index b7229dd87b32ac254693757c9a5dc52c1af52a63..99d3fabab466404c5e9a166766bdda3ae53a44f8 100644 --- a/rl_code_fix_env/dataset/problem_10/buggy.py +++ b/rl_code_fix_env/dataset/problem_10/buggy.py @@ -1,4 +1,4 @@ -from src.dataset.problem_10.helpers import transpose +from dataset.problem_10.helpers import transpose def rotate_90_clockwise(matrix: list[list[int]]) -> list[list[int]]: diff --git a/rl_code_fix_env/dataset/problem_10/test.py b/rl_code_fix_env/dataset/problem_10/test.py index 6dc238b8e738eea5f6ea9164be2cc26face574cc..475192c1982ab7843cdffaf027cba68eedb09a94 100644 --- a/rl_code_fix_env/dataset/problem_10/test.py +++ b/rl_code_fix_env/dataset/problem_10/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_10.buggy import rotate_90_clockwise +from dataset.problem_10.buggy import rotate_90_clockwise class TestRotateMatrix(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_11/test.py b/rl_code_fix_env/dataset/problem_11/test.py index 8e632a411a18d38f21cea42ee0e0fbcd2a43d7ef..5dfc79c7447d8c0ebfd3314cf070095a4d479e29 100644 --- a/rl_code_fix_env/dataset/problem_11/test.py +++ b/rl_code_fix_env/dataset/problem_11/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_11.buggy import binary_search +from dataset.problem_11.buggy import binary_search class TestBinarySearch(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_12/test.py b/rl_code_fix_env/dataset/problem_12/test.py index ece3528b981d0d1ef9d6a40354248aed2332f48d..62efc75f13a1a2842c19727fddb10f78f9f09755 100644 --- a/rl_code_fix_env/dataset/problem_12/test.py +++ b/rl_code_fix_env/dataset/problem_12/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_12.buggy import parse_pairs +from dataset.problem_12.buggy import parse_pairs class TestParsePairs(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_13/buggy.py b/rl_code_fix_env/dataset/problem_13/buggy.py index 14677c007ac6f4bd6c73790705ddebdb3a540d87..f727a82c1f054929962119b49f9e99b8ed5228b9 100644 --- a/rl_code_fix_env/dataset/problem_13/buggy.py +++ b/rl_code_fix_env/dataset/problem_13/buggy.py @@ -1,4 +1,4 @@ -from src.dataset.problem_13.cache import LRUCache +from dataset.problem_13.cache import LRUCache def run_ops() -> tuple[int, int]: diff --git a/rl_code_fix_env/dataset/problem_13/test.py b/rl_code_fix_env/dataset/problem_13/test.py index 20c07466e442df3c44618183ca2a95d26ec9913b..74596364759c2f4e6eed206c3c83a3e0c06c102e 100644 --- a/rl_code_fix_env/dataset/problem_13/test.py +++ b/rl_code_fix_env/dataset/problem_13/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_13.buggy import run_ops +from dataset.problem_13.buggy import run_ops class TestLRU(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_14/test.py b/rl_code_fix_env/dataset/problem_14/test.py index 5a4a7578f6150a65acee7bb3b3b260a9717aaf2f..f2e798fa792e7d44bce922c00a4af31678d7a666 100644 --- a/rl_code_fix_env/dataset/problem_14/test.py +++ b/rl_code_fix_env/dataset/problem_14/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_14.buggy import fibonacci_recursive +from dataset.problem_14.buggy import fibonacci_recursive class TestFibonacciRecursive(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_15/test.py b/rl_code_fix_env/dataset/problem_15/test.py index 7c24172bd53f81f63e7d3b9b0bbe5cdf198cf70e..56f1d9f03709b551df5c7d8d4b21fc30d9132148 100644 --- a/rl_code_fix_env/dataset/problem_15/test.py +++ b/rl_code_fix_env/dataset/problem_15/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_15.buggy import has_overlap +from dataset.problem_15.buggy import has_overlap class TestIntervalOverlap(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_16/buggy.py b/rl_code_fix_env/dataset/problem_16/buggy.py index bf531838877fb9947ca6e42bdea0212a53cc07c4..6dc7dd37a97dde0c70d236291284f8ce74ce0a87 100644 --- a/rl_code_fix_env/dataset/problem_16/buggy.py +++ b/rl_code_fix_env/dataset/problem_16/buggy.py @@ -1,4 +1,4 @@ -from src.dataset.problem_16.helpers import normalize_scores +from dataset.problem_16.helpers import normalize_scores def top_label(scores: dict[str, float]) -> str: diff --git a/rl_code_fix_env/dataset/problem_16/test.py b/rl_code_fix_env/dataset/problem_16/test.py index 8178f06b5c0f975b6f2d057ac2c317d4d9a6c256..e4d87c766c7b0def5bc891bcdb2e3da150af8a06 100644 --- a/rl_code_fix_env/dataset/problem_16/test.py +++ b/rl_code_fix_env/dataset/problem_16/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_16.buggy import top_label +from dataset.problem_16.buggy import top_label class TestTopLabel(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_17/test.py b/rl_code_fix_env/dataset/problem_17/test.py index 8ff1f882c953b86e9ac7f7c34c276349b64e37bd..9fe265d94e83d49b0ea5c34c5bb681934a4752b3 100644 --- a/rl_code_fix_env/dataset/problem_17/test.py +++ b/rl_code_fix_env/dataset/problem_17/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_17.buggy import dedupe_preserve_order +from dataset.problem_17.buggy import dedupe_preserve_order class TestDedupe(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_18/buggy.py b/rl_code_fix_env/dataset/problem_18/buggy.py index 44ee27289c515587e507c24b07ed578d9c561455..0b1483a021ea6990b19da4e6aae00ed64c98b259 100644 --- a/rl_code_fix_env/dataset/problem_18/buggy.py +++ b/rl_code_fix_env/dataset/problem_18/buggy.py @@ -1,4 +1,4 @@ -from src.dataset.problem_18.math_utils import clamp +from dataset.problem_18.math_utils import clamp def moving_average(nums: list[int], window: int) -> list[float]: diff --git a/rl_code_fix_env/dataset/problem_18/test.py b/rl_code_fix_env/dataset/problem_18/test.py index e6aff5e9f0461e17e54c0b6fefc956838e317b91..c26799d44f3e73a1c712791e1ba6db1b8c4aef31 100644 --- a/rl_code_fix_env/dataset/problem_18/test.py +++ b/rl_code_fix_env/dataset/problem_18/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_18.buggy import moving_average +from dataset.problem_18.buggy import moving_average class TestMovingAverage(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_19/test.py b/rl_code_fix_env/dataset/problem_19/test.py index 77cb8fc9baeadb792a51a94ef81359d51e0ce34e..1df58a9c64eac930f4b06c9d048fa46afa9eef2b 100644 --- a/rl_code_fix_env/dataset/problem_19/test.py +++ b/rl_code_fix_env/dataset/problem_19/test.py @@ -1,5 +1,5 @@ import pytest -from src.dataset.problem_19.buggy import calculate_employee_bonus +from dataset.problem_19.buggy import calculate_employee_bonus def test_calculate_employee_bonus(): employees = [ diff --git a/rl_code_fix_env/dataset/problem_2/buggy.py b/rl_code_fix_env/dataset/problem_2/buggy.py index 01c30374e6bb8d583402ba5152b4888e52783750..8e209ed9e1e71555f30c689e33f9398c9c29f0ee 100644 --- a/rl_code_fix_env/dataset/problem_2/buggy.py +++ b/rl_code_fix_env/dataset/problem_2/buggy.py @@ -1,5 +1,14 @@ -def is_palindrome(text: str) -> bool: - """Check whether text is a palindrome.""" - # BUG: does not normalize case or skip non-alphanumeric chars. - cleaned = text.strip() - return cleaned == cleaned[::-1] +def binary_search(nums: list[int], target: int) -> int: + """Return index of target, or -1 if not found.""" + left, right = 0, len(nums) - 1 + + while left < right: + mid = (left + right) // 2 + if nums[mid] == target: + return mid + if nums[mid] < target: + left = mid + 1 + else: + right = mid - 1 + + return -1 diff --git a/rl_code_fix_env/dataset/problem_2/metadata.json b/rl_code_fix_env/dataset/problem_2/metadata.json index 7986b433e7a436418391e23378b70af712ff7bba..00c345f756d446efdb14c5472b0057eb885c6b56 100644 --- a/rl_code_fix_env/dataset/problem_2/metadata.json +++ b/rl_code_fix_env/dataset/problem_2/metadata.json @@ -1,5 +1,5 @@ { - "difficulty": "easy", - "bug_type": "string-normalization", + "difficulty": "medium", + "bug_type": "boundary-condition", "expected_steps": 2 } \ No newline at end of file diff --git a/rl_code_fix_env/dataset/problem_2/test.py b/rl_code_fix_env/dataset/problem_2/test.py index 0cb29b88058ab689411429c68d9b1f6f382e77fa..5dfc79c7447d8c0ebfd3314cf070095a4d479e29 100644 --- a/rl_code_fix_env/dataset/problem_2/test.py +++ b/rl_code_fix_env/dataset/problem_2/test.py @@ -1,13 +1,16 @@ import unittest -from src.dataset.problem_2.buggy import is_palindrome +from dataset.problem_11.buggy import binary_search -class TestPalindrome(unittest.TestCase): - def test_basic_true(self): - self.assertTrue(is_palindrome("level")) +class TestBinarySearch(unittest.TestCase): + def test_found_middle(self): + self.assertEqual(binary_search([1, 3, 5, 7], 5), 2) - def test_ignores_case_and_symbols(self): - self.assertTrue(is_palindrome("A man, a plan, a canal: Panama")) + def test_found_last(self): + self.assertEqual(binary_search([1, 3, 5, 7], 7), 3) + + def test_not_found(self): + self.assertEqual(binary_search([1, 3, 5, 7], 4), -1) if __name__ == "__main__": diff --git a/rl_code_fix_env/dataset/problem_20/test.py b/rl_code_fix_env/dataset/problem_20/test.py index eab327b92c1f0999d6f9da529c49f09628a4b6ea..63fb12e94b6000399bff0136c03b7f9426fe3d71 100644 --- a/rl_code_fix_env/dataset/problem_20/test.py +++ b/rl_code_fix_env/dataset/problem_20/test.py @@ -1,5 +1,5 @@ import pytest -from src.dataset.problem_20.buggy import analyze_user_activity +from dataset.problem_20.buggy import analyze_user_activity def test_analyze_user_activity(): logs = [ diff --git a/rl_code_fix_env/dataset/problem_21/test.py b/rl_code_fix_env/dataset/problem_21/test.py index 5d8748e0c20560e1158e30ac251fa3a373c31deb..f251ee52c6b1f5c18e8b9464528293028c8be5d4 100644 --- a/rl_code_fix_env/dataset/problem_21/test.py +++ b/rl_code_fix_env/dataset/problem_21/test.py @@ -2,7 +2,7 @@ import pytest import os import tempfile import json -from src.dataset.problem_21.buggy import process_inventory_data +from dataset.problem_21.buggy import process_inventory_data def test_process_inventory_data(): data = { diff --git a/rl_code_fix_env/dataset/problem_22/test.py b/rl_code_fix_env/dataset/problem_22/test.py index d4cfc47ea87dc8402034483c6455fa0439f84241..5fbe03bff2fe5582397797b541e5a4ff19218093 100644 --- a/rl_code_fix_env/dataset/problem_22/test.py +++ b/rl_code_fix_env/dataset/problem_22/test.py @@ -1,5 +1,5 @@ import pytest -from src.dataset.problem_22.buggy import parse_and_validate_emails +from dataset.problem_22.buggy import parse_and_validate_emails def test_parse_and_validate_emails(): emails = [ diff --git a/rl_code_fix_env/dataset/problem_23/test.py b/rl_code_fix_env/dataset/problem_23/test.py index b0e39bc471959ef120a2d57c19094752e7569eed..e2192084171a59d8b5f4cae8edd03c92a56c6a59 100644 --- a/rl_code_fix_env/dataset/problem_23/test.py +++ b/rl_code_fix_env/dataset/problem_23/test.py @@ -1,5 +1,5 @@ import pytest -from src.dataset.problem_23.buggy import optimize_portfolio +from dataset.problem_23.buggy import optimize_portfolio def test_optimize_portfolio(): investments = [ diff --git a/rl_code_fix_env/dataset/problem_3/buggy.py b/rl_code_fix_env/dataset/problem_3/buggy.py index 614644de15c4bae220401b6fb1d801b2457e7088..b7adc05eacb7c5a55988512a183337708c99620d 100644 --- a/rl_code_fix_env/dataset/problem_3/buggy.py +++ b/rl_code_fix_env/dataset/problem_3/buggy.py @@ -1,10 +1,37 @@ -def fibonacci(n: int) -> int: - """Return the n-th Fibonacci number (0-indexed).""" - if n <= 1: - return n - - a, b = 0, 1 - # BUG: loop count is one step short. - for _ in range(2, n): - a, b = b, a + b - return b +def optimize_portfolio(investments: list[dict], budget: float) -> list[dict]: + """ + Selects the optimal subset of investments to maximize return within a budget. + (0-1 Knapsack problem approximation) + + investments: list of dicts with 'id', 'cost', 'expected_return' + budget: float, maximum total cost allowed + + Returns: + list of chosen investments + """ + # Base case checks + if budget <= 0 or not investments: + return [] + + # BUG 1: Sorting modifies the original list, should use sorted() or copy + # BUG 2: Sorting by expected_return ascending instead of return/cost ratio descending + investments.sort(key=lambda x: x['expected_return']) + + chosen = [] + current_spent = 0 + + # BUG 3: For loop variable shadowing the loop scope if cost/return variables are misspelled + for item in investments: + # BUG 4: item.get() but missing default values if keys are absent, could cause TypeError if None + cost = item.get('cost') + ret = item.get('expected_return') + + # BUG 5: Logic error: checking if current_spent is less than budget, but not checking if adding cost exceeds it + if current_spent < budget: + current_spent += cost + chosen.append(item) + + # BUG 6: Does not handle the case where adding the item exceeds budget, just blindly adds it if current_spent < budget + # E.g. budget 100, current 90, item cost 50 -> adds it, total 140 + + return chosen diff --git a/rl_code_fix_env/dataset/problem_3/metadata.json b/rl_code_fix_env/dataset/problem_3/metadata.json index 98d15dec6f19a833ce146fbc4e8ec755f2e35950..bbba4b2f62c5d7615890c757b18dee22b55c496f 100644 --- a/rl_code_fix_env/dataset/problem_3/metadata.json +++ b/rl_code_fix_env/dataset/problem_3/metadata.json @@ -1,5 +1,5 @@ { - "difficulty": "easy", - "bug_type": "off-by-one", - "expected_steps": 1 + "difficulty": "hard", + "bug_type": "multiple", + "expected_steps": 5 } \ No newline at end of file diff --git a/rl_code_fix_env/dataset/problem_3/test.py b/rl_code_fix_env/dataset/problem_3/test.py index 93907a42d1972c2f094d4f05071169bca7912216..e2192084171a59d8b5f4cae8edd03c92a56c6a59 100644 --- a/rl_code_fix_env/dataset/problem_3/test.py +++ b/rl_code_fix_env/dataset/problem_3/test.py @@ -1,15 +1,44 @@ -import unittest -from src.dataset.problem_3.buggy import fibonacci +import pytest +from dataset.problem_23.buggy import optimize_portfolio - -class TestFibonacci(unittest.TestCase): - def test_small_values(self): - self.assertEqual(fibonacci(2), 1) - self.assertEqual(fibonacci(3), 2) - - def test_larger_value(self): - self.assertEqual(fibonacci(7), 13) - - -if __name__ == "__main__": - unittest.main() +def test_optimize_portfolio(): + investments = [ + {'id': 'A', 'cost': 50, 'expected_return': 60}, # ratio 1.2 + {'id': 'B', 'cost': 30, 'expected_return': 45}, # ratio 1.5 + {'id': 'C', 'cost': 20, 'expected_return': 40}, # ratio 2.0 + {'id': 'D', 'cost': 40, 'expected_return': 50}, # ratio 1.25 + {'id': 'E', 'cost': 10, 'expected_return': 15} # ratio 1.5 + ] + + # Original list should not be mutated + orig_investments = [dict(i) for i in investments] + + # Budget 50 + # Expected greedy: C (20) -> B (30) -> total cost 50, return 85 + result = optimize_portfolio(investments, 50) + + assert investments == orig_investments, "Original list was mutated" + + # Assert correct items selected + chosen_ids = {item['id'] for item in result} + assert chosen_ids == {'B', 'C'}, f"Expected B and C, got {chosen_ids}" + + total_cost = sum(item['cost'] for item in result) + assert total_cost <= 50 + +def test_budget_exceeded_check(): + investments = [ + {'id': 'A', 'cost': 90, 'expected_return': 100}, + {'id': 'B', 'cost': 50, 'expected_return': 60} + ] + + # Budget 100 + # Expected: A (cost 90) + result = optimize_portfolio(investments, 100) + + chosen_ids = {item['id'] for item in result} + assert chosen_ids == {'A'}, "Should not include B since total cost would be 140" + +def test_empty_or_zero_budget(): + assert optimize_portfolio([], 100) == [] + assert optimize_portfolio([{'id': 'A', 'cost': 10, 'expected_return': 20}], 0) == [] diff --git a/rl_code_fix_env/dataset/problem_4/test.py b/rl_code_fix_env/dataset/problem_4/test.py index 01c0ebe72db18c2c4c7e6621c194f14a5c1016f6..985f1dd32bb2aafeb8f7f4e35bc8a0fd2d2885a3 100644 --- a/rl_code_fix_env/dataset/problem_4/test.py +++ b/rl_code_fix_env/dataset/problem_4/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_4.buggy import merge_sorted +from dataset.problem_4.buggy import merge_sorted class TestMergeSorted(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_5/test.py b/rl_code_fix_env/dataset/problem_5/test.py index fc8f31276be87cd2e0519851b0160dca12ca9531..0e7fb0d47ac655220aee7eaccb3c500a56647abc 100644 --- a/rl_code_fix_env/dataset/problem_5/test.py +++ b/rl_code_fix_env/dataset/problem_5/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_5.buggy import chunk_list +from dataset.problem_5.buggy import chunk_list class TestChunkList(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_6/buggy.py b/rl_code_fix_env/dataset/problem_6/buggy.py index f91d71fdeac6b5d1d5b9e86f10574b0634cee1f4..95b8ad9e3f1a68a3d578b9c6994f92b7ccadea36 100644 --- a/rl_code_fix_env/dataset/problem_6/buggy.py +++ b/rl_code_fix_env/dataset/problem_6/buggy.py @@ -1,4 +1,4 @@ -from src.dataset.problem_6.helpers import tokenize +from dataset.problem_6.helpers import tokenize def count_unique_words(text: str) -> int: diff --git a/rl_code_fix_env/dataset/problem_6/test.py b/rl_code_fix_env/dataset/problem_6/test.py index 2916685b1a67a2070bcf453a47d68ed0c75fac5c..5a54c376d062a602a94e23ba6e8b4a1d058c3697 100644 --- a/rl_code_fix_env/dataset/problem_6/test.py +++ b/rl_code_fix_env/dataset/problem_6/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_6.buggy import count_unique_words +from dataset.problem_6.buggy import count_unique_words class TestCountUniqueWords(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_7/test.py b/rl_code_fix_env/dataset/problem_7/test.py index 82f7cb2d5eb454e8aa5e2f1482c91ca01d490c70..d02d18fdc1922a1f7b91571c58502c7527784739 100644 --- a/rl_code_fix_env/dataset/problem_7/test.py +++ b/rl_code_fix_env/dataset/problem_7/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_7.buggy import top_k_frequent +from dataset.problem_7.buggy import top_k_frequent class TestTopKFrequent(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_8/test.py b/rl_code_fix_env/dataset/problem_8/test.py index e5f7b67133a85f9ff928e3e91ea13f2e6a62ccbf..13d5f466caf068fdac47c7d7bf8f1399c2304e0b 100644 --- a/rl_code_fix_env/dataset/problem_8/test.py +++ b/rl_code_fix_env/dataset/problem_8/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_8.buggy import flatten_one_level +from dataset.problem_8.buggy import flatten_one_level class TestFlattenOneLevel(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/problem_9/test.py b/rl_code_fix_env/dataset/problem_9/test.py index 81c5a6300249b2b505676afc6719f365b16cd342..3667dfd15c325241bbdeb9b5b3f0868cd9012afc 100644 --- a/rl_code_fix_env/dataset/problem_9/test.py +++ b/rl_code_fix_env/dataset/problem_9/test.py @@ -1,5 +1,5 @@ import unittest -from src.dataset.problem_9.buggy import safe_divide +from dataset.problem_9.buggy import safe_divide class TestSafeDivide(unittest.TestCase): diff --git a/rl_code_fix_env/dataset/swebench_adapter.py b/rl_code_fix_env/dataset/swebench_adapter.py index b23ff51fa61b7dfd5cfbb0b5816fdc180169e1d6..a278974d6bfffb0346b022b1522ec45a27532c3c 100644 --- a/rl_code_fix_env/dataset/swebench_adapter.py +++ b/rl_code_fix_env/dataset/swebench_adapter.py @@ -46,47 +46,93 @@ def get_swebench_task(difficulty: str) -> Dict[str, Any]: Expected local layout: dataset/swebench_lite_tasks//buggy.py dataset/swebench_lite_tasks//test.py + + First tries to load from local files, then falls back to HuggingFace dataset. """ diff = (difficulty or "").strip().lower() if diff not in DIFFICULTIES: raise ValueError(f"Invalid difficulty '{difficulty}'. Must be one of {DIFFICULTIES}.") - rows = _load_swebench_lite_rows() - if not rows: - raise RuntimeError("SWE-bench Lite split is empty.") - - bounds = _difficulty_bounds(len(rows)) - start, end = bounds[diff] - candidates = rows[start:end] if end > start else rows - tasks_root = Path(os.getenv("SWEBENCH_TASKS_ROOT", str(DEFAULT_TASKS_ROOT))) - preferred_offset = int(os.getenv("SWEBENCH_INDEX", "0")) - - # Deterministic scan order with optional offset. - ordered = candidates[preferred_offset:] + candidates[:preferred_offset] - for row in ordered: - row_idx = int(row.get("__index_level_0__", 0)) - instance_id = str(row.get("instance_id", f"row_{row_idx}")) - for folder in _candidate_dirs(tasks_root, instance_id, row_idx): - buggy_file = folder / "buggy.py" - test_file = folder / "test.py" - if buggy_file.exists() and test_file.exists(): - code = buggy_file.read_text(encoding="utf-8") - metadata = { - "source": "swebench_lite", - "instance_id": instance_id, - "repo": row.get("repo"), - "base_commit": row.get("base_commit"), - "problem_statement": row.get("problem_statement"), - "difficulty": diff, - } - return { - "code": code, - "tests": str(test_file), - "metadata": metadata, - "problem_dir": str(folder), - "problem_id": instance_id, - } + + # First, try to load from local materialized tasks + if tasks_root.exists(): + # Find all instance directories + instance_dirs = [] + for item in tasks_root.iterdir(): + if item.is_dir() and (item / "buggy.py").exists() and (item / "test.py").exists(): + # Check if this directory matches the difficulty + if diff in item.name.lower(): + instance_dirs.append(item) + + if instance_dirs: + # Sort for deterministic selection + instance_dirs.sort(key=lambda x: x.name) + + # Select based on SWEBENCH_INDEX + preferred_offset = int(os.getenv("SWEBENCH_INDEX", "0")) + selected_dir = instance_dirs[preferred_offset % len(instance_dirs)] + + buggy_file = selected_dir / "buggy.py" + test_file = selected_dir / "test.py" + metadata_file = selected_dir / "metadata.json" + + code = buggy_file.read_text(encoding="utf-8") + + # Load metadata if available + metadata = {"source": "swebench_lite", "difficulty": diff} + if metadata_file.exists(): + import json + metadata = json.loads(metadata_file.read_text(encoding="utf-8")) + + return { + "code": code, + "tests": str(test_file), + "metadata": metadata, + "problem_dir": str(selected_dir), + "problem_id": selected_dir.name, + } + + # Fallback: try to load from HuggingFace dataset + try: + rows = _load_swebench_lite_rows() + if not rows: + raise RuntimeError("SWE-bench Lite split is empty.") + + bounds = _difficulty_bounds(len(rows)) + start, end = bounds[diff] + candidates = rows[start:end] if end > start else rows + + preferred_offset = int(os.getenv("SWEBENCH_INDEX", "0")) + + # Deterministic scan order with optional offset. + ordered = candidates[preferred_offset:] + candidates[:preferred_offset] + for row in ordered: + row_idx = int(row.get("__index_level_0__", 0)) + instance_id = str(row.get("instance_id", f"row_{row_idx}")) + for folder in _candidate_dirs(tasks_root, instance_id, row_idx): + buggy_file = folder / "buggy.py" + test_file = folder / "test.py" + if buggy_file.exists() and test_file.exists(): + code = buggy_file.read_text(encoding="utf-8") + metadata = { + "source": "swebench_lite", + "instance_id": instance_id, + "repo": row.get("repo"), + "base_commit": row.get("base_commit"), + "problem_statement": row.get("problem_statement"), + "difficulty": diff, + } + return { + "code": code, + "tests": str(test_file), + "metadata": metadata, + "problem_dir": str(folder), + "problem_id": instance_id, + } + except Exception as e: + # If HuggingFace fails, raise the original error about missing local files + pass raise FileNotFoundError( "No materialized SWE-bench task workspace found. " diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/django__django-11098_easy_0/buggy.py b/rl_code_fix_env/dataset/swebench_lite_tasks/django__django-11098_easy_0/buggy.py new file mode 100644 index 0000000000000000000000000000000000000000..9b62fd128ec58ca2b4d2826bbc7a6f5139615e04 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/django__django-11098_easy_0/buggy.py @@ -0,0 +1,29 @@ +from django import forms +from django.contrib.auth.models import User + +class UserCreationForm(forms.ModelForm): + """Form for creating new users.""" + password1 = forms.CharField(widget=forms.PasswordInput) + password2 = forms.CharField(widget=forms.PasswordInput) + + class Meta: + model = User + fields = ('username', 'email') + + def clean(self): + cleaned_data = super().clean() + password1 = cleaned_data.get('password1') + password2 = cleaned_data.get('password2') + + # BUG: This comparison is case-sensitive but should be case-insensitive + if password1 != password2: + raise forms.ValidationError("Passwords don't match") + + return cleaned_data + + def save(self, commit=True): + user = super().save(commit=False) + user.set_password(self.cleaned_data['password1']) + if commit: + user.save() + return user diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/django__django-11098_easy_0/metadata.json b/rl_code_fix_env/dataset/swebench_lite_tasks/django__django-11098_easy_0/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..5f762ffd5175053b845ff38cc3a6fc0862776cd0 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/django__django-11098_easy_0/metadata.json @@ -0,0 +1,6 @@ +{ + "instance_id": "django__django-11098_easy_0", + "repo": "django/django", + "problem_statement": "Fix the user creation form validation error", + "difficulty": "easy" +} \ No newline at end of file diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/django__django-11098_easy_0/test.py b/rl_code_fix_env/dataset/swebench_lite_tasks/django__django-11098_easy_0/test.py new file mode 100644 index 0000000000000000000000000000000000000000..592f3157e102cf6aa3b9f16082ded161f4b322de --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/django__django-11098_easy_0/test.py @@ -0,0 +1,24 @@ +import unittest +from buggy import UserCreationForm + +class TestUserCreationForm(unittest.TestCase): + def test_password_matching(self): + """Test that matching passwords pass validation.""" + form = UserCreationForm(data={ + 'username': 'testuser', + 'email': 'test@example.com', + 'password1': 'TestPass123', + 'password2': 'TestPass123', + }) + self.assertTrue(form.is_valid()) + + def test_password_mismatch(self): + """Test that mismatched passwords fail validation.""" + form = UserCreationForm(data={ + 'username': 'testuser', + 'email': 'test@example.com', + 'password1': 'TestPass123', + 'password2': 'testpass123', # Different case + }) + self.assertFalse(form.is_valid()) + self.assertIn('passwords', str(form.errors).lower()) diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/flask__flask-1048_easy_1/buggy.py b/rl_code_fix_env/dataset/swebench_lite_tasks/flask__flask-1048_easy_1/buggy.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c6b516b4700bd3ae952cfb1f7e15f95e1b7524 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/flask__flask-1048_easy_1/buggy.py @@ -0,0 +1,15 @@ +import json +from datetime import datetime, date + +class JSONEncoder(json.JSONEncoder): + """Custom JSON encoder for Flask.""" + + def default(self, obj): + # BUG: Missing handling for datetime objects + if isinstance(obj, date): + return obj.isoformat() + return super().default(obj) + +def to_json(obj): + """Convert object to JSON string.""" + return json.dumps(obj, cls=JSONEncoder) diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/flask__flask-1048_easy_1/metadata.json b/rl_code_fix_env/dataset/swebench_lite_tasks/flask__flask-1048_easy_1/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..a866a05d03e99baaff3f559f7a2b12e41bb3c74c --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/flask__flask-1048_easy_1/metadata.json @@ -0,0 +1,6 @@ +{ + "instance_id": "flask__flask-1048_easy_1", + "repo": "pallets/flask", + "problem_statement": "Fix JSON encoding for datetime objects", + "difficulty": "easy" +} \ No newline at end of file diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/flask__flask-1048_easy_1/test.py b/rl_code_fix_env/dataset/swebench_lite_tasks/flask__flask-1048_easy_1/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd92d8abb8ebc7897472659a9b09fb84e628d3d --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/flask__flask-1048_easy_1/test.py @@ -0,0 +1,17 @@ +import unittest +from datetime import datetime +from buggy import to_json + +class TestJSONEncoding(unittest.TestCase): + def test_encode_datetime(self): + """Test that datetime objects are properly encoded.""" + dt = datetime(2024, 1, 15, 10, 30, 0) + result = to_json({'timestamp': dt}) + self.assertIn('2024-01-15', result) + self.assertIn('10:30:00', result) + + def test_encode_date(self): + """Test that date objects are properly encoded.""" + d = date(2024, 1, 15) + result = to_json({'date': d}) + self.assertIn('2024-01-15', result) diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/numpy__numpy-10825_medium_0/buggy.py b/rl_code_fix_env/dataset/swebench_lite_tasks/numpy__numpy-10825_medium_0/buggy.py new file mode 100644 index 0000000000000000000000000000000000000000..c10d61e09abacaccaca30bb8ae748043e5e69877 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/numpy__numpy-10825_medium_0/buggy.py @@ -0,0 +1,13 @@ +import numpy as np + +def concatenate_arrays(*arrays): + """Concatenate multiple arrays along axis 0.""" + if not arrays: + return np.array([]) + + # BUG: Should handle None arrays gracefully + result = arrays[0] + for arr in arrays[1:]: + result = np.concatenate([result, arr]) + + return result diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/numpy__numpy-10825_medium_0/metadata.json b/rl_code_fix_env/dataset/swebench_lite_tasks/numpy__numpy-10825_medium_0/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..44dbf24dabda398a999146b72f57a3b145da2c6f --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/numpy__numpy-10825_medium_0/metadata.json @@ -0,0 +1,6 @@ +{ + "instance_id": "numpy__numpy-10825_medium_0", + "repo": "numpy/numpy", + "problem_statement": "Fix array concatenation edge case", + "difficulty": "medium" +} \ No newline at end of file diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/numpy__numpy-10825_medium_0/test.py b/rl_code_fix_env/dataset/swebench_lite_tasks/numpy__numpy-10825_medium_0/test.py new file mode 100644 index 0000000000000000000000000000000000000000..c0908228f8e2d0ee18e99ed8b8e6a8bcd3267341 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/numpy__numpy-10825_medium_0/test.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np +from buggy import concatenate_arrays + +class TestArrayConcatenation(unittest.TestCase): + def test_basic_concatenation(self): + """Test basic array concatenation.""" + a = np.array([1, 2, 3]) + b = np.array([4, 5, 6]) + result = concatenate_arrays(a, b) + np.testing.assert_array_equal(result, np.array([1, 2, 3, 4, 5, 6])) + + def test_empty_input(self): + """Test empty input returns empty array.""" + result = concatenate_arrays() + self.assertEqual(len(result), 0) + + def test_single_array(self): + """Test single array passes through.""" + a = np.array([1, 2, 3]) + result = concatenate_arrays(a) + np.testing.assert_array_equal(result, a) diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/pandas__pandas-15230_medium_1/buggy.py b/rl_code_fix_env/dataset/swebench_lite_tasks/pandas__pandas-15230_medium_1/buggy.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd69a340076919914b1c8a802582956f0dac75b --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/pandas__pandas-15230_medium_1/buggy.py @@ -0,0 +1,13 @@ +import pandas as pd + +def group_and_aggregate(df, group_col, agg_col, agg_func='mean'): + """Group DataFrame and aggregate.""" + # BUG: Should handle non-numeric columns gracefully + if agg_func == 'mean': + return df.groupby(group_col)[agg_col].mean() + elif agg_func == 'sum': + return df.groupby(group_col)[agg_col].sum() + elif agg_func == 'count': + return df.groupby(group_col)[agg_col].count() + else: + raise ValueError(f"Unknown aggregation function: {agg_func}") diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/pandas__pandas-15230_medium_1/metadata.json b/rl_code_fix_env/dataset/swebench_lite_tasks/pandas__pandas-15230_medium_1/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..7a1ffec3ca53316168d0bdfafedc9b4b37cc1620 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/pandas__pandas-15230_medium_1/metadata.json @@ -0,0 +1,6 @@ +{ + "instance_id": "pandas__pandas-15230_medium_1", + "repo": "pandas-dev/pandas", + "problem_statement": "Fix DataFrame groupby aggregation", + "difficulty": "medium" +} \ No newline at end of file diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/pandas__pandas-15230_medium_1/test.py b/rl_code_fix_env/dataset/swebench_lite_tasks/pandas__pandas-15230_medium_1/test.py new file mode 100644 index 0000000000000000000000000000000000000000..79d657a5ffae4e0e40d6fcef8b02772af9c0841d --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/pandas__pandas-15230_medium_1/test.py @@ -0,0 +1,24 @@ +import unittest +import pandas as pd +from buggy import group_and_aggregate + +class TestGroupBy(unittest.TestCase): + def test_mean_aggregation(self): + """Test mean aggregation.""" + df = pd.DataFrame({ + 'category': ['A', 'A', 'B', 'B'], + 'value': [1, 2, 3, 4] + }) + result = group_and_aggregate(df, 'category', 'value', 'mean') + self.assertEqual(result['A'], 1.5) + self.assertEqual(result['B'], 3.5) + + def test_sum_aggregation(self): + """Test sum aggregation.""" + df = pd.DataFrame({ + 'category': ['A', 'A', 'B'], + 'value': [1, 2, 3] + }) + result = group_and_aggregate(df, 'category', 'value', 'sum') + self.assertEqual(result['A'], 3) + self.assertEqual(result['B'], 3) diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/pytest__pytest-7426_hard_1/buggy.py b/rl_code_fix_env/dataset/swebench_lite_tasks/pytest__pytest-7426_hard_1/buggy.py new file mode 100644 index 0000000000000000000000000000000000000000..f12b10ba96c6d0e625a972721f1ff4a2e8431521 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/pytest__pytest-7426_hard_1/buggy.py @@ -0,0 +1,14 @@ +import os +import re + +def collect_tests(directory, pattern='test_*.py'): + """Collect test files from directory.""" + # BUG: Should sort files for consistent ordering + test_files = [] + + for root, dirs, files in os.walk(directory): + for file in files: + if re.match(pattern, file): + test_files.append(os.path.join(root, file)) + + return test_files diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/pytest__pytest-7426_hard_1/metadata.json b/rl_code_fix_env/dataset/swebench_lite_tasks/pytest__pytest-7426_hard_1/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..45ceedb8e1e259ee2b1dcd5c45eb2a37f6da93ec --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/pytest__pytest-7426_hard_1/metadata.json @@ -0,0 +1,6 @@ +{ + "instance_id": "pytest__pytest-7426_hard_1", + "repo": "pytest-dev/pytest", + "problem_statement": "Fix test collection order", + "difficulty": "hard" +} \ No newline at end of file diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/pytest__pytest-7426_hard_1/test.py b/rl_code_fix_env/dataset/swebench_lite_tasks/pytest__pytest-7426_hard_1/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8606ea03f4615dca8df44ec0a9ca05df676377db --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/pytest__pytest-7426_hard_1/test.py @@ -0,0 +1,27 @@ +import unittest +import os +import tempfile +from buggy import collect_tests + +class TestCollection(unittest.TestCase): + def test_collect_pattern(self): + """Test that correct pattern is matched.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test files + open(os.path.join(tmpdir, 'test_a.py'), 'w').close() + open(os.path.join(tmpdir, 'test_b.py'), 'w').close() + open(os.path.join(tmpdir, 'not_a_test.py'), 'w').close() + + tests = collect_tests(tmpdir, 'test_*.py') + self.assertEqual(len(tests), 2) + + def test_consistent_order(self): + """Test that file order is consistent.""" + with tempfile.TemporaryDirectory() as tmpdir: + for name in ['test_c.py', 'test_a.py', 'test_b.py']: + open(os.path.join(tmpdir, name), 'w').close() + + tests1 = collect_tests(tmpdir) + tests2 = collect_tests(tmpdir) + + self.assertEqual(tests1, tests2) diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/requests__requests-2875_easy_2/buggy.py b/rl_code_fix_env/dataset/swebench_lite_tasks/requests__requests-2875_easy_2/buggy.py new file mode 100644 index 0000000000000000000000000000000000000000..52b1a5f0e11964ac7324e6ecc5a0415e3e60db41 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/requests__requests-2875_easy_2/buggy.py @@ -0,0 +1,14 @@ +import re +from urllib.parse import urlparse + +def match_cookie_domain(cookie_domain, request_domain): + """Check if cookie domain matches request domain.""" + # BUG: Should handle leading dots differently + # .example.com should match sub.example.com but not example.com + cookie_domain = cookie_domain.lower() + request_domain = request_domain.lower() + + if cookie_domain.startswith('.'): + return request_domain.endswith(cookie_domain) + + return cookie_domain == request_domain diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/requests__requests-2875_easy_2/metadata.json b/rl_code_fix_env/dataset/swebench_lite_tasks/requests__requests-2875_easy_2/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..b0994313548b705a5d8628e7ccba65dacdd98169 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/requests__requests-2875_easy_2/metadata.json @@ -0,0 +1,6 @@ +{ + "instance_id": "requests__requests-2875_easy_2", + "repo": "psf/requests", + "problem_statement": "Fix cookie domain matching", + "difficulty": "easy" +} \ No newline at end of file diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/requests__requests-2875_easy_2/test.py b/rl_code_fix_env/dataset/swebench_lite_tasks/requests__requests-2875_easy_2/test.py new file mode 100644 index 0000000000000000000000000000000000000000..17b849dc916e2f51d443a578fe2d0343bb33053a --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/requests__requests-2875_easy_2/test.py @@ -0,0 +1,17 @@ +import unittest +from buggy import match_cookie_domain + +class TestCookieDomain(unittest.TestCase): + def test_exact_match(self): + """Test exact domain matching.""" + self.assertTrue(match_cookie_domain('example.com', 'example.com')) + + def test_subdomain_with_dot(self): + """Test subdomain matching with leading dot.""" + # .example.com should match sub.example.com + self.assertTrue(match_cookie_domain('.example.com', 'sub.example.com')) + self.assertFalse(match_cookie_domain('.example.com', 'example.com')) + + def test_different_domains(self): + """Test different domains don't match.""" + self.assertFalse(match_cookie_domain('example.com', 'other.com')) diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/scipy__scipy-1925_medium_2/buggy.py b/rl_code_fix_env/dataset/swebench_lite_tasks/scipy__scipy-1925_medium_2/buggy.py new file mode 100644 index 0000000000000000000000000000000000000000..6a30c0c59132b9e1c790f6d8ee10b8a0f2989321 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/scipy__scipy-1925_medium_2/buggy.py @@ -0,0 +1,14 @@ +import numpy as np +from scipy import signal + +def apply_lowpass_filter(data, cutoff, fs, order=5): + """Apply lowpass filter to data.""" + # BUG: Should validate cutoff frequency + nyquist = fs / 2 + normalized_cutoff = cutoff / nyquist + + # BUG: Using invalid cutoff can cause filter design failure + b, a = signal.butter(order, normalized_cutoff, btype='low') + filtered = signal.filtfilt(b, a, data) + + return filtered diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/scipy__scipy-1925_medium_2/metadata.json b/rl_code_fix_env/dataset/swebench_lite_tasks/scipy__scipy-1925_medium_2/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4bdfa7f2cf40df994703f4e19852c7a64de6daba --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/scipy__scipy-1925_medium_2/metadata.json @@ -0,0 +1,6 @@ +{ + "instance_id": "scipy__scipy-1925_medium_2", + "repo": "scipy/scipy", + "problem_statement": "Fix signal filtering edge case", + "difficulty": "medium" +} \ No newline at end of file diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/scipy__scipy-1925_medium_2/test.py b/rl_code_fix_env/dataset/swebench_lite_tasks/scipy__scipy-1925_medium_2/test.py new file mode 100644 index 0000000000000000000000000000000000000000..223d4001c126d9f2db0406d509b0800dd0124652 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/scipy__scipy-1925_medium_2/test.py @@ -0,0 +1,25 @@ +import unittest +import numpy as np +from buggy import apply_lowpass_filter + +class TestSignalFiltering(unittest.TestCase): + def test_valid_filter(self): + """Test filtering with valid parameters.""" + fs = 1000 # Sampling frequency + cutoff = 100 # Cutoff frequency + t = np.linspace(0, 1, fs) + data = np.sin(2 * np.pi * 50 * t) + 0.5 * np.sin(2 * np.pi * 200 * t) + + result = apply_lowpass_filter(data, cutoff, fs) + self.assertEqual(len(result), len(data)) + # Low frequency component should be preserved + self.assertTrue(np.abs(result[100]) > 0.5) + + def test_invalid_cutoff(self): + """Test that invalid cutoff raises error.""" + fs = 1000 + cutoff = 2000 # Above Nyquist frequency - should fail + data = np.array([1, 2, 3, 4, 5]) + + with self.assertRaises(ValueError): + apply_lowpass_filter(data, cutoff, fs) diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/sklearn__sklearn-12345_hard_0/buggy.py b/rl_code_fix_env/dataset/swebench_lite_tasks/sklearn__sklearn-12345_hard_0/buggy.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a2e61d2f0efe2826bff0783c3b08444802b696 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/sklearn__sklearn-12345_hard_0/buggy.py @@ -0,0 +1,13 @@ +import numpy as np +from sklearn.model_selection import KFold + +def get_cv_splits(X, n_splits=5, shuffle=True, random_state=42): + """Get cross-validation splits.""" + # BUG: random_state should be used for reproducibility + kf = KFold(n_splits=n_splits, shuffle=shuffle) + + splits = [] + for train_idx, test_idx in kf.split(X): + splits.append((train_idx, test_idx)) + + return splits diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/sklearn__sklearn-12345_hard_0/metadata.json b/rl_code_fix_env/dataset/swebench_lite_tasks/sklearn__sklearn-12345_hard_0/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..24d27c0fb21fefda3dea85d9c415b4e85e360d05 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/sklearn__sklearn-12345_hard_0/metadata.json @@ -0,0 +1,6 @@ +{ + "instance_id": "sklearn__sklearn-12345_hard_0", + "repo": "scikit-learn/scikit-learn", + "problem_statement": "Fix cross-validation split", + "difficulty": "hard" +} \ No newline at end of file diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/sklearn__sklearn-12345_hard_0/test.py b/rl_code_fix_env/dataset/swebench_lite_tasks/sklearn__sklearn-12345_hard_0/test.py new file mode 100644 index 0000000000000000000000000000000000000000..14df293f932be7c755425467ffa48ab084ece791 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/sklearn__sklearn-12345_hard_0/test.py @@ -0,0 +1,20 @@ +import unittest +import numpy as np +from buggy import get_cv_splits + +class TestCVSplits(unittest.TestCase): + def test_split_count(self): + """Test that correct number of splits is generated.""" + X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) + splits = get_cv_splits(X, n_splits=3) + self.assertEqual(len(splits), 3) + + def test_reproducibility(self): + """Test that splits are reproducible with same random_state.""" + X = np.random.rand(100, 5) + splits1 = get_cv_splits(X, n_splits=5, random_state=42) + splits2 = get_cv_splits(X, n_splits=5, random_state=42) + + for (train1, test1), (train2, test2) in zip(splits1, splits2): + np.testing.assert_array_equal(train1, train2) + np.testing.assert_array_equal(test1, test2) diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/transformers__transformers-12345_hard_2/buggy.py b/rl_code_fix_env/dataset/swebench_lite_tasks/transformers__transformers-12345_hard_2/buggy.py new file mode 100644 index 0000000000000000000000000000000000000000..bb67a6983c5c703fdfbdf3c93b702d8190153b10 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/transformers__transformers-12345_hard_2/buggy.py @@ -0,0 +1,14 @@ +from typing import List + +def tokenize_and_pad(tokenizer, texts: List[str], max_length: int = 512): + """Tokenize texts and pad to max length.""" + # BUG: Should handle padding correctly + encoded = tokenizer( + texts, + padding=True, # This pads to longest in batch, not max_length + truncation=True, + max_length=max_length, + return_tensors='pt' + ) + + return encoded diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/transformers__transformers-12345_hard_2/metadata.json b/rl_code_fix_env/dataset/swebench_lite_tasks/transformers__transformers-12345_hard_2/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..cd56ec71bf0899ef7147dc8498ff77a4e83c4372 --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/transformers__transformers-12345_hard_2/metadata.json @@ -0,0 +1,6 @@ +{ + "instance_id": "transformers__transformers-12345_hard_2", + "repo": "huggingface/transformers", + "problem_statement": "Fix tokenization padding", + "difficulty": "hard" +} \ No newline at end of file diff --git a/rl_code_fix_env/dataset/swebench_lite_tasks/transformers__transformers-12345_hard_2/test.py b/rl_code_fix_env/dataset/swebench_lite_tasks/transformers__transformers-12345_hard_2/test.py new file mode 100644 index 0000000000000000000000000000000000000000..01cf2e04be741864d7f8d2375024a1f0dc9641ce --- /dev/null +++ b/rl_code_fix_env/dataset/swebench_lite_tasks/transformers__transformers-12345_hard_2/test.py @@ -0,0 +1,24 @@ +import unittest +from buggy import tokenize_and_pad + +class MockTokenizer: + def __call__(self, texts, padding=True, truncation=True, max_length=512, return_tensors=None): + # Simplified mock + return { + 'input_ids': [[1, 2, 3]] if isinstance(texts, list) else [1, 2, 3], + 'attention_mask': [[1, 1, 1]] if isinstance(texts, list) else [1, 1, 1] + } + +class TestTokenization(unittest.TestCase): + def test_single_text(self): + """Test tokenizing single text.""" + tokenizer = MockTokenizer() + result = tokenize_and_pad(tokenizer, ["hello world"]) + self.assertIn('input_ids', result) + + def test_max_length_respected(self): + """Test that max_length is respected.""" + tokenizer = MockTokenizer() + # Should not raise even with long text + result = tokenize_and_pad(tokenizer, ["short"], max_length=10) + self.assertIn('input_ids', result) diff --git a/rl_code_fix_env/dataset/task_manager.py b/rl_code_fix_env/dataset/task_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..f984dd9faf16cfe9129f26f745a0ec41951ccbaf --- /dev/null +++ b/rl_code_fix_env/dataset/task_manager.py @@ -0,0 +1,182 @@ +""" +Unified Task Manager: Abstractly load tasks from both local and SWE-bench datasets. + +This module provides a single interface to load tasks from: +1. Local hardcoded dataset (dataset/problem_1, problem_10, etc.) +2. SWE-bench Lite (if available and configured) + +Configuration via environment variables: + TASK_SOURCE "local" | "swebench" | "auto" (default: "auto") + SWEBENCH_FALLBACK "1" (enable fallback when SWE-bench fails, default: "1") + SWEBENCH_TASKS_ROOT Path to SWE-bench tasks directory + SWEBENCH_INDEX Preferred task index within difficulty band +""" + +import os +import logging +from pathlib import Path +from typing import Dict, Any, Optional, Literal + +from rl_code_fix_env.dataset.loader import get_hardcoded_task +from rl_code_fix_env.dataset.swebench_adapter import get_swebench_task + +logger = logging.getLogger(__name__) + +TaskSource = Literal["local", "swebench", "auto"] +Difficulty = Literal["easy", "medium", "hard"] + + +class TaskLoadError(Exception): + """Raised when task loading fails.""" + pass + + +class TaskManager: + """ + Unified interface for loading tasks from any dataset. + + Handles fallback logic, logging, and error recovery. + """ + + def __init__(self, source: Optional[TaskSource] = None): + """ + Initialize TaskManager. + + Args: + source: "local", "swebench", or "auto" (tries swebench first, falls back to local) + If None, reads from TASK_SOURCE env var (default: "auto") + """ + self.source = (source or os.getenv("TASK_SOURCE", "auto")).strip().lower() + self.enable_fallback = ( + os.getenv("SWEBENCH_FALLBACK", "1").strip().lower() in {"1", "true", "yes"} + ) + + if self.source not in {"local", "swebench", "auto"}: + raise ValueError( + f"Invalid TASK_SOURCE='{self.source}'. " + f"Must be one of: local, swebench, auto" + ) + + logger.info( + f"TaskManager initialized: source={self.source}, " + f"fallback_enabled={self.enable_fallback}" + ) + + def load_task(self, difficulty: Difficulty) -> Dict[str, Any]: + """ + Load a task by difficulty level. + + Args: + difficulty: "easy", "medium", or "hard" + + Returns: + Task dict with structure: + { + "code": str, # buggy Python code + "tests": str, # path to test.py + "metadata": dict, # source, repo, problem_statement, etc. + "problem_dir": str, # directory containing buggy.py and test.py + "problem_id": str, # unique identifier for this task + } + + Raises: + TaskLoadError: If no task can be loaded from any source + """ + difficulty = (difficulty or "").strip().lower() + if difficulty not in {"easy", "medium", "hard"}: + raise ValueError( + f"Invalid difficulty='{difficulty}'. Must be one of: easy, medium, hard" + ) + + # Strategy: try sources in order, with fallback if enabled + if self.source == "local": + return self._load_local(difficulty) + + elif self.source == "swebench": + return self._load_swebench(difficulty) + + else: # "auto" mode + logger.debug("Auto mode: trying SWE-bench first...") + swebench_error = None + try: + return self._load_swebench(difficulty) + except Exception as e: + swebench_error = str(e) + logger.debug(f"SWE-bench failed: {e}") + + if self.enable_fallback: + logger.info("SWE-bench unavailable, falling back to local dataset") + try: + return self._load_local(difficulty) + except Exception as local_error: + raise TaskLoadError( + f"Both SWE-bench and local fallback failed:\n" + f" SWE-bench: {swebench_error}\n" + f" Local: {local_error}" + ) from local_error + else: + raise TaskLoadError( + f"SWE-bench loading failed and fallback disabled: {swebench_error}" + ) + + def _load_local(self, difficulty: Difficulty) -> Dict[str, Any]: + """Load from local hardcoded dataset.""" + try: + task = get_hardcoded_task(difficulty) + task["metadata"]["source"] = "local" + logger.info(f"Loaded task from local dataset: {task.get('problem_id')}") + return task + except Exception as e: + error_msg = f"Failed to load from local dataset: {e}" + logger.warning(error_msg) + raise TaskLoadError(error_msg) from e + + def _load_swebench(self, difficulty: Difficulty) -> Dict[str, Any]: + """Load from SWE-bench Lite dataset.""" + try: + task = get_swebench_task(difficulty) + task["metadata"]["source"] = "swebench" + logger.info( + f"Loaded task from SWE-bench: {task.get('problem_id')} " + f"(repo: {task['metadata'].get('repo', '?')})" + ) + return task + except Exception as e: + error_msg = f"Failed to load from SWE-bench: {e}" + logger.debug(error_msg) + raise TaskLoadError(error_msg) from e + + +# Global singleton instance for backward compatibility +_default_manager: Optional[TaskManager] = None + + +def get_task_manager(source: Optional[TaskSource] = None) -> TaskManager: + """ + Get or create the default TaskManager instance. + + Args: + source: Override the source selection. If None, uses TASK_SOURCE env var. + + Returns: + TaskManager instance + """ + global _default_manager + if _default_manager is None or source is not None: + _default_manager = TaskManager(source=source) + return _default_manager + + +def load_task(difficulty: Difficulty, source: Optional[TaskSource] = None) -> Dict[str, Any]: + """ + Convenience function: load a task in one call. + + Args: + difficulty: "easy", "medium", or "hard" + source: Optional override for task source + + Returns: + Task dict + """ + manager = get_task_manager(source=source) + return manager.load_task(difficulty) diff --git a/rl_code_fix_env/inference.py b/rl_code_fix_env/inference.py index c757223c7eac7b0d0a00600ce2cd62b51053c43f..882829d5cee77fbd6f0b7fbfe18854efd9e65a01 100644 --- a/rl_code_fix_env/inference.py +++ b/rl_code_fix_env/inference.py @@ -34,6 +34,22 @@ from client import CodeFixerEnv from models import CodeFixerAction from dotenv import load_dotenv load_dotenv() +import logging + +logger = logging.getLogger("__warnings__") + +def _enable_hermetic_runtime() -> None: + """ + Keep this process isolated from parent/global Python environment leakage. + """ + if os.getenv("HERMETIC_RUN", "1").strip().lower() not in {"1", "true", "yes"}: + return + os.environ["PYTHONPATH"] = "" + os.environ["PYTHONNOUSERSITE"] = "1" + os.environ["PYTHONDONTWRITEBYTECODE"] = "1" + os.environ.pop("PYTHONHOME", None) + +_enable_hermetic_runtime() API_BASE_URL = os.getenv("API_BASE_URL", "https://integrate.api.nvidia.com/v1") API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") @@ -43,7 +59,7 @@ MAX_STEPS = int(os.getenv("MAX_STEPS", "10")) TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) MAX_TOKENS = int(os.getenv("MAX_TOKENS", "512")) -SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_THRESHOLD", "1.0")) +SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_THRESHOLD", "0.5")) MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3")) _DIFFICULTIES = ["easy", "medium", "hard"] @@ -207,6 +223,7 @@ def env_step(action_type: str, payload: str = "") -> dict: def _strip_markdown_fences(text: str) -> str: cleaned = (text or "").strip() + # Remove opening fence with optional language if cleaned.startswith("```"): cleaned = re.sub(r"^```[a-zA-Z0-9_-]*\s*\n", "", cleaned) cleaned = re.sub(r"\n```$", "", cleaned.strip()) @@ -236,14 +253,80 @@ def _build_unified_diff(original: str, revised: str) -> str: return "\n".join(lines).strip() +def _looks_like_python_code(text: str) -> bool: + """Check if text looks like Python code (not a diff, not prose).""" + text_lower = text.lower().strip() + # Not a diff + if _looks_like_unified_diff(text): + return False + # Not just "no change" or similar + if text_lower in ("no_change", "no change", "no changes", "no changes made", + "the code is correct", "already correct", "no bugs found", + "no bug", "no bugs", "already fixed"): + return False + # Check for Python-like patterns (def, class, import, etc.) + python_indicators = ['def ', 'class ', 'import ', 'from ', 'return ', + 'if __name__', 'async def', 'self.', 'print('] + return any(indicator in text for indicator in python_indicators) + + +def _looks_like_output_value(text: str) -> bool: + """Check if text is just an output value like [[3,1],[4,2]]""" + text = text.strip() + # Looks like a literal value/expression, not code + if text.startswith('[') and text.endswith(']') and '\n' not in text: + # Could be a list/dict literal + if ',' in text or ':' in text: + return True + # Single word or simple expression + if '\n' not in text and len(text.split()) <= 3: + # Check if it's not a function definition + if 'def ' not in text and 'class ' not in text: + return True + return False + + def _normalize_action(raw_action: str, original_code: str) -> str: + """Normalize LLM output to unified diff format.""" cleaned = _strip_markdown_fences(raw_action) - if not cleaned or cleaned == "NO_CHANGE": + + # Handle empty or NO_CHANGE cases + if not cleaned: return "" + + cleaned_lower = cleaned.lower().strip() + if cleaned_lower == "no_change": + return "" + + # Check for common "no change" phrases + no_change_phrases = [ + "no change", "no changes", "no changes made", + "the code is correct", "already correct", "no bugs found", + "no bug", "no bugs", "already fixed", "no modifications needed", + "the provided code is already correct", "code appears to be correct" + ] + if any(phrase in cleaned_lower for phrase in no_change_phrases): + return "" + + # Check if output looks like a raw value (not code) - reject it + if _looks_like_output_value(cleaned): + print(f"[DEBUG] Rejected output that looks like a value: {cleaned[:50]}...", + flush=True, file=sys.stderr) + return "" + + # Check if it looks like a unified diff already if _looks_like_unified_diff(cleaned): return cleaned if _is_valid_unified_diff(cleaned) else "" - generated = _build_unified_diff(original_code, cleaned) - return generated if _is_valid_unified_diff(generated) else "" + + # Check if it looks like Python code - convert to diff + if _looks_like_python_code(cleaned): + generated = _build_unified_diff(original_code, cleaned) + return generated if _is_valid_unified_diff(generated) else "" + + # If nothing matches, return empty (invalid output) + print(f"[DEBUG] Could not parse LLM output as code or diff: {cleaned[:100]}...", + flush=True, file=sys.stderr) + return "" def get_action(observation: dict, history: List[str], step: int) -> str: @@ -400,4 +483,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/rl_code_fix_env/openenv_rl_code_fix_env.egg-info/SOURCES.txt b/rl_code_fix_env/openenv_rl_code_fix_env.egg-info/SOURCES.txt index 2fdb785f7de6064e51db5b731ab3f64bb180d29b..b842017e0ee308d9f0da03d4a65e13a1de151d5c 100644 --- a/rl_code_fix_env/openenv_rl_code_fix_env.egg-info/SOURCES.txt +++ b/rl_code_fix_env/openenv_rl_code_fix_env.egg-info/SOURCES.txt @@ -1,18 +1,24 @@ README.md __init__.py +_aliases.py client.py +conftest.py inference.py models.py -prompts.py pyproject.toml ./__init__.py +./_aliases.py ./client.py +./conftest.py ./inference.py ./models.py -./prompts.py dataset/README.md dataset/__init__.py +dataset/generate_swebench_tasks.py dataset/loader.py +dataset/prepare_swebench.py +dataset/swebench_adapter.py +dataset/task_manager.py dataset/tasks.py dataset/problem_1/buggy.py dataset/problem_1/metadata.json @@ -103,6 +109,7 @@ src/environment/environment.py src/reward/llm_scorer.py src/reward/reward.py src/reward/trace_scorer.py +src/reward/trajectory_logger.py src/sandbox/__init__.py src/sandbox/execution.py src/sandbox/patcher.py diff --git a/rl_code_fix_env/pyproject.toml b/rl_code_fix_env/pyproject.toml index 9fe0b7cfba4b438e42ad42186ab28ce266b634ad..65d9279ab90e3d2a3684cbbb7215acd4209b34bf 100644 --- a/rl_code_fix_env/pyproject.toml +++ b/rl_code_fix_env/pyproject.toml @@ -77,3 +77,10 @@ timeout = 25 [tool.setuptools.package-data] "rl_code_fix_env.dataset" = ["problem_*/*", "README.md", "tasks.py"] + +[tool.setuptools.exclude-package-data] +# Prevent conftest.py and _aliases.py from being installed into site-packages. +# If they land there AND the source tree is on sys.path, pytest raises +# ImportPathMismatchError because it finds the same module at two paths. +# The aliases are handled by src/dataset/__init__.py which IS packaged. +"rl_code_fix_env" = ["conftest.py", "_aliases.py"] diff --git a/rl_code_fix_env/server/Dockerfile b/rl_code_fix_env/server/Dockerfile index 949c844800db0b2e5326ddf67f1d72514d7404f1..ba8ff54909dbccd737a5d2a74cebf884706dce75 100644 --- a/rl_code_fix_env/server/Dockerfile +++ b/rl_code_fix_env/server/Dockerfile @@ -1,14 +1,3 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Multi-stage build using openenv-base -# This Dockerfile is flexible and works for both: -# - In-repo environments (with local OpenEnv sources) -# - Standalone environments (with openenv from PyPI/Git) -# The build script (openenv build) handles context detection and sets appropriate build args. ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest FROM ${BASE_IMAGE} AS builder @@ -33,25 +22,25 @@ WORKDIR /app/env # Ensure uv is available (for local builds where base image lacks it) RUN if ! command -v uv >/dev/null 2>&1; then \ - curl -LsSf https://astral.sh/uv/install.sh | sh && \ - mv /root/.local/bin/uv /usr/local/bin/uv && \ - mv /root/.local/bin/uvx /usr/local/bin/uvx; \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ + mv /root/.local/bin/uv /usr/local/bin/uv && \ + mv /root/.local/bin/uvx /usr/local/bin/uvx; \ fi - + # Install dependencies using uv sync # If uv.lock exists, use it; otherwise resolve on the fly RUN --mount=type=cache,target=/root/.cache/uv \ if [ -f uv.lock ]; then \ - uv sync --frozen --no-install-project --no-editable; \ + uv sync --frozen --no-install-project --no-editable; \ else \ - uv sync --no-install-project --no-editable; \ + uv sync --no-install-project --no-editable; \ fi RUN --mount=type=cache,target=/root/.cache/uv \ if [ -f uv.lock ]; then \ - uv sync --frozen --no-editable; \ + uv sync --frozen --no-editable; \ else \ - uv sync --no-editable; \ + uv sync --no-editable; \ fi # Final runtime stage @@ -59,17 +48,19 @@ FROM ${BASE_IMAGE} WORKDIR /app -# Copy the virtual environment from builder -COPY --from=builder /app/env/.venv /app/.venv - -# Copy the environment code +# Copy environment code + its in-place virtualenv from builder. +# Keep the venv at the same path it was created with (/app/env/.venv) +# to avoid relocation issues and dual-venv path conflicts. COPY --from=builder /app/env /app/env -# Set PATH to use the virtual environment -ENV PATH="/app/.venv/bin:$PATH" +# Use the single in-repo venv +ENV VIRTUAL_ENV="/app/env/.venv" +ENV PATH="/app/env/.venv/bin:$PATH" -# Set PYTHONPATH so imports work correctly -ENV PYTHONPATH="/app/env:$PYTHONPATH" +# Hermetic runtime: keep imports pinned to repo code + active venv. +ENV PYTHONPATH="/app/env" +ENV PYTHONNOUSERSITE="1" +ENV PYTHONDONTWRITEBYTECODE="1" # Health check HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ diff --git a/rl_code_fix_env/server/requirements.txt b/rl_code_fix_env/server/requirements.txt index 52ac7a0977ee6beb4e14af9d08073cea482a45df..ab69e7993748cfb3fe2ea437fa75872e363fd898 100644 --- a/rl_code_fix_env/server/requirements.txt +++ b/rl_code_fix_env/server/requirements.txt @@ -27,4 +27,10 @@ pytest-json-report unidiff diff-match-patch openai -python-dotenv \ No newline at end of file +python-dotenv + +# Test dependencies for SWE-bench tasks +Django>=3.2 +Flask>=2.0 +scipy>=1.5 +scikit-learn>=0.24 \ No newline at end of file diff --git a/rl_code_fix_env/src/dataset/__init__.py b/rl_code_fix_env/src/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de99ed1bcbddb46c084a1116d479cf58bc5b1666 --- /dev/null +++ b/rl_code_fix_env/src/dataset/__init__.py @@ -0,0 +1,54 @@ +""" +Alias the real dataset package so that `import src.dataset` works. + +This module ensures that `import src.dataset` and +`from src.dataset.problem_X.buggy import ...` statements work correctly +by aliasing the real dataset package and all its nested submodules. + +This is needed because: +- Test files import from `src.dataset.problem_X.buggy` +- But the actual files are in `dataset/problem_X/buggy.py` +- So we create sys.modules aliases to bridge the gap +""" +import sys +import importlib +import pkgutil + +try: + import dataset as _real_dataset +except ImportError: + # dataset package not on sys.path — try to add the env root + import os + from pathlib import Path + + _here = Path(__file__).parent # src/dataset/ + # Walk up to find the directory that contains the `dataset/` folder + for _candidate in [_here.parent.parent, _here.parent.parent.parent]: + if (_candidate / "dataset").is_dir(): + sys.path.insert(0, str(_candidate)) + break + import dataset as _real_dataset + +# Register top-level alias +sys.modules["src.dataset"] = _real_dataset + +# Register every direct subpackage / submodule and their children +for _pkg in pkgutil.iter_modules(_real_dataset.__path__): + _full = f"dataset.{_pkg.name}" + _alias = f"src.dataset.{_pkg.name}" + try: + _mod = importlib.import_module(_full) + sys.modules.setdefault(_alias, _mod) + + # Also register nested submodules (e.g. problem_X.buggy, problem_X.helpers) + if hasattr(_mod, "__path__"): + for _sub in pkgutil.iter_modules(_mod.__path__): + _sub_full = f"{_full}.{_sub.name}" + _sub_alias = f"{_alias}.{_sub.name}" + try: + _sub_mod = importlib.import_module(_sub_full) + sys.modules.setdefault(_sub_alias, _sub_mod) + except Exception: + pass + except Exception: + pass diff --git a/rl_code_fix_env/src/environment/environment.py b/rl_code_fix_env/src/environment/environment.py index c7d28f61a63ecd5608fc5934190facf29c03f9bd..81daeff43412ffed021d5b136b285eff3c1f4f0c 100644 --- a/rl_code_fix_env/src/environment/environment.py +++ b/rl_code_fix_env/src/environment/environment.py @@ -49,6 +49,7 @@ class CodeEnv: "workspace": self._data["problem_dir"], "logs": startup_log, "test_score": 0.0, + "prev_test_score": 0.0, # Track previous score for regression penalty "passed": 0, "total": 1, } @@ -67,6 +68,7 @@ class CodeEnv: (obs_dict, reward, done, info) """ prev_state = self._state.copy() + prev_test_score = self._state.get("test_score", 0.0) action_type = action.get("type", "") payload = action.get("payload") or "" @@ -89,10 +91,24 @@ class CodeEnv: test_file=self._state["test_path"], workspace_dir=self._state["workspace"], ) - self._state["passed"] = int(passed) - self._state["total"] = 1 - self._state["test_score"] = 1.0 if passed else 0.0 - self._state["logs"] = logs + + # Parse test counts from logs if available (format: [TEST_COUNTS] passed=X total=Y) + import re + test_counts_match = re.search(r'\[TEST_COUNTS\]\s+passed=(\d+)\s+total=(\d+)', logs) + if test_counts_match: + passed_count = int(test_counts_match.group(1)) + total_count = int(test_counts_match.group(2)) + self._state["passed"] = passed_count + self._state["total"] = max(total_count, 1) + # Calculate partial score: passed/total (range 0.0 to 1.0) + self._state["test_score"] = passed_count / max(total_count, 1) + else: + # Fallback to binary scoring if counts not found + self._state["passed"] = 1 if passed else 0 + self._state["total"] = 1 + self._state["test_score"] = 1.0 if passed else 0.0 + + self._state["logs"] = logs elif action_type == "get_logs": self._state["logs"] = self._state.get("logs") or "No logs yet. Run tests first." @@ -108,7 +124,9 @@ class CodeEnv: self._state["passed"] >= self._state["total"] # all tests pass or self.steps >= self.max_steps # step budget exhausted ) - + if action_type == "apply_patch": + self._state["last_action_empty"] = not (payload and payload.strip()) + last_action_empty = self._state.get("last_action_empty", False) try: reward = compute_reward( test_score=self._state["test_score"], @@ -116,6 +134,8 @@ class CodeEnv: code=self._state["code"], steps_taken=self.steps, max_steps=self.max_steps, + prev_test_score=prev_test_score, # Pass for regression penalty + last_action_empty=last_action_empty, ) if self._state["passed"] >= self._state["total"]: reward = 1.0 diff --git a/rl_code_fix_env/src/reward/llm_scorer.py b/rl_code_fix_env/src/reward/llm_scorer.py index 7f7b94d6db2f16de8c83308555eba248daec6874..d09e0484ee1d5a694d552545f67210208439689c 100644 --- a/rl_code_fix_env/src/reward/llm_scorer.py +++ b/rl_code_fix_env/src/reward/llm_scorer.py @@ -1,54 +1,29 @@ """ -LLM score that prompts an LLM-as-a-judge to evaluate the final fixed code for readability, style, and efficiency, returning a score between 0.0 and 1.0. +LLM scorer: evaluates code fixes for correctness, minimality, and quality. Returns a score 0.0–1.0. """ import os -import uuid import re import json import hashlib from typing import Optional -# from rl_code_fix_env.prompts import LLM_SCORER_PROMPT, USER_TEMPLATE -LLM_SCORER_PROMPT = r"""You are an expert code quality judge for reinforcement learning reward system. Your evaluation directly shape an agent's learning signal, so be precise and consistent. - - Evaluate code on exactly three axes, each scored 0.010.0: - - 1. READABILITY naming clarity, comment quality, logical flow, cognitive load - 2. STYLE PEP8/language conventions, consistent formatting, idiomatic patterns - 3. EFFICIENCY algorithmic complexity, redundant ops, memory usage, avoidable loops - - Respond ONLY with this JSON structure, no preamble: - { - "readability": , - "style": , - "efficiency": , - "reasoning": "" - } - """ - -USER_TEMPLATE = """Evaluate this code: -```python -{code} -``` - -Return only the JSON.""" - +from dotenv import load_dotenv +load_dotenv() +from prompts import LLM_SCORER_PROMPT, USER_TEMPLATE from openai import OpenAI -WEIGHTS = {"readability": 0.35, - "style":0.30, - "efficiency": 0.35} +WEIGHTS = {"correctness": 0.4, "minimality": 0.3, "quality": 0.3} SYSTEM_PROMPT = LLM_SCORER_PROMPT USER_TEMPLATE = USER_TEMPLATE _score_cache: dict[str, float] = {} -def _cache_key(code: str, bug_id: Optional[str] = None)->str: +def _cache_key(code: str, bug_id: Optional[str] = None) -> str: content_to_hash = f"{bug_id}: {code}" if bug_id else code return hashlib.sha256(content_to_hash.encode()).hexdigest() -def _parse_scores(raw: str)->Optional[dict]: +def _parse_scores(raw: str) -> Optional[dict]: cleaned = re.sub(r"```(?:json)?|```", "", raw).strip() try: return json.loads(cleaned) @@ -62,7 +37,7 @@ def _parse_scores(raw: str)->Optional[dict]: pass return None -def _aggregate(scores: dict)->float: +def _aggregate(scores: dict) -> float: raw = sum(WEIGHTS[k] * float(scores[k]) for k in WEIGHTS if k in scores) return round(min(max(raw / 10.0, 0.0), 1.0), 4) @@ -79,8 +54,6 @@ def score_code_quality( if use_cache and key in _score_cache: return _score_cache[key] - from dotenv import load_dotenv - load_dotenv() api_key = os.getenv("API_KEY") api_base_url = os.getenv("API_BASE_URL", "https://integrate.api.nvidia.com/v1") model_name = os.getenv("MODEL_NAME", "qwen/qwen2.5-coder-32b-instruct") @@ -95,11 +68,10 @@ def score_code_quality( model=model_name, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": USER_TEMPLATE.format(code=code)}, + {"role": "user", "content": USER_TEMPLATE.format(original_code=code)}, ], - temperature=0.0, - top_p=0.8, - max_tokens=512, + temperature=float(os.getenv("TEMPERATURE", "0.7")), + max_tokens=256, ) raw = completion.choices[0].message.content or "" parsed = _parse_scores(raw) diff --git a/rl_code_fix_env/src/reward/reward.py b/rl_code_fix_env/src/reward/reward.py index 864281c48a39c61a1a4c38f250370b1afc39e472..e56b6b972b172bc7b75c53467029fd948069e8b6 100644 --- a/rl_code_fix_env/src/reward/reward.py +++ b/rl_code_fix_env/src/reward/reward.py @@ -1,36 +1,45 @@ from ..reward.trace_scorer import score_trace -def compute_reward(test_score, trace_obj, code, steps_taken, max_steps): +def compute_reward(test_score, trace_obj, code, steps_taken, max_steps, prev_test_score=0.0, last_action_empty=False): """ Compute reward for code fixing episode. Args: - test_score: Test execution score (0.0-1.0) - trace_obj: TraceCollector object with action history - code: Fixed code string for quality evaluation - steps_taken: Number of steps taken - max_steps: Maximum steps allowed + test_score: Test execution score (0.0-1.0) + trace_obj: TraceCollector object with action history + code: Fixed code string for quality evaluation + steps_taken: Number of steps taken + max_steps: Maximum steps allowed + prev_test_score: Previous test score (for regression penalty) + last_action_empty: Whether the last action was empty/no-op Returns: Reward score in [0.0, 1.0] """ - # 1. Functional Progress (70% weight) — only real signal + # If last action was empty/no-op, give minimal reward to encourage meaningful actions + if last_action_empty: + return 0.0 + + # 1. Functional Progress (90% weight) — primary signal functional_reward = float(test_score) - # 2. Reasoning Quality (20% weight) - trace_reward = score_trace(trace_obj) if trace_obj else 0.0 + # 1b. Regression Penalty: penalize when test score decreases + # This encourages the agent to not make things worse + test_score_delta = test_score - prev_test_score + regression_penalty = 0.0 + if test_score_delta < 0: + # Penalize proportionally to how much the score dropped + regression_penalty = abs(test_score_delta) * 0.1 # 10% penalty for regression - # 3. Deterministic quality proxy (10% weight) - quality_reward = 1.0 if code and code.strip() else 0.0 - efficiency_penalty = 0.05 * (steps_taken / max(max_steps, 1)) + # 2. Reasoning Quality (10% weight) - bonus for good reasoning trace + trace_reward = max(0.0, score_trace(trace_obj) if trace_obj else 0.0) # Ensure non-negative - # Weighted sum — coefficients sum to 1.0 before penalty + # Weighted sum — coefficients sum to 1.0 before penalties reward = ( - 0.7 * functional_reward - + 0.2 * trace_reward - + 0.1 * quality_reward - - efficiency_penalty + 0.9 * functional_reward + + 0.1 * trace_reward + - regression_penalty ) return max(0.0, min(1.0, reward)) \ No newline at end of file diff --git a/rl_code_fix_env/src/reward/trajectory_logger.py b/rl_code_fix_env/src/reward/trajectory_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..db1415a3fe083d74a2eeb803405bbb2f69be6b2a --- /dev/null +++ b/rl_code_fix_env/src/reward/trajectory_logger.py @@ -0,0 +1,95 @@ +""" +FIX 9: Trajectory logging for GRPO training data collection. + +Per rulebook Section 5 & 6: Save episode trajectories to enable GRPO training. +Each episode is saved as JSON with metadata, summary, and full trajectory. +""" + +import json +import os +from pathlib import Path +from datetime import datetime +from typing import Dict, List, Any, Optional + + +class TrajectoryLogger: + """Save episode trajectories for GRPO training.""" + + def __init__(self, output_dir: str = "./episodes"): + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + def save_episode( + self, + task: str, + difficulty: str, + success: bool, + steps: int, + rewards: List[float], + trajectory: List[Dict[str, Any]], + model: str = "unknown", + elapsed_s: float = 0.0, + ) -> str: + """ + Save one episode to JSON for GRPO training. + + Args: + task: Task identifier (e.g., "easy", "problem_1") + difficulty: Difficulty level (easy/medium/hard) + success: Whether episode succeeded + steps: Number of steps taken + rewards: List of rewards per step + trajectory: List of {observation, action, reward, done, test_score} + model: Model name used + elapsed_s: Total episode time + + Returns: + Path to saved episode file + """ + episode = { + "metadata": { + "task": task, + "difficulty": difficulty, + "success": success, + "model": model, + "timestamp": datetime.now().isoformat(), + "elapsed_s": round(elapsed_s, 3), + }, + "summary": { + "steps": steps, + "rewards": [round(r, 4) for r in rewards], + "final_reward": round(rewards[-1], 4) if rewards else 0.0, + "mean_reward": round(sum(rewards) / len(rewards), 4) if rewards else 0.0, + "max_reward": round(max(rewards), 4) if rewards else 0.0, + }, + "trajectory": trajectory, + } + + # Filename: difficulty_timestamp.json + timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S') + filename = f"{difficulty}_{task}_{timestamp_str}.json" + filepath = self.output_dir / filename + + with open(filepath, 'w') as f: + json.dump(episode, f, indent=2) + + return str(filepath) + + @staticmethod + def load_episodes(output_dir: str = "./episodes") -> List[Dict[str, Any]]: + """Load all saved episodes from directory.""" + episodes = [] + episode_dir = Path(output_dir) + + if not episode_dir.exists(): + return episodes + + for json_file in sorted(episode_dir.glob("*.json")): + try: + with open(json_file, 'r') as f: + episode = json.load(f) + episodes.append(episode) + except Exception as e: + print(f"Failed to load {json_file}: {e}") + + return episodes diff --git a/rl_code_fix_env/src/sandbox/execution.py b/rl_code_fix_env/src/sandbox/execution.py index ed9ca0d606e502039b93e5728bec4c1c957bb9cc..70176fa2571d0674f52f74ba75586be45e9d4e6c 100644 --- a/rl_code_fix_env/src/sandbox/execution.py +++ b/rl_code_fix_env/src/sandbox/execution.py @@ -22,6 +22,52 @@ def _docker_available() -> bool: PYTEST_REPORT_FILENAME = ".pytest_error_report.json" + +def _parse_pytest_counts(logs: str) -> Tuple[Optional[int], Optional[int]]: + """ + Parse pytest output to extract pass/fail counts. + + Returns: + (passed_count, total_count) or (None, None) if not found + """ + import re + + # Try to find patterns like "5 passed" or "3 passed, 2 failed" + # Also matches "5 passed, 1 failed" or "1 passed" + patterns = [ + r'(\d+)\s+passed', # "5 passed" + r'(\d+)\s+passed.*?(\d+)\s+failed', # "3 passed, 2 failed" + r'(\d+)\s+failed', # "0 passed, 5 failed" - should handle this + ] + + # Look for summary line like "====== 5 passed, 2 failed ======" + summary_match = re.search(r'(\d+)\s+passed[,.]?\s*(\d+)?\s*failed?', logs) + if summary_match: + passed = int(summary_match.group(1)) + failed = int(summary_match.group(2)) if summary_match.group(2) else 0 + return passed, passed + failed + + # Look for "X passed" only + passed_match = re.search(r'(\d+)\s+passed', logs) + if passed_match: + passed = int(passed_match.group(1)) + # Check if there's also failed count + failed_match = re.search(r'(\d+)\s+failed', logs) + if failed_match: + failed = int(failed_match.group(1)) + return passed, passed + failed + else: + # Assume all tests passed if only "X passed" and no failed + return passed, passed + + # Look for "X failed" only without passed + failed_only_match = re.search(r'(\d+)\s+failed', logs) + if failed_only_match: + failed = int(failed_only_match.group(1)) + return 0, failed + + return None, None + def _build_docker_cmd(workspace_dir: str, test_file: str, extra_flags: List[str] = []) -> List[str]: """ Centralised Docker command builder shared by run_test_file and @@ -66,37 +112,18 @@ def run_test_file( workspace_dir: str, ) -> Tuple[bool, str]: """ - Execute a pytest test file, choosing the best available runner: - - 1. **Docker sandbox** (preferred when docker CLI is present on the host) - ephemeral container with network=none, memory cap, read-only FS. - 2. **Direct subprocess** fallback (used when running inside a container - where there is no nested Docker daemon, e.g. on HF Spaces). - writes current code to a temp dir and runs `python -m pytest` directly. + Execute a pytest test file using direct subprocess execution. + + IMPORTANT: Direct execution is ALWAYS used (never Docker) because: + 1. Docker containers don't have conftest.py for module aliasing (src.dataset issue) + 2. Docker containers don't have access to the full project structure + 3. Direct execution allows proper test discovery and import resolution + 4. This enables SWE-bench task compatibility Returns: (passed, logs) """ - if _docker_available(): - passed, logs = _run_in_docker(code_file, test_file, workspace_dir) - docker_unavailable_markers = ( - "permission denied while trying to connect to the docker api", - "error during connect", - "cannot connect to the docker daemon", - "is the docker daemon running", - "dockerdesktoplinuxengine", - "dockerdesktopwindowsengine", - "npipe://", - ) - lower_logs = (logs or "").lower() - if (not passed) and any(marker in lower_logs for marker in docker_unavailable_markers): - fallback_passed, fallback_logs = _run_direct(code_file, test_file, workspace_dir) - merged_logs = ( - "Docker unavailable; fell back to direct pytest execution.\n" - f"[docker]\n{logs}\n\n[direct]\n{fallback_logs}" - ) - return fallback_passed, merged_logs - return passed, logs + # Always use direct execution - skip Docker entirely for SWE-bench compatibility return _run_direct(code_file, test_file, workspace_dir) @@ -125,6 +152,14 @@ def _run_in_docker( ) passed = result.returncode == 0 logs = result.stdout + "\n" + result.stderr + + # Parse actual pass/fail counts from pytest output + passed_count, total_count = _parse_pytest_counts(logs) + + # If we have counts, return them encoded in the logs for the environment to parse + if passed_count is not None and total_count is not None: + logs = f"[TEST_COUNTS] passed={passed_count} total={total_count}\n" + logs + return passed, logs except subprocess.TimeoutExpired: return False, "CRITICAL ERROR: Execution timed out." @@ -159,18 +194,15 @@ def _run_direct( break repo_root = repo_root.parent - # Build a clean subprocess environment with PYTHONPATH pointing at repo root. - # This ensures `from src.dataset.problem_X.buggy import ...` always resolves, - # even if the parent process has a different (or missing) PYTHONPATH. + # Build a hermetic subprocess environment. We intentionally avoid inheriting + # Python import path state from parent shells/venvs to prevent dual-site-packages + # import collisions (e.g. ImportPathMismatchError on conftest). subprocess_env = os.environ.copy() - existing_pythonpath = subprocess_env.get("PYTHONPATH", "") repo_root_str = str(repo_root) - if repo_root_str not in existing_pythonpath: - subprocess_env["PYTHONPATH"] = ( - repo_root_str + os.pathsep + existing_pythonpath - if existing_pythonpath - else repo_root_str - ) + subprocess_env["PYTHONPATH"] = repo_root_str + subprocess_env["PYTHONNOUSERSITE"] = "1" + subprocess_env["PYTHONDONTWRITEBYTECODE"] = "1" + subprocess_env.pop("PYTHONHOME", None) # Back up original buggy.py and overwrite with patched code original_code: Optional[str] = None @@ -183,6 +215,17 @@ def _run_direct( report_file = workspace_path / PYTEST_REPORT_FILENAME + # Import conftest first to set up src.dataset alias before running pytest + # This ensures imports like "from src.dataset.problem_X.buggy import ..." work + conftest_import_cmd = [ + "python", "-c", + f"import sys; sys.path.insert(0, r'{repo_root_str}'); import conftest" + ] + try: + subprocess.run(conftest_import_cmd, capture_output=True, timeout=10, cwd=str(repo_root)) + except Exception: + pass # Continue even if pre-import fails + cmd = [ "python", "-m", "pytest", str(test_path), # absolute path to test.py @@ -205,6 +248,14 @@ def _run_direct( ) passed = result.returncode == 0 logs = result.stdout + ("\n" + result.stderr if result.stderr else "") + + # Parse actual pass/fail counts from pytest output + passed_count, total_count = _parse_pytest_counts(logs) + + # If we have counts, prepend them to logs for the environment to parse + if passed_count is not None and total_count is not None: + logs = f"[TEST_COUNTS] passed={passed_count} total={total_count}\n" + logs + return passed, logs.strip() or "(no output)" except subprocess.TimeoutExpired: diff --git a/rl_code_fix_env/src/sandbox/patcher.py b/rl_code_fix_env/src/sandbox/patcher.py index 3a8889cead3eb2f376cc6a29e85d9f6b4efc1dfd..47708c011890ba4382c134eeff894fcbead74d5a 100644 --- a/rl_code_fix_env/src/sandbox/patcher.py +++ b/rl_code_fix_env/src/sandbox/patcher.py @@ -1,9 +1,26 @@ +import ast import unidiff import diff_match_patch as dmp_module from dataclasses import dataclass from typing import List, Tuple, Optional +def validate_python_syntax(code: str) -> Tuple[bool, Optional[str]]: + """ + Validate that code string is valid Python by parsing with AST. + + Returns: + (is_valid, error_message) + """ + try: + ast.parse(code) + return True, None + except SyntaxError as e: + return False, f"SyntaxError: {e.msg} at line {e.lineno}, column {e.offset}" + except Exception as e: + return False, f"ParseError: {str(e)}" + + @dataclass class HunkResult: hunk_index: int @@ -128,6 +145,19 @@ def apply_patch( hunk_idx += 1 + # Validate the final patched code is valid Python + is_valid, error_msg = validate_python_syntax(curr_code) + if not is_valid: + # Return original code if patched code is invalid Python + return code, [HunkResult( + hunk_index=0, + source_file="validation", + applied=False, + confidence=0.0, + location_found=0, + failed_reason=f"Invalid Python after patch: {error_msg}", + )] + return curr_code, results @@ -136,15 +166,15 @@ def _reconstruct_from_hunk( include_added: bool = True, include_removed: bool = True, ) -> str: - lines = [] + res = "" for line in hunk: if line.line_type == ' ': - lines.append(line.value.rstrip('\n')) + res += line.value elif line.line_type == '-' and include_removed: - lines.append(line.value.rstrip('\n')) + res += line.value elif line.line_type == '+' and include_added: - lines.append(line.value.rstrip('\n')) - return '\n'.join(lines) + res += line.value + return res def _line_to_char(text: str, line_idx: int) -> int: diff --git a/rl_code_fix_env/test_patch.py b/rl_code_fix_env/test_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..9c2e5d31bc25c2b80e9961da2b7289d2188a70b0 --- /dev/null +++ b/rl_code_fix_env/test_patch.py @@ -0,0 +1,9 @@ +import diff_match_patch as dmp_module +dmp_init = dmp_module.diff_match_patch() +old = ' \"\"\"Rotate matrix 90 degrees clockwise.\"\"\"\n t = transpose(matrix)\n # BUG: this is counter-clockwise.\n return t[::-1]\n' +new = ' \"\"\"Rotate matrix 90 degrees clockwise.\"\"\"\n t = transpose(matrix)\n # BUG: this is counter-clockwise.\n return [row[::-1] for row in t]\n' +diffs = dmp_init.diff_main(old, new) +dmp_init.diff_cleanupSemantic(diffs) +patches = dmp_init.patch_make(old, diffs) +patched, res = dmp_init.patch_apply(patches, old) +print('patched:', repr(patched))