File size: 8,692 Bytes
0e76632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""
MAC Worker Agent
================
Run this on every worker PC to join it to the MAC cluster.

Usage:
    pip install httpx psutil pynvml
    python worker_agent.py

Environment variables (or edit DEFAULTS below):
    MAC_MASTER_URL      e.g. http://192.168.1.100:8000
    MAC_ENROLL_TOKEN    one-time token from admin (/cluster/enroll-token)
    MAC_WORKER_NAME     display name for this node (default: hostname)
    MAC_VLLM_PORT       port where vLLM is running (default: 8001)
    MAC_NOTEBOOK_PORT   port for Jupyter kernel gateway (optional)
    MAC_TAGS            comma-separated: llm,notebook,embedding
    MAC_HEARTBEAT_SEC   heartbeat interval seconds (default: 10)
"""

import os
import sys
import time
import socket
import asyncio
import logging
import subprocess
from typing import Optional

import httpx

log = logging.getLogger("mac-worker")
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%H:%M:%S",
)

# ── Configuration ─────────────────────────────────────────────────────────────

MASTER_URL     = os.environ.get("MAC_MASTER_URL", "http://localhost:8000").rstrip("/")
ENROLL_TOKEN   = os.environ.get("MAC_ENROLL_TOKEN", "")
WORKER_NAME    = os.environ.get("MAC_WORKER_NAME", socket.gethostname())
VLLM_PORT      = int(os.environ.get("MAC_VLLM_PORT", "8001"))
NOTEBOOK_PORT  = int(os.environ.get("MAC_NOTEBOOK_PORT", "0")) or None
TAGS           = os.environ.get("MAC_TAGS", "llm")
HEARTBEAT_SEC  = int(os.environ.get("MAC_HEARTBEAT_SEC", "10"))

API = f"{MASTER_URL}/api/v1"


# ── System metrics ────────────────────────────────────────────────────────────

def _local_ip() -> str:
    """Best-guess local LAN IP."""
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        s.connect(("8.8.8.8", 80))
        ip = s.getsockname()[0]
        s.close()
        return ip
    except OSError:
        return "127.0.0.1"


def _gpu_metrics() -> dict:
    """Read GPU metrics via pynvml (optional β€” degrades gracefully if not installed)."""
    try:
        import pynvml
        pynvml.nvmlInit()
        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
        util = pynvml.nvmlDeviceGetUtilizationRates(handle)
        mem  = pynvml.nvmlDeviceGetMemoryInfo(handle)
        name = pynvml.nvmlDeviceGetName(handle)
        if isinstance(name, bytes):
            name = name.decode()
        return {
            "gpu_name":        name,
            "gpu_vram_mb":     mem.total // (1024 * 1024),
            "gpu_util_pct":    float(util.gpu),
            "gpu_vram_used_mb": mem.used // (1024 * 1024),
        }
    except Exception:
        return {}


def _cpu_ram_metrics() -> dict:
    """Read CPU / RAM metrics via psutil (optional)."""
    try:
        import psutil
        cpu   = psutil.cpu_percent(interval=None)
        ram   = psutil.virtual_memory()
        cores = psutil.cpu_count(logical=False) or psutil.cpu_count()
        return {
            "cpu_util_pct":  float(cpu),
            "ram_total_mb":  ram.total // (1024 * 1024),
            "ram_used_mb":   ram.used  // (1024 * 1024),
            "cpu_cores":     cores,
        }
    except Exception:
        return {}


def _active_models(vllm_port: int) -> list[str]:
    """Query local vLLM /v1/models to find what's loaded."""
    try:
        import httpx as _httpx
        r = _httpx.get(f"http://localhost:{vllm_port}/v1/models", timeout=3)
        if r.status_code == 200:
            data = r.json().get("data", [])
            return [m["id"] for m in data]
    except Exception:
        pass
    return []


# ── Registration ──────────────────────────────────────────────────────────────

