Aswini-Kumar commited on
Commit
f02b845
·
verified ·
1 Parent(s): f89ffa8

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +68 -0
inference.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py — Required by Phase 1 automated checks.
3
+ Runs a complete episode against the local environment server.
4
+ """
5
+
6
+ import requests
7
+ import json
8
+ import sys
9
+
10
+ BASE_URL = "http://localhost:8000"
11
+
12
+
13
+ def run_episode(base_url: str = BASE_URL):
14
+ print(f"Connecting to environment at {base_url}")
15
+
16
+ # Phase 1: reset
17
+ obs = requests.post(f"{base_url}/reset", json={}, timeout=30)
18
+ assert obs.status_code == 200, f"Reset failed with status {obs.status_code}"
19
+ obs = obs.json()
20
+ assert "current_accuracy" in obs, "Missing current_accuracy in reset response"
21
+ assert "budget_remaining" in obs, "Missing budget_remaining in reset response"
22
+ assert "available_tools" in obs, "Missing available_tools in reset response"
23
+ print(f"Phase 1 reset: PASS | initial accuracy={obs['current_accuracy']}")
24
+
25
+ # Phase 2: run a full episode
26
+ tools = ["cleaner", "augmenter", "balancer", "validator"]
27
+ step_count = 0
28
+
29
+ while True:
30
+ # Simple greedy heuristic: pick tool based on dataset stats
31
+ stats = obs.get("dataset_stats", {})
32
+ if stats.get("missing_pct", 0) > 0.05:
33
+ action = {"agent": "cleaner", "target": "all", "strategy": "median_impute"}
34
+ elif stats.get("balance_ratio", 1.0) < 0.3:
35
+ action = {"agent": "balancer", "strategy": "undersample"}
36
+ else:
37
+ action = {"agent": "augmenter"}
38
+
39
+ result = requests.post(f"{base_url}/step", json=action, timeout=30)
40
+ assert result.status_code == 200, f"Step failed with status {result.status_code}"
41
+ result = result.json()
42
+
43
+ reward = result.get("reward")
44
+ done = result.get("done")
45
+ info = result.get("info", {})
46
+ step_count += 1
47
+
48
+ assert isinstance(reward, float), f"Reward must be float, got {type(reward)}"
49
+ assert 0.0 < reward < 1.0, f"Reward {reward} out of valid range (0.0, 1.0)"
50
+ assert isinstance(done, bool), f"Done must be bool, got {type(done)}"
51
+
52
+ print(f" Step {step_count:02d}: agent={action['agent']} | reward={reward:.4f} | "
53
+ f"accuracy={info.get('new_accuracy', '?')} | done={done}")
54
+
55
+ obs = result.get("observation", obs)
56
+
57
+ if done or step_count >= 10:
58
+ break
59
+
60
+ success = info.get("success", False)
61
+ print(f"\nEpisode complete — {step_count} steps | success={success}")
62
+ print("All automated checks passed.")
63
+ return True
64
+
65
+
66
+ if __name__ == "__main__":
67
+ url = sys.argv[1] if len(sys.argv) > 1 else BASE_URL
68
+ run_episode(url)