OpenTransformer commited on
Commit
bafb727
·
verified ·
1 Parent(s): 35a16e0

Add distributed inference harness

Browse files
distributed/inference/agillm35_distributed_infer.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Distributed inference harness for the real AGILLM3.5 transformer blocks.
3
+
4
+ Phase 1 is exact full-sequence AR inference over pipeline stages. Each stage
5
+ owns a contiguous transformer/DiffusionBlock layer range and runs the actual
6
+ AGILLM3.5 Block implementation, including MoE FFNs when enabled by the
7
+ checkpoint config. The coordinator keeps embeddings, final norm, and AR head.
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
13
+ import importlib.util
14
+ import io
15
+ import json
16
+ import math
17
+ import os
18
+ from pathlib import Path
19
+ import shutil
20
+ import ssl
21
+ import struct
22
+ import sys
23
+ import time
24
+ from typing import Any
25
+ from urllib.parse import urlparse
26
+ from urllib.request import Request, urlopen
27
+
28
+
29
+ def load_agillm35(path: str | Path):
30
+ path = Path(path).resolve()
31
+ spec = importlib.util.spec_from_file_location("agillm35_runtime", path)
32
+ if spec is None or spec.loader is None:
33
+ raise RuntimeError(f"cannot import AGILLM3.5 runtime from {path}")
34
+ module = importlib.util.module_from_spec(spec)
35
+ sys.modules.setdefault("agillm35_runtime", module)
36
+ spec.loader.exec_module(module)
37
+ return module
38
+
39
+
40
+ def torch_io():
41
+ import torch
42
+ return torch
43
+
44
+
45
+ def resolve_device(name: str):
46
+ torch = torch_io()
47
+ if name == "auto":
48
+ return "cuda" if torch.cuda.is_available() else "cpu"
49
+ return name
50
+
51
+
52
+ def load_ckpt(runtime: Any, ckpt_path: str | Path) -> dict[str, Any]:
53
+ torch = torch_io()
54
+ path = Path(ckpt_path)
55
+ resolved = path if path.is_file() else (runtime._resolve_ckpt(path) or path)
56
+ sd = torch.load(resolved, map_location="cpu", weights_only=False)
57
+ if sd.get("delta"):
58
+ cfg = runtime.PRESETS["large"].copy()
59
+ sd["cfg"] = cfg
60
+ sd["tie_weights"] = False
61
+ sd["core"] = sd["weights"]["core"]
62
+ sd["ar"] = sd["weights"]["ar"]
63
+ sd["sat"] = sd["weights"].get("sat", {})
64
+ if "nat" in sd["weights"]:
65
+ sd["nat"] = sd["weights"]["nat"]
66
+ if "tokenizer_json" in sd:
67
+ try:
68
+ from tokenizers import Tokenizer as _Tokenizer
69
+ runtime.tok.backend_tokenizer = _Tokenizer.from_str(sd["tokenizer_json"])
70
+ except Exception:
71
+ pass
72
+ return sd
73
+
74
+
75
+ def dblock_ranges(layers: int, blocks: int) -> list[tuple[int, int]]:
76
+ blocks = max(1, int(blocks))
77
+ span = max(1, layers // blocks)
78
+ out = []
79
+ for i in range(blocks):
80
+ start = i * span
81
+ end = (i + 1) * span if i < blocks - 1 else layers
82
+ if start < layers:
83
+ out.append((start, min(end, layers)))
84
+ return out
85
+
86
+
87
+ def make_dense_mask(mode: str, n: int, device: Any, sat_block: int):
88
+ torch = torch_io()
89
+ if mode == "ar":
90
+ return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=device), 1)
91
+ if mode == "sat":
92
+ idx = torch.arange(n, device=device)
93
+ grp = idx.unsqueeze(0) // int(sat_block)
94
+ allow = (grp.T == grp) | (grp.T > grp)
95
+ return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
96
+ if mode == "nat":
97
+ return None
98
+ raise ValueError(f"bad mode {mode!r}")
99
+
100
+
101
+ class StageModule:
102
+ def __init__(
103
+ self,
104
+ runtime: Any,
105
+ sd: dict[str, Any],
106
+ start_layer: int,
107
+ end_layer: int,
108
+ device: str,
109
+ attn_backend: str,
110
+ ):
111
+ torch = torch_io()
112
+ nn = torch.nn
113
+ cfg = sd["cfg"]
114
+ self.runtime = runtime
115
+ self.start_layer = int(start_layer)
116
+ self.end_layer = int(end_layer)
117
+ self.device = torch.device(device)
118
+ self.module = nn.Module()
119
+ self.module.blocks = nn.ModuleList(
120
+ [
121
+ runtime.Block(
122
+ int(cfg["d"]),
123
+ int(cfg["heads"]),
124
+ int(cfg["rank"]),
125
+ attn_backend=attn_backend,
126
+ moe_ffn=bool(cfg.get("moe_ffn", runtime.DEFAULT_MOE_FFN)),
127
+ moe_experts=int(cfg.get("moe_experts", runtime.DEFAULT_MOE_EXPERTS)),
128
+ moe_top_k=int(cfg.get("moe_top_k", runtime.DEFAULT_MOE_TOP_K)),
129
+ moe_mlp_mult=int(cfg.get("moe_mlp_mult", runtime.DEFAULT_MOE_MLP_MULT)),
130
+ )
131
+ for _ in range(self.end_layer - self.start_layer)
132
+ ]
133
+ )
134
+ core_sd = runtime._strip_orig_mod_prefix(sd["core"])
135
+ local_sd = {}
136
+ for local_i, global_i in enumerate(range(self.start_layer, self.end_layer)):
137
+ src_prefix = f"blocks.{global_i}."
138
+ dst_prefix = f"blocks.{local_i}."
139
+ for key, value in core_sd.items():
140
+ if isinstance(key, str) and key.startswith(src_prefix):
141
+ local_sd[dst_prefix + key[len(src_prefix):]] = value
142
+ local_sd = runtime._prepare_core_state_dict_for_load(self.module, local_sd)
143
+ self.module.load_state_dict(local_sd, strict=True)
144
+ self.module.to(self.device)
145
+ self.module.eval()
146
+
147
+ def run(self, hidden: Any, mode: str, sat_block: int) -> tuple[Any, float]:
148
+ torch = torch_io()
149
+ start = time.time()
150
+ x = hidden.to(self.device)
151
+ mask = make_dense_mask(mode, int(x.size(1)), self.device, sat_block)
152
+ with torch.no_grad():
153
+ for block in self.module.blocks:
154
+ x = block(x, mask)
155
+ return x.detach().cpu(), time.time() - start
156
+
157
+
158
+ WIRE_MAGIC = b"AGI35INF1"
159
+
160
+
161
+ def _torch_dtype_name(dtype: Any) -> str:
162
+ text = str(dtype)
163
+ return text.split(".", 1)[1] if text.startswith("torch.") else text
164
+
165
+
166
+ def _torch_dtype_from_name(name: str) -> Any:
167
+ torch = torch_io()
168
+ table = {
169
+ "float64": torch.float64,
170
+ "float32": torch.float32,
171
+ "float16": torch.float16,
172
+ "bfloat16": torch.bfloat16,
173
+ "int64": torch.int64,
174
+ "int32": torch.int32,
175
+ "int16": torch.int16,
176
+ "int8": torch.int8,
177
+ "uint8": torch.uint8,
178
+ "bool": torch.bool,
179
+ }
180
+ if name not in table:
181
+ raise ValueError(f"unsupported tensor dtype over wire: {name}")
182
+ return table[name]
183
+
184
+
185
+ def tensor_payload(data: dict[str, Any]) -> bytes:
186
+ hidden = data["hidden"].detach().cpu().contiguous()
187
+ header = {
188
+ "shape": list(hidden.shape),
189
+ "dtype": _torch_dtype_name(hidden.dtype),
190
+ "meta": {k: v for k, v in data.items() if k != "hidden"},
191
+ }
192
+ if header["dtype"] == "bfloat16":
193
+ raw = hidden.view(torch_io().uint16).numpy().tobytes(order="C")
194
+ else:
195
+ raw = hidden.numpy().tobytes(order="C")
196
+ header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8")
197
+ if len(header_bytes) > 1_000_000:
198
+ raise ValueError("tensor payload header is too large")
199
+ return WIRE_MAGIC + struct.pack(">I", len(header_bytes)) + header_bytes + raw
200
+
201
+
202
+ def tensor_from_payload(data: bytes) -> dict[str, Any]:
203
+ torch = torch_io()
204
+ if len(data) < len(WIRE_MAGIC) + 4 or not data.startswith(WIRE_MAGIC):
205
+ raise ValueError("bad AGILLM35 inference wire payload")
206
+ header_len = struct.unpack(">I", data[len(WIRE_MAGIC):len(WIRE_MAGIC) + 4])[0]
207
+ header_start = len(WIRE_MAGIC) + 4
208
+ header_end = header_start + header_len
209
+ if header_len <= 0 or header_len > 1_000_000 or header_end > len(data):
210
+ raise ValueError("bad AGILLM35 inference wire header")
211
+ header = json.loads(data[header_start:header_end].decode("utf-8"))
212
+ raw = data[header_end:]
213
+ shape = tuple(int(x) for x in header["shape"])
214
+ dtype_name = str(header["dtype"])
215
+ if dtype_name == "bfloat16":
216
+ base = torch.frombuffer(bytearray(raw), dtype=torch.uint16).clone()
217
+ hidden = base.view(torch.bfloat16).reshape(shape)
218
+ else:
219
+ hidden = torch.frombuffer(bytearray(raw), dtype=_torch_dtype_from_name(dtype_name)).clone().reshape(shape)
220
+ out = dict(header.get("meta", {}))
221
+ out["hidden"] = hidden
222
+ return out
223
+
224
+
225
+ def bearer(headers: Any) -> str:
226
+ auth = headers.get("Authorization", "")
227
+ return auth.split(" ", 1)[1].strip() if auth.startswith("Bearer ") else ""
228
+
229
+
230
+ class WorkerHandler(BaseHTTPRequestHandler):
231
+ server_version = "AGILLM35DistributedInferWorker/1"
232
+
233
+ def send_json(self, code: int, data: Any) -> None:
234
+ body = json.dumps(data, indent=2).encode("utf-8")
235
+ self.send_response(code)
236
+ self.send_header("Content-Type", "application/json")
237
+ self.send_header("Content-Length", str(len(body)))
238
+ self.end_headers()
239
+ self.wfile.write(body)
240
+
241
+ def check_auth(self) -> bool:
242
+ token = getattr(self.server, "token", "") # type: ignore[attr-defined]
243
+ if not token:
244
+ return True
245
+ if bearer(self.headers) == token:
246
+ return True
247
+ self.send_json(401, {"error": "bad bearer token"})
248
+ return False
249
+
250
+ def do_GET(self) -> None:
251
+ if self.path == "/health":
252
+ stage = self.server.stage # type: ignore[attr-defined]
253
+ self.send_json(
254
+ 200,
255
+ {
256
+ "ok": True,
257
+ "start_layer": stage.start_layer,
258
+ "end_layer": stage.end_layer,
259
+ "device": str(stage.device),
260
+ },
261
+ )
262
+ return
263
+ self.send_json(404, {"error": "not found"})
264
+
265
+ def do_POST(self) -> None:
266
+ if self.path != "/run":
267
+ self.send_json(404, {"error": "not found"})
268
+ return
269
+ if not self.check_auth():
270
+ return
271
+ n = int(self.headers.get("Content-Length", "0"))
272
+ if n <= 0 or n > int(getattr(self.server, "max_bytes", 2_000_000_000)): # type: ignore[attr-defined]
273
+ self.send_json(413, {"error": "payload too large", "bytes": n})
274
+ return
275
+ payload = tensor_from_payload(self.rfile.read(n))
276
+ hidden, sec = self.server.stage.run( # type: ignore[attr-defined]
277
+ payload["hidden"],
278
+ str(payload.get("mode", "ar")),
279
+ int(payload.get("sat_block", 8)),
280
+ )
281
+ body = tensor_payload(
282
+ {
283
+ "hidden": hidden,
284
+ "stage_sec": sec,
285
+ "start_layer": self.server.stage.start_layer, # type: ignore[attr-defined]
286
+ "end_layer": self.server.stage.end_layer, # type: ignore[attr-defined]
287
+ }
288
+ )
289
+ self.send_response(200)
290
+ self.send_header("Content-Type", "application/octet-stream")
291
+ self.send_header("Content-Length", str(len(body)))
292
+ self.end_headers()
293
+ self.wfile.write(body)
294
+
295
+ def log_message(self, fmt: str, *args: Any) -> None:
296
+ sys.stderr.write("[%s] %s\n" % (time.strftime("%FT%TZ", time.gmtime()), fmt % args))
297
+
298
+
299
+ def cmd_worker(args: argparse.Namespace) -> None:
300
+ runtime = load_agillm35(args.agillm35_path)
301
+ sd = load_ckpt(runtime, args.ckpt)
302
+ args.device = resolve_device(args.device)
303
+ stage = StageModule(runtime, sd, args.start_layer, args.end_layer, args.device, args.attn_backend)
304
+ httpd = ThreadingHTTPServer((args.host, args.port), WorkerHandler)
305
+ httpd.stage = stage # type: ignore[attr-defined]
306
+ httpd.token = args.token # type: ignore[attr-defined]
307
+ httpd.max_bytes = args.max_payload_bytes # type: ignore[attr-defined]
308
+ if args.tls_cert and args.tls_key:
309
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
310
+ ctx.load_cert_chain(args.tls_cert, args.tls_key)
311
+ httpd.socket = ctx.wrap_socket(httpd.socket, server_side=True)
312
+ print(
313
+ json.dumps(
314
+ {
315
+ "event": "worker_ready",
316
+ "bind": [args.host, args.port],
317
+ "layers": [args.start_layer, args.end_layer],
318
+ "device": args.device,
319
+ }
320
+ ),
321
+ flush=True,
322
+ )
323
+ httpd.serve_forever()
324
+
325
+
326
+ class LocalStageClient:
327
+ def __init__(self, stage: StageModule, name: str):
328
+ self.stage = stage
329
+ self.name = name
330
+
331
+ def run(self, hidden: Any, mode: str, sat_block: int) -> tuple[Any, dict[str, Any]]:
332
+ out, sec = self.stage.run(hidden, mode, sat_block)
333
+ return out, {"name": self.name, "sec": sec, "layers": [self.stage.start_layer, self.stage.end_layer]}
334
+
335
+
336
+ class RemoteStageClient:
337
+ def __init__(self, url: str, token: str, name: str, insecure: bool):
338
+ self.url = url.rstrip("/")
339
+ self.token = token
340
+ self.name = name
341
+ self.insecure = insecure
342
+
343
+ def run(self, hidden: Any, mode: str, sat_block: int) -> tuple[Any, dict[str, Any]]:
344
+ payload = tensor_payload({"hidden": hidden.detach().cpu(), "mode": mode, "sat_block": sat_block})
345
+ headers = {"Content-Type": "application/octet-stream"}
346
+ if self.token:
347
+ headers["Authorization"] = f"Bearer {self.token}"
348
+ req = Request(self.url + "/run", data=payload, method="POST", headers=headers)
349
+ context = ssl._create_unverified_context() if self.insecure else None
350
+ start = time.time()
351
+ with urlopen(req, timeout=600, context=context) as r:
352
+ result = tensor_from_payload(r.read())
353
+ wall = time.time() - start
354
+ return result["hidden"], {
355
+ "name": self.name,
356
+ "sec": float(result.get("stage_sec", 0.0)),
357
+ "wall_sec": wall,
358
+ "layers": [result.get("start_layer"), result.get("end_layer")],
359
+ }
360
+
361
+
362
+ def parse_stage_specs(args: argparse.Namespace, runtime: Any, sd: dict[str, Any]) -> list[Any]:
363
+ specs = args.stage or []
364
+ cfg = sd["cfg"]
365
+ if not specs:
366
+ specs = [f"local:0:{int(cfg['layers'])}"]
367
+ out = []
368
+ for idx, spec in enumerate(specs):
369
+ if spec.startswith("local:"):
370
+ _, a, b = spec.split(":", 2)
371
+ stage = StageModule(runtime, sd, int(a), int(b), args.device, args.attn_backend)
372
+ out.append(LocalStageClient(stage, f"local-{a}-{b}"))
373
+ continue
374
+ if "," not in spec:
375
+ raise SystemExit("remote stage syntax: URL,START,END or local:START:END")
376
+ url, a, b = [x.strip() for x in spec.split(",", 2)]
377
+ out.append(RemoteStageClient(url, args.token, f"remote-{idx}-{a}-{b}", args.insecure))
378
+ return out
379
+
380
+
381
+ def restore_heads(runtime: Any, sd: dict[str, Any], device: str):
382
+ torch = torch_io()
383
+ cfg = sd["cfg"]
384
+ tie_weights = bool(sd.get("tie_weights", False))
385
+ emb = torch.nn.Embedding(runtime.VOCAB, int(cfg["d"])).to(device)
386
+ ln = torch.nn.LayerNorm(int(cfg["d"])).to(device)
387
+ core_sd = runtime._strip_orig_mod_prefix(sd["core"])
388
+ emb.weight.data.copy_(core_sd["emb.weight"].to(device))
389
+ ln.load_state_dict({"weight": core_sd["ln.weight"], "bias": core_sd["ln.bias"]})
390
+ ar_h = runtime.ARHead(int(cfg["d"]), tie_weights=tie_weights, embedding_weight=emb.weight if tie_weights else None).to(device)
391
+ ar_h.load_state_dict(sd["ar"])
392
+ emb.eval()
393
+ ln.eval()
394
+ ar_h.eval()
395
+ return emb, ln, ar_h
396
+
397
+
398
+ def cmd_infer(args: argparse.Namespace) -> None:
399
+ torch = torch_io()
400
+ runtime = load_agillm35(args.agillm35_path)
401
+ sd = load_ckpt(runtime, args.ckpt)
402
+ args.device = resolve_device(args.device)
403
+ if bool(sd["cfg"].get("anchor_memory", False)):
404
+ raise SystemExit("distributed phase-1 does not support anchor_memory yet")
405
+ stages = parse_stage_specs(args, runtime, sd)
406
+ emb, ln, ar_h = restore_heads(runtime, sd, args.device)
407
+ prompt_tokens = runtime.tok.encode(args.prompt)
408
+ if not prompt_tokens:
409
+ prompt_tokens = [runtime.EOS]
410
+ ids = torch.tensor([prompt_tokens], dtype=torch.long)
411
+ prompt_len = ids.size(1)
412
+ stage_stats: list[dict[str, Any]] = []
413
+ start = time.time()
414
+ with torch.no_grad():
415
+ for _ in range(int(args.max_new)):
416
+ hidden = emb(ids.to(args.device)).detach().cpu()
417
+ for stage in stages:
418
+ hidden, stat = stage.run(hidden, args.mode, args.sat_block)
419
+ stage_stats.append(stat)
420
+ h = ln(hidden.to(args.device))
421
+ logits = ar_h(h)[:, -1]
422
+ logits = runtime._apply_penalties(
423
+ logits,
424
+ ids.to(args.device),
425
+ args.penalty_last_n,
426
+ args.repetition_penalty,
427
+ args.presence_penalty,
428
+ args.frequency_penalty,
429
+ )
430
+ nxt = runtime._sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
431
+ ids = torch.cat([ids, nxt.detach().cpu()], dim=1)
432
+ elapsed = time.time() - start
433
+ all_ids = ids[0].tolist()
434
+ prompt = runtime.tok.decode(all_ids[:prompt_len], skip_special_tokens=True)
435
+ completion = runtime.tok.decode(all_ids[prompt_len:], skip_special_tokens=True)
436
+ by_stage: dict[str, dict[str, Any]] = {}
437
+ for stat in stage_stats:
438
+ item = by_stage.setdefault(stat["name"], {"calls": 0, "sec": 0.0, "wall_sec": 0.0, "layers": stat.get("layers")})
439
+ item["calls"] += 1
440
+ item["sec"] += float(stat.get("sec", 0.0))
441
+ item["wall_sec"] += float(stat.get("wall_sec", stat.get("sec", 0.0)))
442
+ result = {
443
+ "event": "distributed_infer_done",
444
+ "mode": args.mode,
445
+ "tokens": int(args.max_new),
446
+ "elapsed_sec": round(elapsed, 3),
447
+ "tok_per_sec": round(int(args.max_new) / max(elapsed, 1e-9), 3),
448
+ "stages": by_stage,
449
+ }
450
+ if args.json:
451
+ result["prompt"] = prompt
452
+ result["completion"] = completion
453
+ print(json.dumps(result, indent=2))
454
+ else:
455
+ print(prompt + completion)
456
+ print(json.dumps(result, indent=2))
457
+
458
+
459
+ def cmd_plan(args: argparse.Namespace) -> None:
460
+ runtime = load_agillm35(args.agillm35_path)
461
+ sd = load_ckpt(runtime, args.ckpt)
462
+ layers = int(sd["cfg"]["layers"])
463
+ ranges = dblock_ranges(layers, args.dblock_blocks)
464
+ print(json.dumps({"layers": layers, "dblock_blocks": args.dblock_blocks, "ranges": ranges}, indent=2))
465
+
466
+
467
+ def main() -> int:
468
+ ap = argparse.ArgumentParser(description="AGILLM3.5 distributed transformer/MoE/DiffusionBlock inference")
469
+ sub = ap.add_subparsers(dest="cmd", required=True)
470
+ common = argparse.ArgumentParser(add_help=False)
471
+ common.add_argument("--agillm35-path", default=os.environ.get("AGILLM35_RUNTIME", "./agillm35.py"))
472
+ common.add_argument("--ckpt", required=True)
473
+ common.add_argument("--attn-backend", choices=["manual", "sdpa"], default="manual")
474
+ common.add_argument("--device", default="auto")
475
+
476
+ p = sub.add_parser("plan", parents=[common])
477
+ p.add_argument("--dblock-blocks", type=int, default=8)
478
+ p.set_defaults(func=cmd_plan)
479
+
480
+ p = sub.add_parser("worker", parents=[common])
481
+ p.add_argument("--start-layer", type=int, required=True)
482
+ p.add_argument("--end-layer", type=int, required=True)
483
+ p.add_argument("--host", default="127.0.0.1")
484
+ p.add_argument("--port", type=int, default=9100)
485
+ p.add_argument("--token", default=os.environ.get("AGILLM35_INFER_TOKEN", ""))
486
+ p.add_argument("--max-payload-bytes", type=int, default=2_000_000_000)
487
+ p.add_argument("--tls-cert")
488
+ p.add_argument("--tls-key")
489
+ p.set_defaults(func=cmd_worker)
490
+
491
+ p = sub.add_parser("infer", parents=[common])
492
+ p.add_argument("--prompt", required=True)
493
+ p.add_argument("--max-new", type=int, default=16)
494
+ p.add_argument("--mode", choices=["ar"], default="ar")
495
+ p.add_argument("--stage", action="append", help="local:START:END or URL,START,END. Repeat in pipeline order.")
496
+ p.add_argument("--token", default=os.environ.get("AGILLM35_INFER_TOKEN", ""))
497
+ p.add_argument("--insecure", action="store_true")
498
+ p.add_argument("--temperature", type=float, default=0.7)
499
+ p.add_argument("--greedy", action="store_true")
500
+ p.add_argument("--top-k", type=int, default=0)
501
+ p.add_argument("--top-p", type=float, default=0.9)
502
+ p.add_argument("--min-p", type=float, default=0.0)
503
+ p.add_argument("--repetition-penalty", type=float, default=1.3)
504
+ p.add_argument("--presence-penalty", type=float, default=0.0)
505
+ p.add_argument("--frequency-penalty", type=float, default=0.3)
506
+ p.add_argument("--penalty-last-n", type=int, default=128)
507
+ p.add_argument("--sat-block", type=int, default=8)
508
+ p.add_argument("--json", action="store_true")
509
+ p.set_defaults(func=cmd_infer)
510
+
511
+ args = ap.parse_args()
512
+ args.func(args)
513
+ return 0
514
+
515
+
516
+ if __name__ == "__main__":
517
+ raise SystemExit(main())