code-gen-assistant / scripts /test_repair_loop.py
Rushabh147's picture
Initial deploy to HF Spaces (clean history, LFS for all binaries)
b89e6d6
Raw
History Blame Contribute Delete
2.67 kB
"""Throwaway script: exercise the agentic repair loop on a tricky problem.
The intent is chosen to be genuinely hard for a small model: balanced-bracket
checking with all three delimiter types. The check_fn appends a strict test
harness so the sandbox can report a traceback on failure.
Run:
python scripts/test_repair_loop.py
Expected output: iteration trace showing whether the model self-corrects.
"""
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[1]))
from src.agent.repair_loop import make_repair_generator, repair_loop
from src.rag.generator import CodeAssistant
# ── test harness appended to every generated function ─────────────────────
_TEST_HARNESS = '''
# --- auto-injected test harness ---
assert is_balanced("") is True, "empty string should be balanced"
assert is_balanced("()") is True, "simple parens"
assert is_balanced("([{}])") is True, "nested mix"
assert is_balanced("([)]") is False, "wrong order"
assert is_balanced("(") is False, "unclosed"
assert is_balanced("{[}]") is False, "interleaved wrong"
assert is_balanced("((()))") is True, "deep nesting"
print("all tests passed")
'''
INTENT = (
"Write a Python function called `is_balanced` that returns True if the input "
"string has balanced parentheses (), square brackets [], and curly braces {}, "
"and False otherwise. Use a stack."
)
def check_fn(code: str) -> str:
"""Append the test harness to the generated function."""
return code + _TEST_HARNESS
def main():
print("=" * 60)
print("REPAIR LOOP TEST")
print("Intent:", INTENT[:80] + "...")
print("=" * 60)
print("\n[setup] Loading CodeAssistant (model + index)...")
assistant = CodeAssistant.from_config(with_index=True)
generate_fn = make_repair_generator(assistant)
print("[repair] Starting loop (max_iters=3)...\n")
trace = repair_loop(INTENT, generate_fn, check_fn, max_iters=3, timeout=10.0)
print("\n" + "=" * 60)
print("ITERATION TRACE")
print("=" * 60)
for i, (code, error) in enumerate(trace.history, 1):
print(f"\n--- Iteration {i} ---")
print("Generated code:")
print(code)
if error:
print(f"\nError:\n {error}")
else:
print("\nResult: PASSED")
print("\n" + "=" * 60)
print(f"Outcome: {'SUCCESS' if trace.success else 'BEST EFFORT (did not pass)'}")
print(f"Iterations used: {trace.iterations}")
print("=" * 60)
if trace.success:
print("\nFinal passing code:")
print(trace.final_code)
if __name__ == "__main__":
main()