Commit ·
30aa472
1
Parent(s): 891cc5b
fix: auto-start env server in inference.py
Browse files- inference.py +161 -53
inference.py
CHANGED
|
@@ -28,9 +28,10 @@ from __future__ import annotations
|
|
| 28 |
import argparse
|
| 29 |
import json
|
| 30 |
import os
|
|
|
|
| 31 |
import sys
|
| 32 |
import time
|
| 33 |
-
from typing import Any
|
| 34 |
|
| 35 |
import requests
|
| 36 |
from openai import OpenAI
|
|
@@ -427,6 +428,97 @@ def run_episode(
|
|
| 427 |
|
| 428 |
# ── Main ─────────────────────────────────────────────────────────────────────
|
| 429 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
def main() -> None:
|
| 432 |
parser = argparse.ArgumentParser(description="GridMind-RL baseline inference")
|
|
@@ -459,58 +551,74 @@ def main() -> None:
|
|
| 459 |
print("HF_TOKEN is required.", file=sys.stderr)
|
| 460 |
sys.exit(1)
|
| 461 |
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
task_avgs[
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
|
| 515 |
|
| 516 |
if __name__ == "__main__":
|
|
|
|
| 28 |
import argparse
|
| 29 |
import json
|
| 30 |
import os
|
| 31 |
+
import subprocess
|
| 32 |
import sys
|
| 33 |
import time
|
| 34 |
+
from typing import Any, Optional
|
| 35 |
|
| 36 |
import requests
|
| 37 |
from openai import OpenAI
|
|
|
|
| 428 |
|
| 429 |
# ── Main ─────────────────────────────────────────────────────────────────────
|
| 430 |
|
| 431 |
+
def start_environment_server(port: int = 7860) -> Optional[subprocess.Popen]:
|
| 432 |
+
"""Start the GridMind-RL environment server as a background process.
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
A Popen object if the server was started, or None if it's already running.
|
| 436 |
+
"""
|
| 437 |
+
# First check if server is already running
|
| 438 |
+
try:
|
| 439 |
+
r = requests.get(f"http://localhost:{port}/health", timeout=2)
|
| 440 |
+
if r.status_code == 200:
|
| 441 |
+
print(f"[INFO] Environment server already running on port {port}", file=sys.stderr)
|
| 442 |
+
return None
|
| 443 |
+
except Exception:
|
| 444 |
+
pass # Server not running, we'll start it
|
| 445 |
+
|
| 446 |
+
print(f"[INFO] Starting environment server on port {port}...", file=sys.stderr)
|
| 447 |
+
|
| 448 |
+
# Try to find and run the server
|
| 449 |
+
try:
|
| 450 |
+
# Prepare environment
|
| 451 |
+
env = os.environ.copy()
|
| 452 |
+
env["PORT"] = str(port)
|
| 453 |
+
if "PYTHONPATH" in env:
|
| 454 |
+
env["PYTHONPATH"] = "." + os.pathsep + env["PYTHONPATH"]
|
| 455 |
+
else:
|
| 456 |
+
env["PYTHONPATH"] = "."
|
| 457 |
+
|
| 458 |
+
# Look for compiled Go binary first
|
| 459 |
+
binary_paths = [
|
| 460 |
+
"/usr/local/bin/gridmind-server", # Docker path
|
| 461 |
+
"./gridmind-server", # Local Linux/Mac
|
| 462 |
+
"./gridmind-server.exe", # Local Windows
|
| 463 |
+
]
|
| 464 |
+
|
| 465 |
+
for binary_path in binary_paths:
|
| 466 |
+
if os.path.exists(binary_path):
|
| 467 |
+
try:
|
| 468 |
+
print(f"[INFO] Running Go binary: {binary_path}", file=sys.stderr)
|
| 469 |
+
proc = subprocess.Popen(
|
| 470 |
+
[binary_path],
|
| 471 |
+
env=env,
|
| 472 |
+
stdout=subprocess.DEVNULL,
|
| 473 |
+
stderr=subprocess.DEVNULL,
|
| 474 |
+
)
|
| 475 |
+
time.sleep(2)
|
| 476 |
+
if proc.poll() is None:
|
| 477 |
+
return proc
|
| 478 |
+
except Exception as e:
|
| 479 |
+
print(f"[DEBUG] Failed with {binary_path}: {e}", file=sys.stderr)
|
| 480 |
+
|
| 481 |
+
# Try to compile Go binary if 'go' is available
|
| 482 |
+
try:
|
| 483 |
+
print(f"[INFO] Attempting to compile Go executable...", file=sys.stderr)
|
| 484 |
+
compile_cmd = ["go", "build", "-o", "gridmind-server", "main.go"]
|
| 485 |
+
result = subprocess.run(
|
| 486 |
+
compile_cmd,
|
| 487 |
+
capture_output=True,
|
| 488 |
+
timeout=60,
|
| 489 |
+
cwd=".",
|
| 490 |
+
)
|
| 491 |
+
if result.returncode == 0:
|
| 492 |
+
print(f"[INFO] Compilation successful, starting server...", file=sys.stderr)
|
| 493 |
+
proc = subprocess.Popen(
|
| 494 |
+
["./gridmind-server"],
|
| 495 |
+
env=env,
|
| 496 |
+
stdout=subprocess.PIPE,
|
| 497 |
+
stderr=subprocess.PIPE,
|
| 498 |
+
)
|
| 499 |
+
time.sleep(2)
|
| 500 |
+
if proc.poll() is None:
|
| 501 |
+
return proc
|
| 502 |
+
except Exception as e:
|
| 503 |
+
print(f"[DEBUG] Could not compile: {e}", file=sys.stderr)
|
| 504 |
+
|
| 505 |
+
# Fallback: try to run via Python server module
|
| 506 |
+
print(f"[INFO] Attempting Python server module...", file=sys.stderr)
|
| 507 |
+
proc = subprocess.Popen(
|
| 508 |
+
[sys.executable, "-m", "server.app"],
|
| 509 |
+
env=env,
|
| 510 |
+
stdout=subprocess.PIPE,
|
| 511 |
+
stderr=subprocess.PIPE,
|
| 512 |
+
cwd=".",
|
| 513 |
+
)
|
| 514 |
+
time.sleep(3)
|
| 515 |
+
if proc.poll() is None:
|
| 516 |
+
return proc
|
| 517 |
+
|
| 518 |
+
except Exception as e:
|
| 519 |
+
print(f"[WARNING] Could not start environment server: {e}", file=sys.stderr)
|
| 520 |
+
return None
|
| 521 |
+
|
| 522 |
|
| 523 |
def main() -> None:
|
| 524 |
parser = argparse.ArgumentParser(description="GridMind-RL baseline inference")
|
|
|
|
| 551 |
print("HF_TOKEN is required.", file=sys.stderr)
|
| 552 |
sys.exit(1)
|
| 553 |
|
| 554 |
+
# Start the environment server if not already running
|
| 555 |
+
server_proc = start_environment_server(port=7860)
|
| 556 |
+
|
| 557 |
+
try:
|
| 558 |
+
env_client = GridMindEnvClient(base_url=args.env_url)
|
| 559 |
+
|
| 560 |
+
for attempt in range(30):
|
| 561 |
+
if env_client.health():
|
| 562 |
+
break
|
| 563 |
+
time.sleep(2)
|
| 564 |
+
if attempt == 29:
|
| 565 |
+
print("Environment server not reachable.", file=sys.stderr)
|
| 566 |
+
sys.exit(1)
|
| 567 |
+
|
| 568 |
+
agent = LLMAgent()
|
| 569 |
+
all_results: list[dict[str, Any]] = []
|
| 570 |
+
|
| 571 |
+
for task_id in [1, 2, 3]:
|
| 572 |
+
task_scores: list[float] = []
|
| 573 |
+
for ep in range(args.episodes):
|
| 574 |
+
seed = DEFAULT_SEED_BASE + task_id * 100 + ep
|
| 575 |
+
result = run_episode(
|
| 576 |
+
env_client,
|
| 577 |
+
agent,
|
| 578 |
+
task_id=task_id,
|
| 579 |
+
seed=seed,
|
| 580 |
+
fast_mode=args.fast_mode,
|
| 581 |
+
llm_every=args.llm_every,
|
| 582 |
+
max_steps=args.max_steps,
|
| 583 |
+
verbose=args.verbose,
|
| 584 |
+
)
|
| 585 |
+
task_scores.append(float(result["score"]))
|
| 586 |
+
all_results.append(result)
|
| 587 |
+
_ = sum(task_scores) / len(task_scores)
|
| 588 |
+
|
| 589 |
+
task_avgs: dict[int, float] = {}
|
| 590 |
+
for task_id in [1, 2, 3]:
|
| 591 |
+
scores = [float(r["score"]) for r in all_results if r["task_id"] == task_id]
|
| 592 |
+
avg = sum(scores) / len(scores) if scores else 0.0
|
| 593 |
+
task_avgs[task_id] = avg
|
| 594 |
+
overall = sum(task_avgs.values()) / len(task_avgs)
|
| 595 |
+
|
| 596 |
+
output = {
|
| 597 |
+
"model": MODEL_NAME,
|
| 598 |
+
"api_base": API_BASE_URL,
|
| 599 |
+
"episodes_per_task": args.episodes,
|
| 600 |
+
"seed_base": DEFAULT_SEED_BASE,
|
| 601 |
+
"fast_mode": args.fast_mode,
|
| 602 |
+
"llm_every": args.llm_every,
|
| 603 |
+
"max_steps": args.max_steps,
|
| 604 |
+
"task_averages": {str(k): v for k, v in task_avgs.items()},
|
| 605 |
+
"overall_average": overall,
|
| 606 |
+
"all_results": all_results,
|
| 607 |
+
}
|
| 608 |
+
with open(args.output, "w", encoding="utf-8") as f:
|
| 609 |
+
json.dump(output, f, indent=2)
|
| 610 |
+
finally:
|
| 611 |
+
# Clean up the server process if we started it
|
| 612 |
+
if server_proc:
|
| 613 |
+
try:
|
| 614 |
+
server_proc.terminate()
|
| 615 |
+
server_proc.wait(timeout=5)
|
| 616 |
+
except Exception as e:
|
| 617 |
+
print(f"[WARNING] Failed to terminate server: {e}", file=sys.stderr)
|
| 618 |
+
try:
|
| 619 |
+
server_proc.kill()
|
| 620 |
+
except Exception:
|
| 621 |
+
pass
|
| 622 |
|
| 623 |
|
| 624 |
if __name__ == "__main__":
|