RamMAC / worker-packages /pc2-coder /worker-agent.py
Aaryan17's picture
feat: upload full MAC source (mac/, frontend/, alembic/, tests/)
9c0b225 verified
#!/usr/bin/env python3
"""
MAC Worker Agent — Enrolls with control node and sends periodic heartbeats.
Runs as a sidecar container alongside vLLM on each GPU worker PC.
"""
import asyncio
import json
import os
import socket
import sys
import time
import httpx
CONTROL_URL = os.environ.get("CONTROL_NODE_URL", "http://192.168.1.100:8000")
ENROLLMENT_TOKEN = os.environ.get("ENROLLMENT_TOKEN", "")
NODE_NAME = os.environ.get("NODE_NAME", f"worker-{socket.gethostname()}")
VLLM_PORT = int(os.environ.get("VLLM_PORT", 8001))
VLLM_MODEL = os.environ.get("VLLM_MODEL", "")
GPU_NAME = os.environ.get("GPU_NAME", "NVIDIA GPU")
GPU_VRAM_MB = int(os.environ.get("GPU_VRAM_MB", 12288))
RAM_TOTAL_MB = int(os.environ.get("RAM_TOTAL_MB", 16384))
CPU_CORES = int(os.environ.get("CPU_CORES", 8))
HEARTBEAT_INTERVAL = int(os.environ.get("HEARTBEAT_INTERVAL", 30))
API = f"{CONTROL_URL}/api/v1"
STATE_FILE = "/tmp/mac_worker_state.json"
def get_local_ip():
"""Get the local IP address visible on the network."""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
return "127.0.0.1"
def load_state():
"""Load saved node ID from previous enrollment."""
try:
with open(STATE_FILE, "r") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return {}
def save_state(data):
"""Save enrollment state."""
with open(STATE_FILE, "w") as f:
json.dump(data, f)
def get_resource_metrics():
"""Collect current resource utilization metrics."""
metrics = {
"cpu_util_pct": 0.0,
"ram_used_mb": 0,
"gpu_util_pct": 0.0,
"gpu_vram_used_mb": 0,
}
try:
import psutil
metrics["cpu_util_pct"] = psutil.cpu_percent(interval=1)
mem = psutil.virtual_memory()
metrics["ram_used_mb"] = int(mem.used / 1024 / 1024)
except ImportError:
pass
# Try nvidia-smi for GPU metrics
try:
import subprocess
result = subprocess.run(
["nvidia-smi", "--query-gpu=utilization.gpu,memory.used",
"--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=5
)
if result.returncode == 0:
parts = result.stdout.strip().split(",")
if len(parts) >= 2:
metrics["gpu_util_pct"] = float(parts[0].strip())
metrics["gpu_vram_used_mb"] = int(float(parts[1].strip()))
except (FileNotFoundError, subprocess.TimeoutExpired, ValueError):
pass
return metrics
async def enroll(client: httpx.AsyncClient) -> str | None:
"""Enroll this node with the control server. Returns node ID."""
state = load_state()
if state.get("node_id"):
print(f"[AGENT] Already enrolled as node {state['node_id']}")
return state["node_id"]
if not ENROLLMENT_TOKEN:
print("[AGENT] ERROR: No ENROLLMENT_TOKEN set. Cannot enroll.")
return None
ip = get_local_ip()
payload = {
"enrollment_token": ENROLLMENT_TOKEN,
"name": NODE_NAME,
"hostname": socket.gethostname(),
"ip_address": ip,
"port": VLLM_PORT,
"gpu_name": GPU_NAME,
"gpu_vram_mb": GPU_VRAM_MB,
"ram_total_mb": RAM_TOTAL_MB,
"cpu_cores": CPU_CORES,
}
try:
resp = await client.post(f"{API}/nodes/enroll", json=payload)
if resp.status_code == 200:
data = resp.json()
node_id = data.get("id")
save_state({"node_id": node_id, "name": NODE_NAME})
print(f"[AGENT] Enrolled successfully! Node ID: {node_id}")
return node_id
else:
print(f"[AGENT] Enrollment failed: {resp.status_code} {resp.text}")
return None
except httpx.RequestError as e:
print(f"[AGENT] Connection error during enrollment: {e}")
return None
async def heartbeat_loop(client: httpx.AsyncClient, node_id: str):
"""Send periodic heartbeats with resource metrics."""
consecutive_failures = 0
max_failures = 10
while True:
try:
metrics = get_resource_metrics()
resp = await client.post(
f"{API}/nodes/heartbeat/{node_id}",
json=metrics
)
if resp.status_code == 200:
data = resp.json()
consecutive_failures = 0
warnings = data.get("warnings", [])
if warnings:
print(f"[AGENT] Resource warnings: {warnings}")
elif resp.status_code == 404:
print("[AGENT] Node not found — re-enrollment needed")
save_state({})
return
else:
consecutive_failures += 1
print(f"[AGENT] Heartbeat failed: {resp.status_code}")
except httpx.RequestError as e:
consecutive_failures += 1
print(f"[AGENT] Heartbeat connection error: {e}")
if consecutive_failures >= max_failures:
print(f"[AGENT] {max_failures} consecutive failures. Waiting 60s before retry.")
await asyncio.sleep(60)
consecutive_failures = 0
else:
await asyncio.sleep(HEARTBEAT_INTERVAL)
async def wait_for_vllm():
"""Wait for local vLLM server to be ready."""
print(f"[AGENT] Waiting for vLLM on port {VLLM_PORT}...")
async with httpx.AsyncClient(timeout=5) as client:
for attempt in range(120): # Wait up to 10 minutes
try:
resp = await client.get(f"http://localhost:{VLLM_PORT}/health")
if resp.status_code == 200:
print("[AGENT] vLLM is ready!")
return True
except httpx.RequestError:
pass
await asyncio.sleep(5)
print("[AGENT] WARNING: vLLM did not become ready in time")
return False
async def detect_vllm_model():
"""Query vLLM to find what model it's actually serving."""
try:
async with httpx.AsyncClient(timeout=10) as client:
resp = await client.get(f"http://localhost:{VLLM_PORT}/v1/models")
if resp.status_code == 200:
data = resp.json()
models = data.get("data", [])
if models:
return models[0].get("id", "")
except Exception:
pass
return VLLM_MODEL
def _model_id_from_served_name(served_name: str) -> str:
"""Map HuggingFace model name to MAC model_id."""
mapping = {
"Qwen/Qwen2.5-7B-Instruct-AWQ": "qwen2.5:7b",
"Qwen/Qwen2.5-7B-Instruct": "qwen2.5:7b",
"Qwen/Qwen2.5-Coder-7B-Instruct-AWQ": "qwen2.5-coder:7b",
"Qwen/Qwen2.5-Coder-7B-Instruct": "qwen2.5-coder:7b",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": "deepseek-r1:14b",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B": "deepseek-r1:7b",
"google/gemma-3-27b-it": "gemma3:27b",
}
return mapping.get(served_name, served_name)
async def register_model(client: httpx.AsyncClient, node_id: str):
"""Register the served model with the control node."""
served_name = await detect_vllm_model()
if not served_name:
print("[AGENT] Could not detect model — skipping registration")
return
model_id = _model_id_from_served_name(served_name)
print(f"[AGENT] Registering model: {model_id} ({served_name})")
try:
resp = await client.post(
f"{API}/nodes/register-model/{node_id}",
json={
"model_id": model_id,
"served_name": served_name,
"model_name": served_name.split("/")[-1] if "/" in served_name else served_name,
"vllm_port": VLLM_PORT,
}
)
if resp.status_code == 200:
print(f"[AGENT] Model registered: {resp.json()}")
else:
print(f"[AGENT] Model registration failed: {resp.status_code} {resp.text}")
except httpx.RequestError as e:
print(f"[AGENT] Model registration error: {e}")
async def main():
print(f"[AGENT] MAC Worker Agent starting — {NODE_NAME}")
print(f"[AGENT] Control node: {CONTROL_URL}")
if VLLM_MODEL:
print(f"[AGENT] Configured model: {VLLM_MODEL}")
await wait_for_vllm()
async with httpx.AsyncClient(timeout=30) as client:
# Enrollment loop
node_id = None
while not node_id:
node_id = await enroll(client)
if not node_id:
print("[AGENT] Retrying enrollment in 30s...")
await asyncio.sleep(30)
# Register model with control node
await register_model(client, node_id)
# Heartbeat loop
print(f"[AGENT] Starting heartbeat loop (interval: {HEARTBEAT_INTERVAL}s)")
while True:
await heartbeat_loop(client, node_id)
# If heartbeat loop exits (node not found), re-enroll
print("[AGENT] Heartbeat loop exited. Re-enrolling...")
node_id = None
while not node_id:
node_id = await enroll(client)
if not node_id:
await asyncio.sleep(30)
await register_model(client, node_id)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("[AGENT] Shutting down...")
sys.exit(0)