RamMAC / worker-agent.py
Aaryan17's picture
feat: upload full MAC source (mac/, frontend/, alembic/, tests/)
9c0b225 verified
#!/usr/bin/env python3
"""
MAC Worker Agent — Enrolls with control node and sends periodic heartbeats.
Runs as a sidecar container alongside vLLM on each GPU worker PC.
"""
import asyncio
import json
import os
import socket
import sys
import time
import httpx
CONTROL_URL = os.environ.get("CONTROL_NODE_URL", "http://192.168.1.100:8000")
ENROLLMENT_TOKEN = os.environ.get("ENROLLMENT_TOKEN", "")
NODE_NAME = os.environ.get("NODE_NAME", f"worker-{socket.gethostname()}")
VLLM_PORT = int(os.environ.get("VLLM_PORT", 8001))
VLLM_MODEL = os.environ.get("VLLM_MODEL", "")
GPU_NAME = os.environ.get("GPU_NAME", "NVIDIA GPU")
GPU_VRAM_MB = int(os.environ.get("GPU_VRAM_MB", 12288))
RAM_TOTAL_MB = int(os.environ.get("RAM_TOTAL_MB", 16384))
CPU_CORES = int(os.environ.get("CPU_CORES", 8))
HEARTBEAT_INTERVAL = int(os.environ.get("HEARTBEAT_INTERVAL", 30))
API = f"{CONTROL_URL}/api/v1"
STATE_FILE = "/tmp/mac_worker_state.json"
def get_local_ip():
"""Get the local IP address visible on the network."""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
return "127.0.0.1"
def load_state():
"""Load saved node ID from previous enrollment."""
try:
with open(STATE_FILE, "r") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return {}
def save_state(data):
"""Save enrollment state."""
with open(STATE_FILE, "w") as f:
json.dump(data, f)
def get_resource_metrics():
"""Collect current resource utilization metrics."""
metrics = {
"cpu_util_pct": 0.0,
"ram_used_mb": 0,
"gpu_util_pct": 0.0,
"gpu_vram_used_mb": 0,
}
try:
import psutil
metrics["cpu_util_pct"] = psutil.cpu_percent(interval=1)
mem = psutil.virtual_memory()
metrics["ram_used_mb"] = int(mem.used / 1024 / 1024)
except ImportError:
pass
# Try nvidia-smi for GPU metrics
try:
import subprocess
result = subprocess.run(
["nvidia-smi", "--query-gpu=utilization.gpu,memory.used",
"--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
parts = result.stdout.strip().split(",")
if len(parts) >= 2:
metrics["gpu_util_pct"] = float(parts[0].strip())
metrics["gpu_vram_used_mb"] = int(float(parts[1].strip()))
except (FileNotFoundError, subprocess.TimeoutExpired, ValueError):
pass
return metrics
async def enroll(client: httpx.AsyncClient) -> str | None:
"""Enroll this node with the control server. Returns node ID."""
state = load_state()
if state.get("node_id"):
print(f"[AGENT] Already enrolled as node {state['node_id']}")
return state["node_id"]
if not ENROLLMENT_TOKEN:
print("[AGENT] ERROR: No ENROLLMENT_TOKEN set. Cannot enroll.")
return None
ip = get_local_ip()
payload = {
"enrollment_token": ENROLLMENT_TOKEN,
"name": NODE_NAME,
"hostname": socket.gethostname(),
"ip_address": ip,
"port": VLLM_PORT,
"gpu_name": GPU_NAME,
"gpu_vram_mb": GPU_VRAM_MB,
"ram_total_mb": RAM_TOTAL_MB,
"cpu_cores": CPU_CORES,
}
try:
resp = await client.post(f"{API}/nodes/enroll", json=payload)
if resp.status_code == 200:
data = resp.json()
node_id = data.get("id")
save_state({"node_id": node_id, "name": NODE_NAME})
print(f"[AGENT] Enrolled successfully! Node ID: {node_id}")
return node_id
else:
print(f"[AGENT] Enrollment failed: {resp.status_code} {resp.text}")
return None
except httpx.RequestError as e:
print(f"[AGENT] Connection error during enrollment: {e}")
return None
async def heartbeat_loop(client: httpx.AsyncClient, node_id: str):
"""Send periodic heartbeats with resource metrics."""
consecutive_failures = 0
max_failures = 10
while True:
try:
metrics = get_resource_metrics()
resp = await client.post(
f"{API}/nodes/heartbeat/{node_id}",
json=metrics
)
if resp.status_code == 200:
data = resp.json()
consecutive_failures = 0
warnings = data.get("warnings", [])
if warnings:
print(f"[AGENT] Resource warnings: {warnings}")
elif resp.status_code == 404:
print("[AGENT] Node not found — re-enrollment needed")
save_state({})
return
else:
consecutive_failures += 1
print(f"[AGENT] Heartbeat failed: {resp.status_code}")
except httpx.RequestError as e:
consecutive_failures += 1
print(f"[AGENT] Heartbeat connection error: {e}")
if consecutive_failures >= max_failures:
print(f"[AGENT] {max_failures} consecutive failures. Waiting 60s before retry.")
await asyncio.sleep(60)
consecutive_failures = 0
else:
await asyncio.sleep(HEARTBEAT_INTERVAL)
async def wait_for_vllm():
"""Wait for local vLLM server to be ready."""
print(f"[AGENT] Waiting for vLLM on port {VLLM_PORT}...")
async with httpx.AsyncClient(timeout=5) as client:
for attempt in range(120): # Wait up to 10 minutes
try:
resp = await client.get(f"http://localhost:{VLLM_PORT}/health")
if resp.status_code == 200:
print("[AGENT] vLLM is ready!")
return True
except httpx.RequestError:
pass
await asyncio.sleep(5)
print("[AGENT] WARNING: vLLM did not become ready in time")
return False
async def detect_vllm_model():
"""Query vLLM to find what model it's actually serving."""
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(f"http://localhost:{VLLM_PORT}/v1/models")
if resp.status_code == 200:
data = resp.json()
models = data.get("data", [])
if models:
return models[0].get("id", "")
except Exception:
pass
return VLLM_MODEL
def _model_id_from_served_name(served_name: str) -> str:
"""Map HuggingFace model name to MAC model_id.
Falls back to served_name itself for community-submitted models."""
mapping = {
"Qwen/Qwen2.5-7B-Instruct-AWQ": "qwen2.5:7b",
"Qwen/Qwen2.5-7B-Instruct": "qwen2.5:7b",
"Qwen/Qwen2.5-Coder-7B-Instruct-AWQ": "qwen2.5-coder:7b",
"Qwen/Qwen2.5-Coder-7B-Instruct": "qwen2.5-coder:7b",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": "deepseek-r1:14b",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B": "deepseek-r1:7b",
"google/gemma-3-27b-it": "gemma3:27b",
}
return mapping.get(served_name, served_name)
async def register_model(client: httpx.AsyncClient, node_id: str):
"""Register the served model with the control node."""
served_name = await detect_vllm_model()
if not served_name:
print("[AGENT] Could not detect model — skipping registration")
return
model_id = _model_id_from_served_name(served_name)
print(f"[AGENT] Registering model: {model_id} ({served_name})")
try:
resp = await client.post(
f"{API}/nodes/register-model/{node_id}",
json={
"model_id": model_id,
"served_name": served_name,
"model_name": served_name.split("/")[-1] if "/" in served_name else served_name,
"vllm_port": VLLM_PORT,
}
)
if resp.status_code == 200:
print(f"[AGENT] Model registered: {resp.json()}")
else:
print(f"[AGENT] Model registration failed: {resp.status_code} {resp.text}")
except httpx.RequestError as e:
print(f"[AGENT] Model registration error: {e}")
DEPLOY_POLL_INTERVAL = int(os.environ.get("DEPLOY_POLL_INTERVAL", 60))
async def poll_pending_deployments(client: httpx.AsyncClient, node_id: str):
"""Poll the control node for new model deployments assigned to this worker.
When found, start a new vLLM instance for each pending model."""
try:
resp = await client.get(f"{API}/nodes/pending-deployments/{node_id}")
if resp.status_code != 200:
return
data = resp.json()
pending = data.get("pending", [])
if not pending:
return
for deploy in pending:
deployment_id = deploy["deployment_id"]
model_id = deploy["model_id"]
vllm_port = deploy["vllm_port"]
gpu_mem_util = deploy.get("gpu_memory_util", 0.85)
max_model_len = deploy.get("max_model_len", 8192)
print(f"[AGENT] New deployment: {model_id} on port {vllm_port}")
try:
# Start a new vLLM instance for this model
import subprocess
cmd = [
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
"--model", model_id,
"--port", str(vllm_port),
"--gpu-memory-utilization", str(gpu_mem_util),
"--max-model-len", str(max_model_len),
"--host", "0.0.0.0",
]
print(f"[AGENT] Starting vLLM: {' '.join(cmd)}")
# Start as a detached process
subprocess.Popen(
cmd,
stdout=open(f"/tmp/vllm_{vllm_port}.log", "w"),
stderr=subprocess.STDOUT,
start_new_session=True,
)
# Wait for the new vLLM instance to be ready
ready = False
for _ in range(120): # 10 minutes
try:
check = await client.get(f"http://localhost:{vllm_port}/health")
if check.status_code == 200:
ready = True
break
except httpx.RequestError:
pass
await asyncio.sleep(5)
if ready:
# Report success
await client.post(
f"{API}/nodes/deployment/{deployment_id}/status",
json={"status": "ready"}
)
print(f"[AGENT] Deployment {model_id} ready on port {vllm_port}")
else:
await client.post(
f"{API}/nodes/deployment/{deployment_id}/status",
json={"status": "failed", "error_message": "vLLM did not start in time"}
)
print(f"[AGENT] Deployment {model_id} FAILED — vLLM timeout")
except Exception as e:
print(f"[AGENT] Error deploying {model_id}: {e}")
try:
await client.post(
f"{API}/nodes/deployment/{deployment_id}/status",
json={"status": "failed", "error_message": str(e)[:500]}
)
except Exception:
pass
except httpx.RequestError as e:
print(f"[AGENT] Deploy poll error: {e}")
async def main():
print(f"[AGENT] MAC Worker Agent starting — {NODE_NAME}")
print(f"[AGENT] Control node: {CONTROL_URL}")
if VLLM_MODEL:
print(f"[AGENT] Configured model: {VLLM_MODEL}")
await wait_for_vllm()
async with httpx.AsyncClient(timeout=30) as client:
# Enrollment loop
node_id = None
while not node_id:
node_id = await enroll(client)
if not node_id:
print("[AGENT] Retrying enrollment in 30s...")
await asyncio.sleep(30)
# Register model with control node
await register_model(client, node_id)
# Combined heartbeat + deploy-poll loop
print(f"[AGENT] Starting heartbeat loop (interval: {HEARTBEAT_INTERVAL}s)")
print(f"[AGENT] Deploy poll interval: {DEPLOY_POLL_INTERVAL}s")
heartbeat_task = asyncio.create_task(heartbeat_loop(client, node_id))
deploy_poll_counter = 0
while True:
if heartbeat_task.done():
# Heartbeat exited (node not found) — re-enroll
print("[AGENT] Heartbeat loop exited. Re-enrolling...")
node_id = None
while not node_id:
node_id = await enroll(client)
if not node_id:
await asyncio.sleep(30)
await register_model(client, node_id)
heartbeat_task = asyncio.create_task(heartbeat_loop(client, node_id))
deploy_poll_counter = 0
# Poll for new deployments periodically
deploy_poll_counter += HEARTBEAT_INTERVAL
if deploy_poll_counter >= DEPLOY_POLL_INTERVAL:
await poll_pending_deployments(client, node_id)
deploy_poll_counter = 0
await asyncio.sleep(HEARTBEAT_INTERVAL)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("[AGENT] Shutting down...")
sys.exit(0)