shinka-backup / tests /test_edit_base.py
JustinTX's picture
Add files using upload-large-folder tool
1ca9dbd verified
from shinka.edit import apply_diff_patch, apply_full_patch
from shinka.edit.apply_diff import (
_find_indented_match,
_apply_indentation_to_replace,
_strip_trailing_whitespace,
)
patch_str = """
<<<<<<< SEARCH
def run_experiment(train_dataset, device):
epochs = 5
batch_size = 64
learning_rate = 0.01
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Initialize model, loss function, and optimizer
model = MNISTNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# Training loop
for epoch in range(1, epochs + 1):
train(model, device, train_loader, optimizer, criterion, epoch)
return model
=======
THIS IS A TEST
>>>>>>> REPLACE
<<<<<<< SEARCH
THIS IS A TEST
=======
THIS IS A TEST PART 2
>>>>>>> REPLACE
"""
new_str = """# EVOLVE-BLOCK-START
THIS IS A TEST PART 2
# EVOLVE-BLOCK-END"""
def test_edit():
result = apply_diff_patch(
original_path="tests/file.py",
patch_str=patch_str,
patch_dir=None,
)
updated_str, num_applied, output_path, error, patch_txt, diff_path = result
assert updated_str == new_str
assert num_applied == 2
assert output_path is None
assert error is None
def test_apply_full_patch_single_evolve_block():
"""Test apply_full_patch with single EVOLVE-BLOCK region."""
original_content = """# Immutable header
import os
# EVOLVE-BLOCK-START
def old_function():
return "old"
# EVOLVE-BLOCK-END
# Immutable footer
if __name__ == "__main__":
pass
"""
patch_content = """```python
# Immutable header
import os
# EVOLVE-BLOCK-START
def new_function():
return "new"
def another_function():
return "another"
# EVOLVE-BLOCK-END
# Immutable footer
if __name__ == "__main__":
pass
```"""
expected_result = """# Immutable header
import os
# EVOLVE-BLOCK-START
def new_function():
return "new"
def another_function():
return "another"
# EVOLVE-BLOCK-END
# Immutable footer
if __name__ == "__main__":
pass
"""
result = apply_full_patch(
patch_str=patch_content,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 1
assert output_path is None
assert error is None
# Now we can directly check the updated content
assert updated_content.strip() == expected_result.strip()
def test_apply_full_patch_with_evolve_blocks_in_patch():
"""Test apply_full_patch when patch contains EVOLVE-BLOCK markers."""
original_content = """# Header
# EVOLVE-BLOCK-START
def old_func1():
pass
# EVOLVE-BLOCK-END
# Middle section
# EVOLVE-BLOCK-START
def old_func2():
pass
# EVOLVE-BLOCK-END
# Footer
"""
patch_content = """```python
# Header
# EVOLVE-BLOCK-START
def new_func1():
return 1
# EVOLVE-BLOCK-END
# Middle section
# EVOLVE-BLOCK-START
def new_func2():
return 2
# EVOLVE-BLOCK-END
# Footer
```"""
result = apply_full_patch(
patch_str=patch_content,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 1
assert error is None
# Should have replaced both evolve blocks with new content
def test_apply_full_patch_full_file_without_markers_extracts_block_only():
"""Full-file patch without EVOLVE markers should not copy immutable code
into the evolve block; only the block payload is replaced."""
original_content = """# Header line\n# EVOLVE-BLOCK-START\nold_line()\n# EVOLVE-BLOCK-END\n# Footer line\n"""
# Patch is the entire file content but with the EVOLVE markers omitted.
patch_content = """```python
new_line()
another_new_line()
```"""
expected = """# Header line
# EVOLVE-BLOCK-START
new_line()
another_new_line()
# EVOLVE-BLOCK-END
# Footer line
"""
result = apply_full_patch(
patch_str=patch_content,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert error is None
assert num_applied == 1
assert updated_content == expected
def test_apply_full_patch_patch_with_start_marker_only():
"""Patch has only START marker; original has both markers."""
original_content = """# Header line
# EVOLVE-BLOCK-START
old_line()
# EVOLVE-BLOCK-END
# Footer line
"""
patch_content = """```python
# Header line
# EVOLVE-BLOCK-START
new_line()
# Footer line
```"""
expected = """# Header line
# EVOLVE-BLOCK-START
new_line()
# EVOLVE-BLOCK-END
# Footer line
"""
result = apply_full_patch(
patch_str=patch_content,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert error is None
assert num_applied == 1
assert updated_content == expected
def test_apply_full_patch_patch_with_end_marker_only():
"""Patch has only END marker; original has both markers."""
original_content = """# Header line
# EVOLVE-BLOCK-START
old_line()
# EVOLVE-BLOCK-END
# Footer line
"""
patch_content = """```python
# Header line
new_line()
# EVOLVE-BLOCK-END
# Footer line
```"""
expected = """# Header line
# EVOLVE-BLOCK-START
new_line()
# EVOLVE-BLOCK-END
# Footer line
"""
result = apply_full_patch(
patch_str=patch_content,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert error is None
assert num_applied == 1
assert updated_content == expected
def test_apply_full_patch_no_evolve_blocks():
"""Test apply_full_patch with no EVOLVE-BLOCK regions - should error."""
original_content = """# Just regular code
def function():
return "no evolve blocks"
"""
patch_content = """```python
def new_function():
return "new"
```"""
result = apply_full_patch(
patch_str=patch_content,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 0
assert error == "No EVOLVE-BLOCK regions found in original content"
assert output_path is None
assert updated_content == original_content # Should return original content
def test_apply_full_patch_multiple_evolve_blocks_ambiguous():
"""Test apply_full_patch with multiple EVOLVE-BLOCK regions."""
original_content = """# EVOLVE-BLOCK-START
def func1():
pass
# EVOLVE-BLOCK-END
# EVOLVE-BLOCK-START
def func2():
pass
# EVOLVE-BLOCK-END
"""
patch_content = """```python
def new_function():
return "ambiguous which block to replace"
```"""
result = apply_full_patch(
patch_str=patch_content,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 0
assert error is not None
assert "Multiple EVOLVE-BLOCK regions found" in error
assert "doesn't specify which to replace" in error
assert output_path is None
assert updated_content == original_content # Should return original content
def test_apply_full_patch_patch_with_single_marker_ambiguous_multiple_regions():
"""Single marker in patch is ambiguous when original has multiple regions."""
original_content = """# Header
# EVOLVE-BLOCK-START
func1()
# EVOLVE-BLOCK-END
# EVOLVE-BLOCK-START
func2()
# EVOLVE-BLOCK-END
# Footer
"""
# Patch includes only START marker
patch_content = """```python
# Header
# EVOLVE-BLOCK-START
new_code()
# Footer
```"""
updated_content, num_applied, output_path, error, patch_txt, diff_path = (
apply_full_patch(
patch_str=patch_content,
original_str=original_content,
language="python",
verbose=False,
)
)
assert num_applied == 0
assert error is not None
assert "only one EVOLVE-BLOCK marker" in error
def test_apply_full_patch_invalid_extraction():
"""Test apply_full_patch with invalid code extraction."""
original_content = """# EVOLVE-BLOCK-START
def old_func():
pass
# EVOLVE-BLOCK-END
"""
# No proper language fences - extract_between will return "none"
patch_content = "def new_function(): return 'no fences'"
result = apply_full_patch(
patch_str=patch_content,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
# extract_between returns "none" when it can't find the pattern
# After our fix, this should be treated as an error
assert num_applied == 0
assert error == "Could not extract code from patch string"
assert output_path is None
assert updated_content == original_content # Should return original content
def test_apply_full_patch_with_patch_dir():
"""Test apply_full_patch with patch directory specified."""
import tempfile
from pathlib import Path
original_content = """# EVOLVE-BLOCK-START
def old_function():
return "old"
# EVOLVE-BLOCK-END
"""
patch_content = """```python
def new_function():
return "new"
```"""
with tempfile.TemporaryDirectory() as temp_dir:
patch_dir = Path(temp_dir) / "test_patch"
result = apply_full_patch(
patch_str=patch_content,
original_str=original_content,
patch_dir=str(patch_dir),
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 1
assert error is None
assert output_path is not None
assert output_path.exists()
assert diff_path is not None
assert diff_path.exists()
# Check that files were created
assert (patch_dir / "rewrite.txt").exists()
assert (patch_dir / "original.py").exists()
assert (patch_dir / "main.py").exists()
assert (patch_dir / "edit.diff").exists()
# Verify the updated content matches what's in the file
file_content = output_path.read_text("utf-8")
assert file_content == updated_content
# ============================================================================
# Tests for Indentation Correction Functionality
# ============================================================================
def test_find_indented_match_exact_match():
"""Test _find_indented_match when exact match is found."""
original = """def function():
x = 1
y = 2
return x + y"""
search = "x = 1"
matched, pos = _find_indented_match(search, original)
assert matched == search
assert pos != -1
assert original[pos : pos + len(matched)] == matched
def test_find_indented_match_needs_indentation():
"""Test _find_indented_match when indentation correction is needed."""
original = """def function():
x = 1
y = 2
return x + y"""
# Search text without proper indentation
search = "x = 1\ny = 2"
matched, pos = _find_indented_match(search, original)
expected = " x = 1\n y = 2"
assert matched == expected
assert pos != -1
assert original[pos : pos + len(matched)] == matched
def test_find_indented_match_multiline_with_relative_indentation():
"""Test _find_indented_match with multiline blocks having relative indentation."""
original = """def function():
if True:
x = 1
if nested:
y = 2
return x + y"""
# Search text without proper base indentation but with relative indentation
search = """if True:
x = 1
if nested:
y = 2"""
matched, pos = _find_indented_match(search, original)
expected = """ if True:
x = 1
if nested:
y = 2"""
assert matched == expected
assert pos != -1
def test_find_indented_match_not_found():
"""Test _find_indented_match when text is not found."""
original = """def function():
x = 1
return x"""
search = "z = 3"
matched, pos = _find_indented_match(search, original)
assert matched == ""
assert pos == -1
def test_find_indented_match_empty_search():
"""Test _find_indented_match with empty search text."""
original = "def function():\n pass"
search = ""
matched, pos = _find_indented_match(search, original)
assert matched == ""
assert pos == -1
def test_apply_indentation_to_replace():
"""Test _apply_indentation_to_replace function."""
replace_text = """x = 10
if x > 5:
print("big")
else:
print("small")"""
indent_str = " " # 4 spaces
result = _apply_indentation_to_replace(replace_text, indent_str)
expected = """ x = 10
if x > 5:
print("big")
else:
print("small")"""
assert result == expected
def test_apply_indentation_to_replace_empty_lines():
"""Test _apply_indentation_to_replace with empty lines."""
replace_text = """x = 1
y = 2"""
indent_str = " "
result = _apply_indentation_to_replace(replace_text, indent_str)
expected = """ x = 1
y = 2"""
assert result == expected
def test_strip_trailing_whitespace():
"""Test _strip_trailing_whitespace function."""
# Create text with trailing whitespace programmatically to avoid linting issues
text_with_trailing = "line1 \nline2\t\nline3\nline4 \t "
result = _strip_trailing_whitespace(text_with_trailing)
expected = "line1\nline2\nline3\nline4"
assert result == expected
# ============================================================================
# Integration Tests for Indentation Correction in apply_diff_patch
# ============================================================================
def test_indentation_correction_in_patch():
"""Test that apply_diff_patch correctly handles indentation mismatches."""
original_content = """# EVOLVE-BLOCK-START
def calculate():
centers = compute_centers()
radius = get_radius()
area = math.pi * radius ** 2
return area
# EVOLVE-BLOCK-END"""
# Patch with incorrect indentation
patch_str = """<<<<<<< SEARCH
centers = compute_centers()
radius = get_radius()
=======
centers = compute_new_centers()
radius = get_new_radius()
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 1
assert error is None
assert "compute_new_centers()" in updated_content
assert "get_new_radius()" in updated_content
# Verify indentation is preserved
assert " centers = compute_new_centers()" in updated_content
def test_indentation_correction_multiline_patch():
"""Test indentation correction with multiline search/replace blocks."""
original_content = """# EVOLVE-BLOCK-START
def process_data():
if condition:
data = load_data()
result = process(data)
return result
return None
# EVOLVE-BLOCK-END"""
# Patch with no indentation
patch_str = """<<<<<<< SEARCH
if condition:
data = load_data()
result = process(data)
return result
=======
if new_condition:
data = load_new_data()
result = new_process(data)
return enhanced_result
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 1
assert error is None
assert "new_condition" in updated_content
assert "load_new_data()" in updated_content
# Verify proper indentation is applied
assert " if new_condition:" in updated_content
assert " data = load_new_data()" in updated_content
def test_indentation_correction_with_trailing_whitespace():
"""Test that indentation correction works with trailing whitespace."""
# Create content with trailing whitespace programmatically
original_content = """# EVOLVE-BLOCK-START
def func():
x = 1
y = 2
return x + y
# EVOLVE-BLOCK-END"""
# Patch with trailing whitespace and incorrect indentation
patch_str = """<<<<<<< SEARCH
x = 1
y = 2
=======
x = 10
y = 20
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 1
assert error is None
assert "x = 10" in updated_content
assert "y = 20" in updated_content
# Verify trailing whitespace is stripped
lines = updated_content.split("\n")
for line in lines:
assert line == line.rstrip(), f"Line has trailing whitespace: {repr(line)}"
def test_indentation_correction_fails_gracefully():
"""Test that indentation correction fails gracefully when match cannot be found."""
original_content = """# EVOLVE-BLOCK-START
def func():
x = 1
y = 2
return x + y
# EVOLVE-BLOCK-END"""
# Patch with text that doesn't exist
patch_str = """<<<<<<< SEARCH
z = 3
w = 4
=======
z = 30
w = 40
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 0
assert error is not None
assert "SEARCH text not found" in error
assert updated_content == original_content # Should remain unchanged
def test_mixed_indentation_styles():
"""Test handling of mixed indentation styles (spaces and tabs)."""
original_content = """# EVOLVE-BLOCK-START
def func():
\tx = 1 # Tab indented
\ty = 2 # Tab indented
\treturn x + y
# EVOLVE-BLOCK-END"""
# Search with space indentation (should match tab indented lines)
patch_str = """<<<<<<< SEARCH
x = 1 # Tab indented
y = 2 # Tab indented
=======
x = 10
y = 20
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 1
assert error is None
assert "x = 10" in updated_content
# Verify original tab indentation is preserved
assert "\tx = 10" in updated_content
assert "\ty = 20" in updated_content
def test_indentation_with_empty_lines_in_search():
"""Test indentation correction with empty lines in search block."""
original_content = """# EVOLVE-BLOCK-START
def func():
x = 1
y = 2
return x + y
# EVOLVE-BLOCK-END"""
patch_str = """<<<<<<< SEARCH
x = 1
y = 2
=======
x = 10
y = 20
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 1
assert error is None
assert " x = 10" in updated_content
assert " y = 20" in updated_content
def test_indentation_correction_preserves_mutable_regions():
"""Test that indentation correction respects EVOLVE-BLOCK boundaries."""
original_content = """# Immutable section
def immutable_func():
x = 1
return x
# EVOLVE-BLOCK-START
def mutable_func():
y = 2
return y
# EVOLVE-BLOCK-END
# Another immutable section
def another_immutable():
z = 3
return z"""
# Try to patch something in immutable region (should fail)
patch_str = """<<<<<<< SEARCH
x = 1
=======
x = 100
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 0
assert error is not None
assert "outside EVOLVE-BLOCK" in error
def test_insertion_with_indentation():
"""Test insertion (empty search) with proper indentation context."""
original_content = """# EVOLVE-BLOCK-START
def func():
x = 1
return x
# EVOLVE-BLOCK-END"""
# Empty search = insertion at end of mutable region
patch_str = """<<<<<<< SEARCH
=======
# New comment
y = 2
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 1
assert error is None
assert "# New comment" in updated_content
assert "y = 2" in updated_content
# ============================================================================
# Tests for Enhanced Error Messages
# ============================================================================
def test_enhanced_search_not_found_error():
"""Test that search not found errors provide helpful suggestions."""
original_content = """# EVOLVE-BLOCK-START
def calculate():
centers = compute_centers()
radius = get_radius()
area = math.pi * radius ** 2
return area
# EVOLVE-BLOCK-END"""
# Search for similar but not exact text
patch_str = """<<<<<<< SEARCH
centers = compute_center()
=======
centers = compute_new_centers()
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 0
assert error is not None
assert "SEARCH text not found" in error
def test_enhanced_evolve_block_violation_error():
"""Test that EVOLVE-BLOCK violation errors show context and suggestions."""
original_content = """# Immutable header
import os
import sys
# EVOLVE-BLOCK-START
def mutable_function():
return "editable"
# EVOLVE-BLOCK-END
# Immutable footer
if __name__ == "__main__":
main()"""
# Try to edit immutable code
patch_str = """<<<<<<< SEARCH
import os
=======
import os
import json
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 0
assert error is not None
assert "Attempted to edit outside EVOLVE-BLOCK regions" in error
assert "Context around found text:" in error
assert "Available editable regions" in error
assert "Line" in error # Should show line numbers in context
assert "Suggestions:" in error
def test_enhanced_no_evolve_block_error():
"""Test error message when no EVOLVE-BLOCK regions exist."""
original_content = """def regular_function():
return "no evolve blocks here"
if __name__ == "__main__":
print("Hello world")"""
# Try to insert into file with no EVOLVE-BLOCK
patch_str = """<<<<<<< SEARCH
=======
# New comment
new_var = 42
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 0
assert error is not None
assert "Cannot perform insertion: No EVOLVE-BLOCK regions found" in error
assert "Current file structure:" in error
assert "Expected format:" in error
assert "EVOLVE-BLOCK-START" in error
assert "Suggestions:" in error
def test_enhanced_error_with_multiline_search():
"""Test enhanced error messages with multiline search blocks."""
original_content = """# EVOLVE-BLOCK-START
def process():
data = load_data()
result = transform(data)
return result
# EVOLVE-BLOCK-END"""
# Search for multiline block with typo
patch_str = """<<<<<<< SEARCH
data = load_data()
result = transform_data(data)
return result
=======
data = load_new_data()
result = new_transform(data)
return result
>>>>>>> REPLACE"""
result = apply_diff_patch(
patch_str=patch_str,
original_str=original_content,
language="python",
verbose=False,
)
updated_content, num_applied, output_path, error, patch_txt, diff_path = result
assert num_applied == 0
assert error is not None