Spaces:
Sleeping
Sleeping
| """ | |
| inference.py — Phase 1/2 automated check entry point (v0.3). | |
| Demonstrates the full query/apply episode with session management. | |
| """ | |
| import requests | |
| import sys | |
| import json | |
| BASE_URL = "http://localhost:8000" | |
| def run_episode(base_url: str = BASE_URL): | |
| base_url = base_url.rstrip("/") | |
| print(f"DataCentric-Env v0.3 | {base_url}\n") | |
| # Phase 1: reset — get session_id | |
| resp = requests.post(f"{base_url}/reset", json={}, timeout=30) | |
| assert resp.status_code == 200, f"Reset failed: {resp.status_code} {resp.text}" | |
| obs = resp.json() | |
| assert "session_id" in obs, "Missing session_id in reset response" | |
| assert "current_accuracy" in obs, "Missing current_accuracy" | |
| assert "budget_remaining" in obs, "Missing budget_remaining" | |
| assert "available_actions" in obs, "Missing available_actions" | |
| session_id = obs["session_id"] | |
| print(f"Phase 1 reset: PASS") | |
| print(f" session_id = {session_id}") | |
| print(f" accuracy = {obs['current_accuracy']} -> target {obs['target_accuracy']}") | |
| print(f" budget = {obs['budget_remaining']}") | |
| print(f" missing_pct = {obs['dataset_stats']['missing_pct']}") | |
| print(f" balance_ratio = {obs['dataset_stats']['balance_ratio']}\n") | |
| # Phase 2: run a full episode | |
| rewards = [] | |
| step_count = 0 | |
| queried_agents = set() | |
| while True: | |
| budget = obs.get("budget_remaining", 0) | |
| if budget <= 0 or obs.get("done"): | |
| break | |
| pending = obs.get("pending_recommendations", {}) | |
| stats = obs.get("dataset_stats", {}) | |
| # Strategy: query analyst first, then follow plan, then apply | |
| if not queried_agents: | |
| action = {"session_id": session_id, "action": "query_analyst"} | |
| elif not pending: | |
| # Pick next agent based on dataset stats | |
| if stats.get("missing_pct", 0) > 0.05 and "cleaner" not in queried_agents: | |
| action = {"session_id": session_id, "action": "query_cleaner"} | |
| elif stats.get("balance_ratio", 1.0) < 0.45 and "balancer" not in queried_agents: | |
| action = {"session_id": session_id, "action": "query_balancer"} | |
| elif "augmenter" not in queried_agents: | |
| action = {"session_id": session_id, "action": "query_augmenter", "target_class": 1} | |
| else: | |
| break # out of ideas | |
| else: | |
| # Apply highest priority pending rec | |
| best_rec_id = min(pending, key=lambda k: pending[k].get("priority", 99)) | |
| action = {"session_id": session_id, "action": "apply", "rec_id": best_rec_id} | |
| result = requests.post(f"{base_url}/step", json=action, timeout=30) | |
| assert result.status_code == 200, f"Step failed: {result.status_code} {result.text}" | |
| result = result.json() | |
| if "error" in result and "exploit_detected" not in result: | |
| print(f" Step {step_count+1}: ERROR - {result['error']}") | |
| break | |
| reward = result.get("reward", 0.0) | |
| info = result.get("info", {}) | |
| step_count += 1 | |
| rewards.append(reward) | |
| assert isinstance(reward, float), f"Reward must be float, got {type(reward)}" | |
| assert 0.0 < reward < 1.0, f"Reward {reward} out of range (0.0, 1.0)" | |
| action_type = info.get("action_type", "?") | |
| if action_type == "query": | |
| agent = info.get("agent_queried", "?") | |
| queried_agents.add(agent) | |
| n_recs = info.get("n_recommendations", 0) | |
| print(f" Step {step_count:02d}: QUERY {agent:12s} -> {n_recs} recs | reward={reward:.4f} | budget={info.get('budget_remaining','?')}") | |
| else: | |
| print(f" Step {step_count:02d}: APPLY {info.get('rec_type','?'):15s} -> {info.get('prev_accuracy','?'):.4f}->{info.get('new_accuracy','?'):.4f} | reward={reward:.4f} | success={info.get('success',False)}") | |
| obs = result.get("observation", obs) | |
| if result.get("done"): | |
| print(f"\n Episode done. Success={info.get('success', False)}") | |
| break | |
| # Verify metrics endpoint | |
| m = requests.get(f"{base_url}/metrics", timeout=10).json() | |
| assert "sessions" in m, "Missing sessions in /metrics" | |
| print(f"\nPhase 2 full episode: PASS") | |
| print(f" Steps: {step_count} | Mean reward: {sum(rewards)/max(len(rewards),1):.4f}") | |
| print(f" All {len(rewards)} rewards in (0.0, 1.0): {all(0.0 < r < 1.0 for r in rewards)}") | |
| print(f" Active sessions: {m['sessions']['active_sessions']}") | |
| print("\nAll automated checks passed.") | |
| return True | |
| if __name__ == "__main__": | |
| url = sys.argv[1] if len(sys.argv) > 1 else BASE_URL | |
| run_episode(url) | |