File size: 4,674 Bytes
f02b845
1341a55
 
f02b845
 
 
09b7578
f02b845
 
 
 
 
1341a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09b7578
1341a55
 
f02b845
1341a55
09b7578
1341a55
09b7578
f02b845
09b7578
1341a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09b7578
1341a55
 
 
 
 
 
 
 
 
 
09b7578
 
1341a55
f02b845
 
09b7578
f02b845
 
1341a55
f02b845
09b7578
 
 
1341a55
 
 
09b7578
1341a55
f02b845
 
 
09b7578
1341a55
f02b845
 
1341a55
 
 
09b7578
 
1341a55
09b7578
1341a55
09b7578
f02b845
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
inference.py — Phase 1/2 automated check entry point (v0.3).
Demonstrates the full query/apply episode with session management.
"""
import requests
import sys
import json

BASE_URL = "http://localhost:8000"


def run_episode(base_url: str = BASE_URL):
    base_url = base_url.rstrip("/")
    print(f"DataCentric-Env v0.3 | {base_url}\n")

    # Phase 1: reset — get session_id
    resp = requests.post(f"{base_url}/reset", json={}, timeout=30)
    assert resp.status_code == 200, f"Reset failed: {resp.status_code} {resp.text}"
    obs = resp.json()
    assert "session_id" in obs, "Missing session_id in reset response"
    assert "current_accuracy" in obs, "Missing current_accuracy"
    assert "budget_remaining" in obs, "Missing budget_remaining"
    assert "available_actions" in obs, "Missing available_actions"

    session_id = obs["session_id"]
    print(f"Phase 1 reset: PASS")
    print(f"  session_id     = {session_id}")
    print(f"  accuracy       = {obs['current_accuracy']} -> target {obs['target_accuracy']}")
    print(f"  budget         = {obs['budget_remaining']}")
    print(f"  missing_pct    = {obs['dataset_stats']['missing_pct']}")
    print(f"  balance_ratio  = {obs['dataset_stats']['balance_ratio']}\n")

    # Phase 2: run a full episode
    rewards = []
    step_count = 0
    queried_agents = set()

    while True:
        budget = obs.get("budget_remaining", 0)
        if budget <= 0 or obs.get("done"):
            break

        pending = obs.get("pending_recommendations", {})
        stats = obs.get("dataset_stats", {})

        # Strategy: query analyst first, then follow plan, then apply
        if not queried_agents:
            action = {"session_id": session_id, "action": "query_analyst"}
        elif not pending:
            # Pick next agent based on dataset stats
            if stats.get("missing_pct", 0) > 0.05 and "cleaner" not in queried_agents:
                action = {"session_id": session_id, "action": "query_cleaner"}
            elif stats.get("balance_ratio", 1.0) < 0.45 and "balancer" not in queried_agents:
                action = {"session_id": session_id, "action": "query_balancer"}
            elif "augmenter" not in queried_agents:
                action = {"session_id": session_id, "action": "query_augmenter", "target_class": 1}
            else:
                break  # out of ideas
        else:
            # Apply highest priority pending rec
            best_rec_id = min(pending, key=lambda k: pending[k].get("priority", 99))
            action = {"session_id": session_id, "action": "apply", "rec_id": best_rec_id}

        result = requests.post(f"{base_url}/step", json=action, timeout=30)
        assert result.status_code == 200, f"Step failed: {result.status_code} {result.text}"
        result = result.json()

        if "error" in result and "exploit_detected" not in result:
            print(f"  Step {step_count+1}: ERROR - {result['error']}")
            break

        reward = result.get("reward", 0.0)
        info = result.get("info", {})
        step_count += 1
        rewards.append(reward)

        assert isinstance(reward, float), f"Reward must be float, got {type(reward)}"
        assert 0.0 < reward < 1.0, f"Reward {reward} out of range (0.0, 1.0)"

        action_type = info.get("action_type", "?")
        if action_type == "query":
            agent = info.get("agent_queried", "?")
            queried_agents.add(agent)
            n_recs = info.get("n_recommendations", 0)
            print(f"  Step {step_count:02d}: QUERY  {agent:12s} -> {n_recs} recs | reward={reward:.4f} | budget={info.get('budget_remaining','?')}")
        else:
            print(f"  Step {step_count:02d}: APPLY  {info.get('rec_type','?'):15s} -> {info.get('prev_accuracy','?'):.4f}->{info.get('new_accuracy','?'):.4f} | reward={reward:.4f} | success={info.get('success',False)}")

        obs = result.get("observation", obs)

        if result.get("done"):
            print(f"\n  Episode done. Success={info.get('success', False)}")
            break

    # Verify metrics endpoint
    m = requests.get(f"{base_url}/metrics", timeout=10).json()
    assert "sessions" in m, "Missing sessions in /metrics"

    print(f"\nPhase 2 full episode: PASS")
    print(f"  Steps: {step_count} | Mean reward: {sum(rewards)/max(len(rewards),1):.4f}")
    print(f"  All {len(rewards)} rewards in (0.0, 1.0): {all(0.0 < r < 1.0 for r in rewards)}")
    print(f"  Active sessions: {m['sessions']['active_sessions']}")
    print("\nAll automated checks passed.")
    return True


if __name__ == "__main__":
    url = sys.argv[1] if len(sys.argv) > 1 else BASE_URL
    run_episode(url)