ShreeshantXD commited on
Commit
30aa472
·
1 Parent(s): 891cc5b

fix: auto-start env server in inference.py

Browse files
Files changed (1) hide show
  1. inference.py +161 -53
inference.py CHANGED
@@ -28,9 +28,10 @@ from __future__ import annotations
28
  import argparse
29
  import json
30
  import os
 
31
  import sys
32
  import time
33
- from typing import Any
34
 
35
  import requests
36
  from openai import OpenAI
@@ -427,6 +428,97 @@ def run_episode(
427
 
428
  # ── Main ─────────────────────────────────────────────────────────────────────
429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  def main() -> None:
432
  parser = argparse.ArgumentParser(description="GridMind-RL baseline inference")
@@ -459,58 +551,74 @@ def main() -> None:
459
  print("HF_TOKEN is required.", file=sys.stderr)
460
  sys.exit(1)
461
 
462
- env_client = GridMindEnvClient(base_url=args.env_url)
463
-
464
- for attempt in range(30):
465
- if env_client.health():
466
- break
467
- time.sleep(2)
468
- if attempt == 29:
469
- print("Environment server not reachable.", file=sys.stderr)
470
- sys.exit(1)
471
-
472
- agent = LLMAgent()
473
- all_results: list[dict[str, Any]] = []
474
-
475
- for task_id in [1, 2, 3]:
476
- task_scores: list[float] = []
477
- for ep in range(args.episodes):
478
- seed = DEFAULT_SEED_BASE + task_id * 100 + ep
479
- result = run_episode(
480
- env_client,
481
- agent,
482
- task_id=task_id,
483
- seed=seed,
484
- fast_mode=args.fast_mode,
485
- llm_every=args.llm_every,
486
- max_steps=args.max_steps,
487
- verbose=args.verbose,
488
- )
489
- task_scores.append(float(result["score"]))
490
- all_results.append(result)
491
- _ = sum(task_scores) / len(task_scores)
492
-
493
- task_avgs: dict[int, float] = {}
494
- for task_id in [1, 2, 3]:
495
- scores = [float(r["score"]) for r in all_results if r["task_id"] == task_id]
496
- avg = sum(scores) / len(scores) if scores else 0.0
497
- task_avgs[task_id] = avg
498
- overall = sum(task_avgs.values()) / len(task_avgs)
499
-
500
- output = {
501
- "model": MODEL_NAME,
502
- "api_base": API_BASE_URL,
503
- "episodes_per_task": args.episodes,
504
- "seed_base": DEFAULT_SEED_BASE,
505
- "fast_mode": args.fast_mode,
506
- "llm_every": args.llm_every,
507
- "max_steps": args.max_steps,
508
- "task_averages": {str(k): v for k, v in task_avgs.items()},
509
- "overall_average": overall,
510
- "all_results": all_results,
511
- }
512
- with open(args.output, "w", encoding="utf-8") as f:
513
- json.dump(output, f, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
 
516
  if __name__ == "__main__":
 
28
  import argparse
29
  import json
30
  import os
31
+ import subprocess
32
  import sys
33
  import time
34
+ from typing import Any, Optional
35
 
36
  import requests
37
  from openai import OpenAI
 
428
 
429
  # ── Main ─────────────────────────────────────────────────────────────────────
430
 
431
+ def start_environment_server(port: int = 7860) -> Optional[subprocess.Popen]:
432
+ """Start the GridMind-RL environment server as a background process.
433
+
434
+ Returns:
435
+ A Popen object if the server was started, or None if it's already running.
436
+ """
437
+ # First check if server is already running
438
+ try:
439
+ r = requests.get(f"http://localhost:{port}/health", timeout=2)
440
+ if r.status_code == 200:
441
+ print(f"[INFO] Environment server already running on port {port}", file=sys.stderr)
442
+ return None
443
+ except Exception:
444
+ pass # Server not running, we'll start it
445
+
446
+ print(f"[INFO] Starting environment server on port {port}...", file=sys.stderr)
447
+
448
+ # Try to find and run the server
449
+ try:
450
+ # Prepare environment
451
+ env = os.environ.copy()
452
+ env["PORT"] = str(port)
453
+ if "PYTHONPATH" in env:
454
+ env["PYTHONPATH"] = "." + os.pathsep + env["PYTHONPATH"]
455
+ else:
456
+ env["PYTHONPATH"] = "."
457
+
458
+ # Look for compiled Go binary first
459
+ binary_paths = [
460
+ "/usr/local/bin/gridmind-server", # Docker path
461
+ "./gridmind-server", # Local Linux/Mac
462
+ "./gridmind-server.exe", # Local Windows
463
+ ]
464
+
465
+ for binary_path in binary_paths:
466
+ if os.path.exists(binary_path):
467
+ try:
468
+ print(f"[INFO] Running Go binary: {binary_path}", file=sys.stderr)
469
+ proc = subprocess.Popen(
470
+ [binary_path],
471
+ env=env,
472
+ stdout=subprocess.DEVNULL,
473
+ stderr=subprocess.DEVNULL,
474
+ )
475
+ time.sleep(2)
476
+ if proc.poll() is None:
477
+ return proc
478
+ except Exception as e:
479
+ print(f"[DEBUG] Failed with {binary_path}: {e}", file=sys.stderr)
480
+
481
+ # Try to compile Go binary if 'go' is available
482
+ try:
483
+ print(f"[INFO] Attempting to compile Go executable...", file=sys.stderr)
484
+ compile_cmd = ["go", "build", "-o", "gridmind-server", "main.go"]
485
+ result = subprocess.run(
486
+ compile_cmd,
487
+ capture_output=True,
488
+ timeout=60,
489
+ cwd=".",
490
+ )
491
+ if result.returncode == 0:
492
+ print(f"[INFO] Compilation successful, starting server...", file=sys.stderr)
493
+ proc = subprocess.Popen(
494
+ ["./gridmind-server"],
495
+ env=env,
496
+ stdout=subprocess.PIPE,
497
+ stderr=subprocess.PIPE,
498
+ )
499
+ time.sleep(2)
500
+ if proc.poll() is None:
501
+ return proc
502
+ except Exception as e:
503
+ print(f"[DEBUG] Could not compile: {e}", file=sys.stderr)
504
+
505
+ # Fallback: try to run via Python server module
506
+ print(f"[INFO] Attempting Python server module...", file=sys.stderr)
507
+ proc = subprocess.Popen(
508
+ [sys.executable, "-m", "server.app"],
509
+ env=env,
510
+ stdout=subprocess.PIPE,
511
+ stderr=subprocess.PIPE,
512
+ cwd=".",
513
+ )
514
+ time.sleep(3)
515
+ if proc.poll() is None:
516
+ return proc
517
+
518
+ except Exception as e:
519
+ print(f"[WARNING] Could not start environment server: {e}", file=sys.stderr)
520
+ return None
521
+
522
 
523
  def main() -> None:
524
  parser = argparse.ArgumentParser(description="GridMind-RL baseline inference")
 
551
  print("HF_TOKEN is required.", file=sys.stderr)
552
  sys.exit(1)
553
 
554
+ # Start the environment server if not already running
555
+ server_proc = start_environment_server(port=7860)
556
+
557
+ try:
558
+ env_client = GridMindEnvClient(base_url=args.env_url)
559
+
560
+ for attempt in range(30):
561
+ if env_client.health():
562
+ break
563
+ time.sleep(2)
564
+ if attempt == 29:
565
+ print("Environment server not reachable.", file=sys.stderr)
566
+ sys.exit(1)
567
+
568
+ agent = LLMAgent()
569
+ all_results: list[dict[str, Any]] = []
570
+
571
+ for task_id in [1, 2, 3]:
572
+ task_scores: list[float] = []
573
+ for ep in range(args.episodes):
574
+ seed = DEFAULT_SEED_BASE + task_id * 100 + ep
575
+ result = run_episode(
576
+ env_client,
577
+ agent,
578
+ task_id=task_id,
579
+ seed=seed,
580
+ fast_mode=args.fast_mode,
581
+ llm_every=args.llm_every,
582
+ max_steps=args.max_steps,
583
+ verbose=args.verbose,
584
+ )
585
+ task_scores.append(float(result["score"]))
586
+ all_results.append(result)
587
+ _ = sum(task_scores) / len(task_scores)
588
+
589
+ task_avgs: dict[int, float] = {}
590
+ for task_id in [1, 2, 3]:
591
+ scores = [float(r["score"]) for r in all_results if r["task_id"] == task_id]
592
+ avg = sum(scores) / len(scores) if scores else 0.0
593
+ task_avgs[task_id] = avg
594
+ overall = sum(task_avgs.values()) / len(task_avgs)
595
+
596
+ output = {
597
+ "model": MODEL_NAME,
598
+ "api_base": API_BASE_URL,
599
+ "episodes_per_task": args.episodes,
600
+ "seed_base": DEFAULT_SEED_BASE,
601
+ "fast_mode": args.fast_mode,
602
+ "llm_every": args.llm_every,
603
+ "max_steps": args.max_steps,
604
+ "task_averages": {str(k): v for k, v in task_avgs.items()},
605
+ "overall_average": overall,
606
+ "all_results": all_results,
607
+ }
608
+ with open(args.output, "w", encoding="utf-8") as f:
609
+ json.dump(output, f, indent=2)
610
+ finally:
611
+ # Clean up the server process if we started it
612
+ if server_proc:
613
+ try:
614
+ server_proc.terminate()
615
+ server_proc.wait(timeout=5)
616
+ except Exception as e:
617
+ print(f"[WARNING] Failed to terminate server: {e}", file=sys.stderr)
618
+ try:
619
+ server_proc.kill()
620
+ except Exception:
621
+ pass
622
 
623
 
624
  if __name__ == "__main__":