gridmind / python /validate.py
ShreeshantXD's picture
feat: add baseline scores JSON, inference script, and update Dockerfile for improved project structure
6d74982
"""
GridMind-RL Pre-Submission Validator
--------------------------------------
Validates the Go environment server against all OpenEnv spec requirements.
Run with: python python/validate.py [--env-url http://localhost:7860]
"""
import argparse
import json
import sys
import time
import traceback
from typing import Any
import requests
ENV_URL = "http://localhost:7860"
PASS = "[OK]"
FAIL = "[FAIL]"
WARN = "[WARN]"
def check(label: str, condition: bool, detail: str = "") -> bool:
icon = PASS if condition else FAIL
line = f" {icon} {label}"
if detail:
line += f" - {detail}"
print(line)
return condition
def get(url: str, timeout: int = 10) -> requests.Response:
return requests.get(url, timeout=timeout)
def post(url: str, payload: Any = None, timeout: int = 10) -> requests.Response:
return requests.post(url, json=payload, timeout=timeout)
def validate(env_url: str) -> bool:
base = env_url.rstrip("/")
results = []
print("\n" + "=" * 50)
print(" GridMind-RL OpenEnv Validation Report")
print("=" * 50 + "\n")
# ── 1. Health & ping ─────────────────────────────────────────────────────
print("1. Health & Ping")
try:
r = get(f"{base}/health")
results.append(check("GET /health returns 200", r.status_code == 200, f"got {r.status_code}"))
data = r.json()
results.append(check("Response has 'status' field", "status" in data))
rp = get(f"{base}/ping")
results.append(check("GET /ping returns 200", rp.status_code == 200, f"got {rp.status_code}"))
except Exception as e:
results.append(check("GET /health reachable", False, str(e)))
print(f"\n [FAIL] Cannot reach server at {base}. Is it running?\n")
return False
# ── 2. Reset endpoint ───────────────────────────────────────────────────
print("\n2. Reset Endpoint")
reset_resp = None
try:
r = post(f"{base}/reset", {"task_id": 1, "seed": 42, "num_buildings": 1})
results.append(check("POST /reset returns 200", r.status_code == 200, f"got {r.status_code}"))
reset_resp = r.json()
results.append(check("Response has 'observations'", "observations" in reset_resp))
results.append(check("Response has 'episode'", "episode" in reset_resp))
results.append(check("Response has 'seed'", "seed" in reset_resp))
results.append(check("Response has 'task_id'", "task_id" in reset_resp))
obs_list = reset_resp.get("observations", [])
results.append(check("observations is a list", isinstance(obs_list, list)))
results.append(check("At least 1 observation returned", len(obs_list) >= 1))
if obs_list:
obs = obs_list[0]
obs_fields = ["indoor_temperature", "thermal_storage_level", "process_demand",
"current_price", "grid_stress_signal", "carbon_intensity",
"hour_of_day", "batch_queue", "cumulative_cost", "step"]
for field in obs_fields:
results.append(check(f"obs has '{field}'", field in obs))
# Seed reproducibility
r2 = post(f"{base}/reset", {"task_id": 1, "seed": 42})
d2 = r2.json()
obs1 = reset_resp.get("observations", [{}])[0]
obs2 = d2.get("observations", [{}])[0]
same = (abs(obs1.get("indoor_temperature", 0) - obs2.get("indoor_temperature", 0)) < 1e-6)
results.append(check("Same seed produces same initial obs", same))
except Exception as e:
results.append(check("POST /reset succeeds", False, str(e)))
traceback.print_exc()
# ── 3. Step endpoint ────────────────────────────────────────────────────
print("\n3. Step Endpoint")
try:
# Reset fresh
post(f"{base}/reset", {"task_id": 1, "seed": 100})
action = {
"hvac_power_level": 0.5,
"thermal_charge_rate": 0.1,
"batch_job_slot": 1,
"load_shed_fraction": 0.0,
"building_id": 0,
}
r = post(f"{base}/step", action)
results.append(check("POST /step returns 200", r.status_code == 200))
step_resp = r.json()
step_fields = ["observation", "reward", "done", "info"]
for f in step_fields:
results.append(check(f"step response has '{f}'", f in step_resp))
results.append(check("reward is numeric", isinstance(step_resp.get("reward"), (int, float))))
results.append(check("done is boolean", isinstance(step_resp.get("done"), bool)))
info = step_resp.get("info", {})
results.append(check("info has 'reward_components'", "reward_components" in info))
results.append(check("info has 'energy_used_kwh'", "energy_used_kwh" in info))
rc = info.get("reward_components", {})
rc_fields = ["cost_savings", "temp_constraint", "grid_response",
"deadline_penalty", "efficiency_bonus", "stability_penalty",
"carbon_reward", "total"]
for f in rc_fields:
results.append(check(f"reward_components has '{f}'", f in rc))
# Test array action format
r2 = post(f"{base}/step", [action])
results.append(check("POST /step accepts array of actions", r2.status_code == 200))
except Exception as e:
results.append(check("POST /step succeeds", False, str(e)))
traceback.print_exc()
# ── 4. State endpoint ───────────────────────────────────────────────────
print("\n4. State Endpoint")
try:
r = get(f"{base}/state")
results.append(check("GET /state returns 200", r.status_code == 200))
state = r.json()
state_fields = ["buildings", "price_curve_episode", "carbon_curve_episode",
"episode", "step", "task_id", "done", "seed"]
for f in state_fields:
results.append(check(f"state has '{f}'", f in state))
curve_n = 24 # EpisodeSteps/4 (96/4) downsamples to hourly points
results.append(check("price_curve_episode has 24 entries",
len(state.get("price_curve_episode", [])) == curve_n))
results.append(check("carbon_curve_episode has 24 entries",
len(state.get("carbon_curve_episode", [])) == curve_n))
except Exception as e:
results.append(check("GET /state succeeds", False, str(e)))
# ── 5. Replay endpoint ──────────────────────────────────────────────────
print("\n5. Replay Endpoint")
try:
r = get(f"{base}/replay")
results.append(check("GET /replay returns 200", r.status_code == 200))
replay = r.json()
results.append(check("response has 'replay' list", "replay" in replay))
results.append(check("response has 'steps' count", "steps" in replay))
except Exception as e:
results.append(check("GET /replay succeeds", False, str(e)))
# ── 6. Grade endpoint ───────────────────────────────────────────────────
print("\n6. Grade Endpoint")
try:
# Run quick 10-step episode
post(f"{base}/reset", {"task_id": 1, "seed": 777})
action = {"hvac_power_level": 0.3, "thermal_charge_rate": 0.0,
"batch_job_slot": 0, "load_shed_fraction": 0.0}
done = False
while not done:
r2 = post(f"{base}/step", action)
if r2.json().get("done"):
done = True
r = get(f"{base}/grade")
results.append(check("GET /grade returns 200", r.status_code == 200))
grade = r.json()
grade_fields = ["task_id", "score", "sub_scores", "exploit_detected"]
for f in grade_fields:
results.append(check(f"grade has '{f}'", f in grade))
score = grade.get("score", -1)
results.append(check("score in [0.0, 1.0]", 0.0 <= score <= 1.0, f"score={score:.4f}"))
except Exception as e:
results.append(check("GET /grade succeeds", False, str(e)))
# ── 7. Tasks endpoint ───────────────────────────────────────────────────
print("\n7. Tasks Endpoint")
try:
r = get(f"{base}/tasks")
results.append(check("GET /tasks returns 200", r.status_code == 200))
tasks = r.json()
results.append(check("returns list of 3 tasks", len(tasks) == 3))
task_fields = ["id", "name", "description", "difficulty", "weights"]
for f in task_fields:
results.append(check(f"task has '{f}'", f in tasks[0]))
except Exception as e:
results.append(check("GET /tasks succeeds", False, str(e)))
# ── 8. Metrics endpoint ─────────────────────────────────────────────────
print("\n8. Metrics Endpoint (Prometheus)")
try:
r = get(f"{base}/metrics")
results.append(check("GET /metrics returns 200", r.status_code == 200))
content = r.text
results.append(check("metrics contain step counter",
"gridmind_steps_total" in content))
results.append(check("metrics contain latency gauge",
"gridmind_step_latency_ms_avg" in content))
except Exception as e:
results.append(check("GET /metrics succeeds", False, str(e)))
# ── 9. Grader score variation ───────────────────────────────────────────
print("\n9. Grader Score Variation (non-trivial scores)")
scores_nonzero = []
scores_nonone = []
for seed in [10, 20, 30]:
try:
post(f"{base}/reset", {"task_id": 1, "seed": seed})
# Two different policies
for a in [0.1, 0.9]:
post(f"{base}/reset", {"task_id": 1, "seed": seed})
done = False
while not done:
r2 = post(f"{base}/step", {"hvac_power_level": a, "thermal_charge_rate": 0,
"batch_job_slot": 0, "load_shed_fraction": 0})
if r2.json().get("done"):
done = True
g = requests.get(f"{base}/grade", timeout=10).json()
sc = g.get("score", 0)
scores_nonzero.append(sc > 0.01)
scores_nonone.append(sc < 0.999)
except Exception:
pass
results.append(check("Scores are not always 0.0", any(scores_nonzero)))
results.append(check("Scores are not always 1.0", any(scores_nonone)))
# ── Summary ─────────────────────────────────────────────────────────────
passed = sum(results)
total = len(results)
pct = 100 * passed // total if total > 0 else 0
print(f"\n" + "=" * 50)
print(f" Result: {passed}/{total} checks passed ({pct}%)")
if passed == total:
print(" ALL CHECKS PASSED - Ready for submission!")
else:
print(f" {total - passed} checks failed. Fix errors above.")
print("=" * 50 + "\n")
return passed == total
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--env-url", type=str, default=ENV_URL)
args = parser.parse_args()
ok = validate(args.env_url)
sys.exit(0 if ok else 1)
if __name__ == "__main__":
main()