async def register(client: httpx.AsyncClient) -> tuple[str, str]:
    """Register with master. Returns (node_id, node_token)."""
    if not ENROLL_TOKEN:
        log.error("MAC_ENROLL_TOKEN is not set. Get one from the admin panel.")
        sys.exit(1)

    gpu = _gpu_metrics()
    sys_info = _cpu_ram_metrics()
    ip = _local_ip()

    payload = {
        "enrollment_token": ENROLL_TOKEN,
        "name":             WORKER_NAME,
        "hostname":         socket.gethostname(),
        "ip_address":       ip,
        "port":             VLLM_PORT,
        "notebook_port":    NOTEBOOK_PORT,
        "tags":             TAGS,
        **{k: gpu.get(k) for k in ("gpu_name", "gpu_vram_mb")},
        **{k: sys_info.get(k) for k in ("ram_total_mb", "cpu_cores")},
    }

    log.info(f"Registering with master at {API} as '{WORKER_NAME}' ({ip}:{VLLM_PORT}) ...")
    resp = await client.post(f"{API}/cluster/register", json=payload, timeout=15)

    if resp.status_code == 401:
        log.error("Invalid or expired enrollment token. Generate a new one from the admin panel.")
        sys.exit(1)

    resp.raise_for_status()
    data = resp.json()
    node_id = data["node_id"]
    node_token = data.get("node_token") or _node_token_from_env()

    log.info(f"Registered: node_id={node_id}  status={data['status']}")
    if data["status"] == "pending":
        log.info("Waiting for admin to approve this node in the MAC admin panel ...")
    return node_id, node_token


def _node_token_from_env() -> str:
    """Fallback: read token from env (for re-registration after reboot)."""
    t = os.environ.get("MAC_NODE_TOKEN", "")
    if not t:
        log.error(
            "No node_token returned from master and MAC_NODE_TOKEN not set.\n"
            "On first registration the token is the sha256 of your enrollment token.\n"
            "Set MAC_NODE_TOKEN in your environment after the first run."
        )
        sys.exit(1)
    return t


# ── Heartbeat loop ────────────────────────────────────────────────────────────

async def heartbeat_loop(client: httpx.AsyncClient, node_id: str, node_token: str):
    """Send heartbeats every HEARTBEAT_SEC seconds."""
    consecutive_failures = 0

    while True:
        try:
            gpu = _gpu_metrics()
            sys_info = _cpu_ram_metrics()
            models = _active_models(VLLM_PORT)

            payload = {
                "node_id":        node_id,
                "node_token":     node_token,
                "active_models":  models,
                "queue_depth":    0,
                **{k: gpu.get(k) for k in ("gpu_util_pct", "gpu_vram_used_mb")},
                **{k: sys_info.get(k) for k in ("cpu_util_pct", "ram_used_mb")},
            }

            resp = await client.post(f"{API}/cluster/heartbeat", json=payload, timeout=10)

            if resp.status_code == 403:
                # Not yet approved β€” keep trying silently
                await asyncio.sleep(HEARTBEAT_SEC * 3)
                continue

            if resp.status_code == 401:
                log.error("Heartbeat auth failed β€” node_token may be wrong.")
                sys.exit(1)

            resp.raise_for_status()
            data = resp.json()
            if data.get("status") != "active":
                log.warning(f"Node status: {data.get('status')}")

            consecutive_failures = 0

        except httpx.RequestError as e:
            consecutive_failures += 1
            log.warning(f"Heartbeat failed ({consecutive_failures}): {e}")
            if consecutive_failures >= 12:
                log.error("Master unreachable for 2 minutes. Check network.")

        except Exception as e:
            log.warning(f"Heartbeat error: {e}")

        await asyncio.sleep(HEARTBEAT_SEC)


# ── Entry point ───────────────────────────────────────────────────────────────

async def main():
    log.info(f"MAC Worker Agent starting β€” master={MASTER_URL}  name={WORKER_NAME}")
    log.info(f"vLLM port={VLLM_PORT}  notebook_port={NOTEBOOK_PORT}  tags={TAGS}")

    async with httpx.AsyncClient() as client:
        node_id, node_token = await register(client)
        log.info("Starting heartbeat loop ...")
        await heartbeat_loop(client, node_id, node_token)


if __name__ == "__main__":
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        log.info("Worker agent stopped.")