File size: 3,389 Bytes
6a32325
71fa486
 
6a32325
71fa486
 
6a32325
71fa486
6a32325
71fa486
6a32325
 
71fa486
 
6a32325
 
 
 
 
71fa486
 
6a32325
71fa486
6a32325
 
 
71fa486
6a32325
 
 
 
 
71fa486
 
6a32325
 
 
 
 
 
71fa486
 
6a32325
 
 
71fa486
6a32325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71fa486
6a32325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71fa486
6a32325
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Test all 7 tasks: seed, golden migration, grade, reset, close."""
import sys
import os
import sqlite3

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'OpenEnv', 'src'))

import seeds
from server.grader import StateReconciler
from server.environment import DbMigrationEnvironment
from models import MigrationAction


def test_golden_migration(task_name: str) -> None:
    """Test that golden migration produces a near-perfect grader score."""
    config = seeds.TASKS[task_name]
    
    # 1. Create DB and seed
    conn = sqlite3.connect(":memory:")
    conn.execute("PRAGMA foreign_keys = ON")
    config["seed_fn"](conn)
    
    # 2. Score before migration (should be low)
    reconciler = StateReconciler(task_name)
    score_before = reconciler.score(conn)
    
    # 3. Run golden migration
    config["golden_fn"](conn)
    
    # 4. Score after migration (should be >0.90)
    score_after = reconciler.score(conn)
    
    conn.close()
    
    status = "PASS" if score_after >= 0.90 else "FAIL"
    print(f"  [{status}] {task_name}: before={score_before:.2f} after={score_after:.2f}")
    
    if score_after < 0.90:
        raise AssertionError(f"{task_name}: golden migration only scored {score_after:.2f}")


def test_environment_lifecycle(task_name: str) -> None:
    """Test that environment can reset, step, and close without crashes."""
    env = DbMigrationEnvironment(task_name=task_name)
    obs = env.reset()
    
    assert not obs.done, f"{task_name}: obs.done should be False after reset"
    assert obs.step_number == 0, f"{task_name}: step should be 0 after reset"
    assert obs.current_schema_sql, f"{task_name}: should have current schema"
    assert obs.target_schema_sql, f"{task_name}: should have target schema"
    
    # Run a SELECT to verify data passthrough
    action = MigrationAction(
        sql_command="SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'",
        reasoning="List tables",
        submit_final=False,
    )
    obs = env.step(action)
    assert "rows total" in obs.last_execution_result or "Query returned" in obs.last_execution_result, \
        f"{task_name}: SELECT should return formatted data, got: {obs.last_execution_result[:100]}"
    
    env.close()
    print(f"  [PASS] {task_name}: environment lifecycle OK (SELECT data passthrough verified)")


def main():
    print("=" * 60)
    print("Testing Golden Migrations (all 7 tasks)")
    print("=" * 60)
    
    errors = []
    for task_name in seeds.TASKS:
        try:
            test_golden_migration(task_name)
        except Exception as e:
            errors.append(f"Golden {task_name}: {e}")
    
    print()
    print("=" * 60)
    print("Testing Environment Lifecycle (all 7 tasks)")
    print("=" * 60)
    
    for task_name in seeds.TASKS:
        try:
            test_environment_lifecycle(task_name)
        except Exception as e:
            errors.append(f"Lifecycle {task_name}: {e}")
    
    print()
    if errors:
        print("=" * 60)
        print(f"FAILURES ({len(errors)}):")
        for e in errors:
            print(f"  ✗ {e}")
        print("=" * 60)
        sys.exit(1)
    else:
        print("=" * 60)
        print("ALL 7 TASKS PASSED!")
        print("=" * 60)


if __name__ == "__main__":
    main()