Spaces:
Sleeping
Sleeping
File size: 4,674 Bytes
f02b845 1341a55 f02b845 09b7578 f02b845 1341a55 09b7578 1341a55 f02b845 1341a55 09b7578 1341a55 09b7578 f02b845 09b7578 1341a55 09b7578 1341a55 09b7578 1341a55 f02b845 09b7578 f02b845 1341a55 f02b845 09b7578 1341a55 09b7578 1341a55 f02b845 09b7578 1341a55 f02b845 1341a55 09b7578 1341a55 09b7578 1341a55 09b7578 f02b845 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | """
inference.py — Phase 1/2 automated check entry point (v0.3).
Demonstrates the full query/apply episode with session management.
"""
import requests
import sys
import json
BASE_URL = "http://localhost:8000"
def run_episode(base_url: str = BASE_URL):
base_url = base_url.rstrip("/")
print(f"DataCentric-Env v0.3 | {base_url}\n")
# Phase 1: reset — get session_id
resp = requests.post(f"{base_url}/reset", json={}, timeout=30)
assert resp.status_code == 200, f"Reset failed: {resp.status_code} {resp.text}"
obs = resp.json()
assert "session_id" in obs, "Missing session_id in reset response"
assert "current_accuracy" in obs, "Missing current_accuracy"
assert "budget_remaining" in obs, "Missing budget_remaining"
assert "available_actions" in obs, "Missing available_actions"
session_id = obs["session_id"]
print(f"Phase 1 reset: PASS")
print(f" session_id = {session_id}")
print(f" accuracy = {obs['current_accuracy']} -> target {obs['target_accuracy']}")
print(f" budget = {obs['budget_remaining']}")
print(f" missing_pct = {obs['dataset_stats']['missing_pct']}")
print(f" balance_ratio = {obs['dataset_stats']['balance_ratio']}\n")
# Phase 2: run a full episode
rewards = []
step_count = 0
queried_agents = set()
while True:
budget = obs.get("budget_remaining", 0)
if budget <= 0 or obs.get("done"):
break
pending = obs.get("pending_recommendations", {})
stats = obs.get("dataset_stats", {})
# Strategy: query analyst first, then follow plan, then apply
if not queried_agents:
action = {"session_id": session_id, "action": "query_analyst"}
elif not pending:
# Pick next agent based on dataset stats
if stats.get("missing_pct", 0) > 0.05 and "cleaner" not in queried_agents:
action = {"session_id": session_id, "action": "query_cleaner"}
elif stats.get("balance_ratio", 1.0) < 0.45 and "balancer" not in queried_agents:
action = {"session_id": session_id, "action": "query_balancer"}
elif "augmenter" not in queried_agents:
action = {"session_id": session_id, "action": "query_augmenter", "target_class": 1}
else:
break # out of ideas
else:
# Apply highest priority pending rec
best_rec_id = min(pending, key=lambda k: pending[k].get("priority", 99))
action = {"session_id": session_id, "action": "apply", "rec_id": best_rec_id}
result = requests.post(f"{base_url}/step", json=action, timeout=30)
assert result.status_code == 200, f"Step failed: {result.status_code} {result.text}"
result = result.json()
if "error" in result and "exploit_detected" not in result:
print(f" Step {step_count+1}: ERROR - {result['error']}")
break
reward = result.get("reward", 0.0)
info = result.get("info", {})
step_count += 1
rewards.append(reward)
assert isinstance(reward, float), f"Reward must be float, got {type(reward)}"
assert 0.0 < reward < 1.0, f"Reward {reward} out of range (0.0, 1.0)"
action_type = info.get("action_type", "?")
if action_type == "query":
agent = info.get("agent_queried", "?")
queried_agents.add(agent)
n_recs = info.get("n_recommendations", 0)
print(f" Step {step_count:02d}: QUERY {agent:12s} -> {n_recs} recs | reward={reward:.4f} | budget={info.get('budget_remaining','?')}")
else:
print(f" Step {step_count:02d}: APPLY {info.get('rec_type','?'):15s} -> {info.get('prev_accuracy','?'):.4f}->{info.get('new_accuracy','?'):.4f} | reward={reward:.4f} | success={info.get('success',False)}")
obs = result.get("observation", obs)
if result.get("done"):
print(f"\n Episode done. Success={info.get('success', False)}")
break
# Verify metrics endpoint
m = requests.get(f"{base_url}/metrics", timeout=10).json()
assert "sessions" in m, "Missing sessions in /metrics"
print(f"\nPhase 2 full episode: PASS")
print(f" Steps: {step_count} | Mean reward: {sum(rewards)/max(len(rewards),1):.4f}")
print(f" All {len(rewards)} rewards in (0.0, 1.0): {all(0.0 < r < 1.0 for r in rewards)}")
print(f" Active sessions: {m['sessions']['active_sessions']}")
print("\nAll automated checks passed.")
return True
if __name__ == "__main__":
url = sys.argv[1] if len(sys.argv) > 1 else BASE_URL
run_episode(url)
|