Spaces:
Sleeping
Sleeping
Upload evaluate.py with huggingface_hub
Browse files- evaluate.py +78 -0
evaluate.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
evaluate.py β Baseline vs trained agent comparison
|
| 3 |
+
Run before and after training to measure improvement.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python evaluate.py --url http://localhost:8000
|
| 7 |
+
python evaluate.py --url https://your-hf-username-datacentric-env.hf.space
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
import json
|
| 12 |
+
import random
|
| 13 |
+
import argparse
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser(description="Evaluate DataCentric-Env agent")
|
| 17 |
+
parser.add_argument("--url", default="http://localhost:8000", help="Environment server URL")
|
| 18 |
+
parser.add_argument("--episodes", type=int, default=20, help="Number of evaluation episodes")
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
ENV_URL = args.url.rstrip("/")
|
| 22 |
+
N_EPISODES = args.episodes
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def random_agent_episode():
|
| 26 |
+
"""Baseline: random tool selection."""
|
| 27 |
+
obs = requests.post(f"{ENV_URL}/reset").json()
|
| 28 |
+
tools = ["cleaner", "augmenter", "balancer", "validator"]
|
| 29 |
+
total_reward = 0.0
|
| 30 |
+
success = False
|
| 31 |
+
for _ in range(10):
|
| 32 |
+
action = {"agent": random.choice(tools), "target": "all"}
|
| 33 |
+
result = requests.post(f"{ENV_URL}/step", json=action).json()
|
| 34 |
+
total_reward += result.get("reward", 0)
|
| 35 |
+
if result.get("done"):
|
| 36 |
+
success = result.get("info", {}).get("success", False)
|
| 37 |
+
break
|
| 38 |
+
return total_reward, success
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# βββ Run baseline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
print(f"Running {N_EPISODES} baseline (random) episodes against {ENV_URL}...")
|
| 43 |
+
baseline_rewards = []
|
| 44 |
+
baseline_successes = []
|
| 45 |
+
for i in range(N_EPISODES):
|
| 46 |
+
reward, success = random_agent_episode()
|
| 47 |
+
baseline_rewards.append(reward)
|
| 48 |
+
baseline_successes.append(success)
|
| 49 |
+
print(f" Episode {i+1:02d}: reward={reward:.3f} success={success}")
|
| 50 |
+
|
| 51 |
+
mean_baseline = sum(baseline_rewards) / len(baseline_rewards)
|
| 52 |
+
success_rate_baseline = sum(baseline_successes) / len(baseline_successes)
|
| 53 |
+
print(f"\nBaseline mean reward: {mean_baseline:.3f}")
|
| 54 |
+
print(f"Baseline success rate: {success_rate_baseline:.1%}")
|
| 55 |
+
|
| 56 |
+
# βββ Plot reward curve ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 57 |
+
plt.figure(figsize=(10, 4))
|
| 58 |
+
|
| 59 |
+
plt.subplot(1, 2, 1)
|
| 60 |
+
plt.plot(range(1, N_EPISODES + 1), baseline_rewards, marker="o", color="#5B8FF9", label="Random baseline")
|
| 61 |
+
plt.xlabel("Episode")
|
| 62 |
+
plt.ylabel("Total Reward")
|
| 63 |
+
plt.title("Baseline Reward per Episode")
|
| 64 |
+
plt.legend()
|
| 65 |
+
plt.grid(alpha=0.3)
|
| 66 |
+
|
| 67 |
+
plt.subplot(1, 2, 2)
|
| 68 |
+
mean_trained = mean_baseline * 1.0 # placeholder β replace with trained agent result
|
| 69 |
+
plt.bar(["Random baseline", "Trained agent"],
|
| 70 |
+
[mean_baseline, mean_trained],
|
| 71 |
+
color=["#5B8FF9", "#5AD8A6"])
|
| 72 |
+
plt.ylabel("Mean Episode Reward")
|
| 73 |
+
plt.title("Baseline vs Trained Agent")
|
| 74 |
+
plt.grid(alpha=0.3, axis="y")
|
| 75 |
+
|
| 76 |
+
plt.tight_layout()
|
| 77 |
+
plt.savefig("results.png", dpi=150)
|
| 78 |
+
print("\nSaved results.png")
|