|
|
| """
|
| MAC Worker Agent — Enrolls with control node and sends periodic heartbeats.
|
| Runs as a sidecar container alongside vLLM on each GPU worker PC.
|
| """
|
|
|
| import asyncio
|
| import json
|
| import os
|
| import socket
|
| import sys
|
| import time
|
|
|
| import httpx
|
|
|
| CONTROL_URL = os.environ.get("CONTROL_NODE_URL", "http://192.168.1.100:8000")
|
| ENROLLMENT_TOKEN = os.environ.get("ENROLLMENT_TOKEN", "")
|
| NODE_NAME = os.environ.get("NODE_NAME", f"worker-{socket.gethostname()}")
|
| VLLM_PORT = int(os.environ.get("VLLM_PORT", 8001))
|
| VLLM_MODEL = os.environ.get("VLLM_MODEL", "")
|
| GPU_NAME = os.environ.get("GPU_NAME", "NVIDIA GPU")
|
| GPU_VRAM_MB = int(os.environ.get("GPU_VRAM_MB", 12288))
|
| RAM_TOTAL_MB = int(os.environ.get("RAM_TOTAL_MB", 16384))
|
| CPU_CORES = int(os.environ.get("CPU_CORES", 8))
|
| HEARTBEAT_INTERVAL = int(os.environ.get("HEARTBEAT_INTERVAL", 30))
|
|
|
| API = f"{CONTROL_URL}/api/v1"
|
| STATE_FILE = "/tmp/mac_worker_state.json"
|
|
|
|
|
| def get_local_ip():
|
| """Get the local IP address visible on the network."""
|
| try:
|
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
| s.connect(("8.8.8.8", 80))
|
| ip = s.getsockname()[0]
|
| s.close()
|
| return ip
|
| except Exception:
|
| return "127.0.0.1"
|
|
|
|
|
| def load_state():
|
| """Load saved node ID from previous enrollment."""
|
| try:
|
| with open(STATE_FILE, "r") as f:
|
| return json.load(f)
|
| except (FileNotFoundError, json.JSONDecodeError):
|
| return {}
|
|
|
|
|
| def save_state(data):
|
| """Save enrollment state."""
|
| with open(STATE_FILE, "w") as f:
|
| json.dump(data, f)
|
|
|
|
|
| def get_resource_metrics():
|
| """Collect current resource utilization metrics."""
|
| metrics = {
|
| "cpu_util_pct": 0.0,
|
| "ram_used_mb": 0,
|
| "gpu_util_pct": 0.0,
|
| "gpu_vram_used_mb": 0,
|
| }
|
| try:
|
| import psutil
|
| metrics["cpu_util_pct"] = psutil.cpu_percent(interval=1)
|
| mem = psutil.virtual_memory()
|
| metrics["ram_used_mb"] = int(mem.used / 1024 / 1024)
|
| except ImportError:
|
| pass
|
|
|
|
|
| try:
|
| import subprocess
|
| result = subprocess.run(
|
| ["nvidia-smi", "--query-gpu=utilization.gpu,memory.used",
|
| "--format=csv,noheader,nounits"],
|
| capture_output=True, text=True, timeout=5
|
| )
|
| if result.returncode == 0:
|
| parts = result.stdout.strip().split(",")
|
| if len(parts) >= 2:
|
| metrics["gpu_util_pct"] = float(parts[0].strip())
|
| metrics["gpu_vram_used_mb"] = int(float(parts[1].strip()))
|
| except (FileNotFoundError, subprocess.TimeoutExpired, ValueError):
|
| pass
|
|
|
| return metrics
|
|
|
|
|
| async def enroll(client: httpx.AsyncClient) -> str | None:
|
| """Enroll this node with the control server. Returns node ID."""
|
| state = load_state()
|
| if state.get("node_id"):
|
| print(f"[AGENT] Already enrolled as node {state['node_id']}")
|
| return state["node_id"]
|
|
|
| if not ENROLLMENT_TOKEN:
|
| print("[AGENT] ERROR: No ENROLLMENT_TOKEN set. Cannot enroll.")
|
| return None
|
|
|
| ip = get_local_ip()
|
| payload = {
|
| "enrollment_token": ENROLLMENT_TOKEN,
|
| "name": NODE_NAME,
|
| "hostname": socket.gethostname(),
|
| "ip_address": ip,
|
| "port": VLLM_PORT,
|
| "gpu_name": GPU_NAME,
|
| "gpu_vram_mb": GPU_VRAM_MB,
|
| "ram_total_mb": RAM_TOTAL_MB,
|
| "cpu_cores": CPU_CORES,
|
| }
|
|
|
| try:
|
| resp = await client.post(f"{API}/nodes/enroll", json=payload)
|
| if resp.status_code == 200:
|
| data = resp.json()
|
| node_id = data.get("id")
|
| save_state({"node_id": node_id, "name": NODE_NAME})
|
| print(f"[AGENT] Enrolled successfully! Node ID: {node_id}")
|
| return node_id
|
| else:
|
| print(f"[AGENT] Enrollment failed: {resp.status_code} {resp.text}")
|
| return None
|
| except httpx.RequestError as e:
|
| print(f"[AGENT] Connection error during enrollment: {e}")
|
| return None
|
|
|
|
|
| async def heartbeat_loop(client: httpx.AsyncClient, node_id: str):
|
| """Send periodic heartbeats with resource metrics."""
|
| consecutive_failures = 0
|
| max_failures = 10
|
|
|
| while True:
|
| try:
|
| metrics = get_resource_metrics()
|
| resp = await client.post(
|
| f"{API}/nodes/heartbeat/{node_id}",
|
| json=metrics
|
| )
|
|
|
| if resp.status_code == 200:
|
| data = resp.json()
|
| consecutive_failures = 0
|
| warnings = data.get("warnings", [])
|
| if warnings:
|
| print(f"[AGENT] Resource warnings: {warnings}")
|
| elif resp.status_code == 404:
|
| print("[AGENT] Node not found — re-enrollment needed")
|
| save_state({})
|
| return
|
| else:
|
| consecutive_failures += 1
|
| print(f"[AGENT] Heartbeat failed: {resp.status_code}")
|
|
|
| except httpx.RequestError as e:
|
| consecutive_failures += 1
|
| print(f"[AGENT] Heartbeat connection error: {e}")
|
|
|
| if consecutive_failures >= max_failures:
|
| print(f"[AGENT] {max_failures} consecutive failures. Waiting 60s before retry.")
|
| await asyncio.sleep(60)
|
| consecutive_failures = 0
|
| else:
|
| await asyncio.sleep(HEARTBEAT_INTERVAL)
|
|
|
|
|
| async def wait_for_vllm():
|
| """Wait for local vLLM server to be ready."""
|
| print(f"[AGENT] Waiting for vLLM on port {VLLM_PORT}...")
|
| async with httpx.AsyncClient(timeout=5) as client:
|
| for attempt in range(120):
|
| try:
|
| resp = await client.get(f"http://localhost:{VLLM_PORT}/health")
|
| if resp.status_code == 200:
|
| print("[AGENT] vLLM is ready!")
|
| return True
|
| except httpx.RequestError:
|
| pass
|
| await asyncio.sleep(5)
|
| print("[AGENT] WARNING: vLLM did not become ready in time")
|
| return False
|
|
|
|
|
| async def detect_vllm_model():
|
| """Query vLLM to find what model it's actually serving."""
|
| try:
|
| async with httpx.AsyncClient(timeout=10) as client:
|
| resp = await client.get(f"http://localhost:{VLLM_PORT}/v1/models")
|
| if resp.status_code == 200:
|
| data = resp.json()
|
| models = data.get("data", [])
|
| if models:
|
| return models[0].get("id", "")
|
| except Exception:
|
| pass
|
| return VLLM_MODEL
|
|
|
|
|
| def _model_id_from_served_name(served_name: str) -> str:
|
| """Map HuggingFace model name to MAC model_id."""
|
| mapping = {
|
| "Qwen/Qwen2.5-7B-Instruct-AWQ": "qwen2.5:7b",
|
| "Qwen/Qwen2.5-7B-Instruct": "qwen2.5:7b",
|
| "Qwen/Qwen2.5-Coder-7B-Instruct-AWQ": "qwen2.5-coder:7b",
|
| "Qwen/Qwen2.5-Coder-7B-Instruct": "qwen2.5-coder:7b",
|
| "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": "deepseek-r1:14b",
|
| "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B": "deepseek-r1:7b",
|
| "google/gemma-3-27b-it": "gemma3:27b",
|
| }
|
| return mapping.get(served_name, served_name)
|
|
|
|
|
| async def register_model(client: httpx.AsyncClient, node_id: str):
|
| """Register the served model with the control node."""
|
| served_name = await detect_vllm_model()
|
| if not served_name:
|
| print("[AGENT] Could not detect model — skipping registration")
|
| return
|
|
|
| model_id = _model_id_from_served_name(served_name)
|
| print(f"[AGENT] Registering model: {model_id} ({served_name})")
|
|
|
| try:
|
| resp = await client.post(
|
| f"{API}/nodes/register-model/{node_id}",
|
| json={
|
| "model_id": model_id,
|
| "served_name": served_name,
|
| "model_name": served_name.split("/")[-1] if "/" in served_name else served_name,
|
| "vllm_port": VLLM_PORT,
|
| }
|
| )
|
| if resp.status_code == 200:
|
| print(f"[AGENT] Model registered: {resp.json()}")
|
| else:
|
| print(f"[AGENT] Model registration failed: {resp.status_code} {resp.text}")
|
| except httpx.RequestError as e:
|
| print(f"[AGENT] Model registration error: {e}")
|
|
|
|
|
| async def main():
|
| print(f"[AGENT] MAC Worker Agent starting — {NODE_NAME}")
|
| print(f"[AGENT] Control node: {CONTROL_URL}")
|
| if VLLM_MODEL:
|
| print(f"[AGENT] Configured model: {VLLM_MODEL}")
|
|
|
| await wait_for_vllm()
|
|
|
| async with httpx.AsyncClient(timeout=30) as client:
|
|
|
| node_id = None
|
| while not node_id:
|
| node_id = await enroll(client)
|
| if not node_id:
|
| print("[AGENT] Retrying enrollment in 30s...")
|
| await asyncio.sleep(30)
|
|
|
|
|
| await register_model(client, node_id)
|
|
|
|
|
| print(f"[AGENT] Starting heartbeat loop (interval: {HEARTBEAT_INTERVAL}s)")
|
| while True:
|
| await heartbeat_loop(client, node_id)
|
|
|
|
|
| print("[AGENT] Heartbeat loop exited. Re-enrolling...")
|
| node_id = None
|
| while not node_id:
|
| node_id = await enroll(client)
|
| if not node_id:
|
| await asyncio.sleep(30)
|
| await register_model(client, node_id)
|
|
|
|
|
| if __name__ == "__main__":
|
| try:
|
| asyncio.run(main())
|
| except KeyboardInterrupt:
|
| print("[AGENT] Shutting down...")
|
| sys.exit(0)
|
|
|