datacentric-env / inference.py
Aswini-Kumar's picture
Upload inference.py with huggingface_hub
1341a55 verified
"""
inference.py — Phase 1/2 automated check entry point (v0.3).
Demonstrates the full query/apply episode with session management.
"""
import requests
import sys
import json
BASE_URL = "http://localhost:8000"
def run_episode(base_url: str = BASE_URL):
base_url = base_url.rstrip("/")
print(f"DataCentric-Env v0.3 | {base_url}\n")
# Phase 1: reset — get session_id
resp = requests.post(f"{base_url}/reset", json={}, timeout=30)
assert resp.status_code == 200, f"Reset failed: {resp.status_code} {resp.text}"
obs = resp.json()
assert "session_id" in obs, "Missing session_id in reset response"
assert "current_accuracy" in obs, "Missing current_accuracy"
assert "budget_remaining" in obs, "Missing budget_remaining"
assert "available_actions" in obs, "Missing available_actions"
session_id = obs["session_id"]
print(f"Phase 1 reset: PASS")
print(f" session_id = {session_id}")
print(f" accuracy = {obs['current_accuracy']} -> target {obs['target_accuracy']}")
print(f" budget = {obs['budget_remaining']}")
print(f" missing_pct = {obs['dataset_stats']['missing_pct']}")
print(f" balance_ratio = {obs['dataset_stats']['balance_ratio']}\n")
# Phase 2: run a full episode
rewards = []
step_count = 0
queried_agents = set()
while True:
budget = obs.get("budget_remaining", 0)
if budget <= 0 or obs.get("done"):
break
pending = obs.get("pending_recommendations", {})
stats = obs.get("dataset_stats", {})
# Strategy: query analyst first, then follow plan, then apply
if not queried_agents:
action = {"session_id": session_id, "action": "query_analyst"}
elif not pending:
# Pick next agent based on dataset stats
if stats.get("missing_pct", 0) > 0.05 and "cleaner" not in queried_agents:
action = {"session_id": session_id, "action": "query_cleaner"}
elif stats.get("balance_ratio", 1.0) < 0.45 and "balancer" not in queried_agents:
action = {"session_id": session_id, "action": "query_balancer"}
elif "augmenter" not in queried_agents:
action = {"session_id": session_id, "action": "query_augmenter", "target_class": 1}
else:
break # out of ideas
else:
# Apply highest priority pending rec
best_rec_id = min(pending, key=lambda k: pending[k].get("priority", 99))
action = {"session_id": session_id, "action": "apply", "rec_id": best_rec_id}
result = requests.post(f"{base_url}/step", json=action, timeout=30)
assert result.status_code == 200, f"Step failed: {result.status_code} {result.text}"
result = result.json()
if "error" in result and "exploit_detected" not in result:
print(f" Step {step_count+1}: ERROR - {result['error']}")
break
reward = result.get("reward", 0.0)
info = result.get("info", {})
step_count += 1
rewards.append(reward)
assert isinstance(reward, float), f"Reward must be float, got {type(reward)}"
assert 0.0 < reward < 1.0, f"Reward {reward} out of range (0.0, 1.0)"
action_type = info.get("action_type", "?")
if action_type == "query":
agent = info.get("agent_queried", "?")
queried_agents.add(agent)
n_recs = info.get("n_recommendations", 0)
print(f" Step {step_count:02d}: QUERY {agent:12s} -> {n_recs} recs | reward={reward:.4f} | budget={info.get('budget_remaining','?')}")
else:
print(f" Step {step_count:02d}: APPLY {info.get('rec_type','?'):15s} -> {info.get('prev_accuracy','?'):.4f}->{info.get('new_accuracy','?'):.4f} | reward={reward:.4f} | success={info.get('success',False)}")
obs = result.get("observation", obs)
if result.get("done"):
print(f"\n Episode done. Success={info.get('success', False)}")
break
# Verify metrics endpoint
m = requests.get(f"{base_url}/metrics", timeout=10).json()
assert "sessions" in m, "Missing sessions in /metrics"
print(f"\nPhase 2 full episode: PASS")
print(f" Steps: {step_count} | Mean reward: {sum(rewards)/max(len(rewards),1):.4f}")
print(f" All {len(rewards)} rewards in (0.0, 1.0): {all(0.0 < r < 1.0 for r in rewards)}")
print(f" Active sessions: {m['sessions']['active_sessions']}")
print("\nAll automated checks passed.")
return True
if __name__ == "__main__":
url = sys.argv[1] if len(sys.argv) > 1 else BASE_URL
run_episode(url)