Buckets:
| #!/usr/bin/env python3 | |
| """Distributed inference harness for the real AGILLM4.1 transformer blocks. | |
| Phase 1 is exact full-sequence AR inference over pipeline stages. Each stage | |
| owns a contiguous transformer/DiffusionBlock layer range and runs the actual | |
| AGILLM4.1 Block implementation, including MoE FFNs when enabled by the | |
| checkpoint config. The coordinator keeps embeddings, final norm, and AR head. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gc | |
| from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer | |
| import importlib.util | |
| import io | |
| import json | |
| import math | |
| import os | |
| from pathlib import Path | |
| import shutil | |
| import ssl | |
| import struct | |
| import sys | |
| import time | |
| import uuid | |
| from typing import Any | |
| from urllib.parse import urlparse | |
| from urllib.request import Request, urlopen | |
| def load_agillm41(path: str | Path): | |
| path = Path(path).resolve() | |
| spec = importlib.util.spec_from_file_location("agillm41_runtime", path) | |
| if spec is None or spec.loader is None: | |
| raise RuntimeError(f"cannot import AGILLM4.1 runtime from {path}") | |
| module = importlib.util.module_from_spec(spec) | |
| sys.modules.setdefault("agillm41_runtime", module) | |
| sys.modules.setdefault("agillm35_runtime", module) | |
| spec.loader.exec_module(module) | |
| return module | |
| def torch_io(): | |
| import torch | |
| return torch | |
| def resolve_device(name: str): | |
| torch = torch_io() | |
| if name == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return name | |
| def load_ckpt(runtime: Any, ckpt_path: str | Path) -> dict[str, Any]: | |
| torch = torch_io() | |
| path = Path(ckpt_path) | |
| resolved = path if path.is_file() else (runtime._resolve_ckpt(path) or path) | |
| sd = torch.load(resolved, map_location="cpu", weights_only=False) | |
| if sd.get("delta"): | |
| cfg = runtime.PRESETS["large"].copy() | |
| sd["cfg"] = cfg | |
| sd["tie_weights"] = False | |
| sd["core"] = sd["weights"]["core"] | |
| sd["ar"] = sd["weights"]["ar"] | |
| sd["sat"] = sd["weights"].get("sat", {}) | |
| if "nat" in sd["weights"]: | |
| sd["nat"] = sd["weights"]["nat"] | |
| if "tokenizer_json" in sd: | |
| try: | |
| from tokenizers import Tokenizer as _Tokenizer | |
| runtime.tok.backend_tokenizer = _Tokenizer.from_str(sd["tokenizer_json"]) | |
| except Exception: | |
| pass | |
| return sd | |
| def dblock_ranges(layers: int, blocks: int) -> list[tuple[int, int]]: | |
| blocks = max(1, int(blocks)) | |
| span = max(1, layers // blocks) | |
| out = [] | |
| for i in range(blocks): | |
| start = i * span | |
| end = (i + 1) * span if i < blocks - 1 else layers | |
| if start < layers: | |
| out.append((start, min(end, layers))) | |
| return out | |
| def make_dense_mask(mode: str, n: int, device: Any, sat_block: int): | |
| torch = torch_io() | |
| if mode == "ar": | |
| return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=device), 1) | |
| if mode == "sat": | |
| idx = torch.arange(n, device=device) | |
| grp = idx.unsqueeze(0) // int(sat_block) | |
| allow = (grp.T == grp) | (grp.T > grp) | |
| return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0) | |
| if mode == "nat": | |
| return None | |
| raise ValueError(f"bad mode {mode!r}") | |
| def make_cached_mask(mode: str, q_len: int, total_seq_len: int, device: Any, sat_block: int): | |
| torch = torch_io() | |
| if mode == "ar": | |
| if q_len == 1: | |
| return None | |
| k_len = int(total_seq_len) | |
| q_start = k_len - int(q_len) | |
| q_pos = torch.arange(q_start, k_len, device=device).view(q_len, 1) | |
| k_pos = torch.arange(k_len, device=device).view(1, k_len) | |
| blocked = k_pos > q_pos | |
| return torch.where(blocked, float("-inf"), 0.0).view(1, 1, q_len, k_len) | |
| if mode == "nat": | |
| return None | |
| return make_dense_mask(mode, int(total_seq_len), device, sat_block)[..., -int(q_len):, :] | |
| class StageModule: | |
| def __init__( | |
| self, | |
| runtime: Any, | |
| sd: dict[str, Any], | |
| start_layer: int, | |
| end_layer: int, | |
| device: str, | |
| attn_backend: str, | |
| ): | |
| torch = torch_io() | |
| nn = torch.nn | |
| cfg = sd["cfg"] | |
| self.runtime = runtime | |
| self.start_layer = int(start_layer) | |
| self.end_layer = int(end_layer) | |
| self.device = torch.device(device) | |
| self.cache: dict[str, list[Any]] = {} | |
| self.cache_last_used: dict[str, float] = {} | |
| self.max_cache_sessions = 64 | |
| self.module = nn.Module() | |
| self.module.blocks = nn.ModuleList( | |
| [ | |
| runtime.Block( | |
| int(cfg["d"]), | |
| int(cfg["heads"]), | |
| int(cfg["rank"]), | |
| attn_backend=attn_backend, | |
| moe_ffn=bool(cfg.get("moe_ffn", runtime.DEFAULT_MOE_FFN)), | |
| moe_experts=int(cfg.get("moe_experts", runtime.DEFAULT_MOE_EXPERTS)), | |
| moe_top_k=int(cfg.get("moe_top_k", runtime.DEFAULT_MOE_TOP_K)), | |
| moe_mlp_mult=int(cfg.get("moe_mlp_mult", runtime.DEFAULT_MOE_MLP_MULT)), | |
| ) | |
| for _ in range(self.end_layer - self.start_layer) | |
| ] | |
| ) | |
| core_sd = runtime._strip_orig_mod_prefix(sd["core"]) | |
| local_sd = {} | |
| for local_i, global_i in enumerate(range(self.start_layer, self.end_layer)): | |
| src_prefix = f"blocks.{global_i}." | |
| dst_prefix = f"blocks.{local_i}." | |
| for key, value in core_sd.items(): | |
| if isinstance(key, str) and key.startswith(src_prefix): | |
| local_sd[dst_prefix + key[len(src_prefix):]] = value | |
| local_sd = runtime._prepare_core_state_dict_for_load(self.module, local_sd) | |
| self.module.load_state_dict(local_sd, strict=True) | |
| del local_sd, core_sd | |
| gc.collect() | |
| self.module.to(self.device) | |
| self.module.eval() | |
| def run(self, hidden: Any, mode: str, sat_block: int) -> tuple[Any, float]: | |
| torch = torch_io() | |
| start = time.time() | |
| x = hidden.to(self.device) | |
| mask = make_dense_mask(mode, int(x.size(1)), self.device, sat_block) | |
| with torch.no_grad(): | |
| for block in self.module.blocks: | |
| x = block(x, mask) | |
| return x.detach().cpu(), time.time() - start | |
| def _prune_cache(self) -> None: | |
| excess = len(self.cache) - self.max_cache_sessions | |
| if excess <= 0: | |
| return | |
| for session_id, _ in sorted(self.cache_last_used.items(), key=lambda kv: kv[1])[:excess]: | |
| self.cache.pop(session_id, None) | |
| self.cache_last_used.pop(session_id, None) | |
| def clear_cache(self, session_id: str) -> None: | |
| self.cache.pop(session_id, None) | |
| self.cache_last_used.pop(session_id, None) | |
| def run_cached( | |
| self, | |
| hidden: Any, | |
| mode: str, | |
| sat_block: int, | |
| session_id: str, | |
| total_seq_len: int, | |
| reset_cache: bool = False, | |
| ) -> tuple[Any, float]: | |
| torch = torch_io() | |
| start = time.time() | |
| if reset_cache: | |
| self.clear_cache(session_id) | |
| x = hidden.to(self.device) | |
| q_len = int(x.size(1)) | |
| mask = make_cached_mask(mode, q_len, int(total_seq_len), self.device, sat_block) | |
| kvs = self.cache.get(session_id) | |
| if kvs is not None and len(kvs) != len(self.module.blocks): | |
| kvs = None | |
| new_kvs = [] | |
| with torch.no_grad(): | |
| for idx, block in enumerate(self.module.blocks): | |
| kv = None if kvs is None else kvs[idx] | |
| x, new_kv = block(x, mask, kv=kv, use_cache=True, total_seq_len=int(total_seq_len)) | |
| if isinstance(new_kv, tuple): | |
| new_kv = tuple(t.detach() for t in new_kv) | |
| new_kvs.append(new_kv) | |
| self.cache[session_id] = new_kvs | |
| self.cache_last_used[session_id] = time.time() | |
| self._prune_cache() | |
| return x.detach().cpu(), time.time() - start | |
| WIRE_MAGIC = b"AGI35INF1" | |
| def _torch_dtype_name(dtype: Any) -> str: | |
| text = str(dtype) | |
| return text.split(".", 1)[1] if text.startswith("torch.") else text | |
| def _torch_dtype_from_name(name: str) -> Any: | |
| torch = torch_io() | |
| table = { | |
| "float64": torch.float64, | |
| "float32": torch.float32, | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "int64": torch.int64, | |
| "int32": torch.int32, | |
| "int16": torch.int16, | |
| "int8": torch.int8, | |
| "uint8": torch.uint8, | |
| "bool": torch.bool, | |
| } | |
| if name not in table: | |
| raise ValueError(f"unsupported tensor dtype over wire: {name}") | |
| return table[name] | |
| def tensor_payload(data: dict[str, Any]) -> bytes: | |
| hidden = data["hidden"].detach().cpu().contiguous() | |
| header = { | |
| "shape": list(hidden.shape), | |
| "dtype": _torch_dtype_name(hidden.dtype), | |
| "meta": {k: v for k, v in data.items() if k != "hidden"}, | |
| } | |
| if header["dtype"] == "bfloat16": | |
| raw = hidden.view(torch_io().uint16).numpy().tobytes(order="C") | |
| else: | |
| raw = hidden.numpy().tobytes(order="C") | |
| header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8") | |
| if len(header_bytes) > 1_000_000: | |
| raise ValueError("tensor payload header is too large") | |
| return WIRE_MAGIC + struct.pack(">I", len(header_bytes)) + header_bytes + raw | |
| def tensor_from_payload(data: bytes) -> dict[str, Any]: | |
| torch = torch_io() | |
| if len(data) < len(WIRE_MAGIC) + 4 or not data.startswith(WIRE_MAGIC): | |
| raise ValueError("bad AGILLM41/AGILLM35 inference wire payload") | |
| header_len = struct.unpack(">I", data[len(WIRE_MAGIC):len(WIRE_MAGIC) + 4])[0] | |
| header_start = len(WIRE_MAGIC) + 4 | |
| header_end = header_start + header_len | |
| if header_len <= 0 or header_len > 1_000_000 or header_end > len(data): | |
| raise ValueError("bad AGILLM41/AGILLM35 inference wire header") | |
| header = json.loads(data[header_start:header_end].decode("utf-8")) | |
| raw = data[header_end:] | |
| shape = tuple(int(x) for x in header["shape"]) | |
| dtype_name = str(header["dtype"]) | |
| if dtype_name == "bfloat16": | |
| base = torch.frombuffer(bytearray(raw), dtype=torch.uint16).clone() | |
| hidden = base.view(torch.bfloat16).reshape(shape) | |
| else: | |
| hidden = torch.frombuffer(bytearray(raw), dtype=_torch_dtype_from_name(dtype_name)).clone().reshape(shape) | |
| out = dict(header.get("meta", {})) | |
| out["hidden"] = hidden | |
| return out | |
| def bearer(headers: Any) -> str: | |
| auth = headers.get("Authorization", "") | |
| return auth.split(" ", 1)[1].strip() if auth.startswith("Bearer ") else "" | |
| class WorkerHandler(BaseHTTPRequestHandler): | |
| server_version = "AGILLM41DistributedInferWorker/1" | |
| def send_json(self, code: int, data: Any) -> None: | |
| body = json.dumps(data, indent=2).encode("utf-8") | |
| self.send_response(code) | |
| self.send_header("Content-Type", "application/json") | |
| self.send_header("Content-Length", str(len(body))) | |
| self.end_headers() | |
| self.wfile.write(body) | |
| def check_auth(self) -> bool: | |
| token = getattr(self.server, "token", "") # type: ignore[attr-defined] | |
| if not token: | |
| return True | |
| if bearer(self.headers) == token: | |
| return True | |
| self.send_json(401, {"error": "bad bearer token"}) | |
| return False | |
| def do_GET(self) -> None: | |
| if self.path == "/health": | |
| stage = self.server.stage # type: ignore[attr-defined] | |
| self.send_json( | |
| 200, | |
| { | |
| "ok": True, | |
| "start_layer": stage.start_layer, | |
| "end_layer": stage.end_layer, | |
| "device": str(stage.device), | |
| }, | |
| ) | |
| return | |
| self.send_json(404, {"error": "not found"}) | |
| def do_POST(self) -> None: | |
| if self.path != "/run": | |
| self.send_json(404, {"error": "not found"}) | |
| return | |
| if not self.check_auth(): | |
| return | |
| n = int(self.headers.get("Content-Length", "0")) | |
| if n <= 0 or n > int(getattr(self.server, "max_bytes", 2_000_000_000)): # type: ignore[attr-defined] | |
| self.send_json(413, {"error": "payload too large", "bytes": n}) | |
| return | |
| payload = tensor_from_payload(self.rfile.read(n)) | |
| if bool(payload.get("use_cache", False)): | |
| hidden, sec = self.server.stage.run_cached( # type: ignore[attr-defined] | |
| payload["hidden"], | |
| str(payload.get("mode", "ar")), | |
| int(payload.get("sat_block", 8)), | |
| str(payload.get("session_id", "")), | |
| int(payload.get("total_seq_len", int(payload["hidden"].size(1)))), | |
| bool(payload.get("reset_cache", False)), | |
| ) | |
| else: | |
| hidden, sec = self.server.stage.run( # type: ignore[attr-defined] | |
| payload["hidden"], | |
| str(payload.get("mode", "ar")), | |
| int(payload.get("sat_block", 8)), | |
| ) | |
| body = tensor_payload( | |
| { | |
| "hidden": hidden, | |
| "stage_sec": sec, | |
| "start_layer": self.server.stage.start_layer, # type: ignore[attr-defined] | |
| "end_layer": self.server.stage.end_layer, # type: ignore[attr-defined] | |
| } | |
| ) | |
| self.send_response(200) | |
| self.send_header("Content-Type", "application/octet-stream") | |
| self.send_header("Content-Length", str(len(body))) | |
| self.end_headers() | |
| self.wfile.write(body) | |
| def log_message(self, fmt: str, *args: Any) -> None: | |
| sys.stderr.write("[%s] %s\n" % (time.strftime("%FT%TZ", time.gmtime()), fmt % args)) | |
| def cmd_worker(args: argparse.Namespace) -> None: | |
| runtime = load_agillm41(args.agillm35_path) | |
| sd = load_ckpt(runtime, args.ckpt) | |
| args.device = resolve_device(args.device) | |
| stage = StageModule(runtime, sd, args.start_layer, args.end_layer, args.device, args.attn_backend) | |
| del sd | |
| gc.collect() | |
| ThreadingHTTPServer.allow_reuse_address = True # avoid TIME_WAIT bind failures on relaunch | |
| httpd = ThreadingHTTPServer((args.host, args.port), WorkerHandler) | |
| httpd.stage = stage # type: ignore[attr-defined] | |
| httpd.token = args.token # type: ignore[attr-defined] | |
| httpd.max_bytes = args.max_payload_bytes # type: ignore[attr-defined] | |
| if args.tls_cert and args.tls_key: | |
| ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) | |
| ctx.load_cert_chain(args.tls_cert, args.tls_key) | |
| httpd.socket = ctx.wrap_socket(httpd.socket, server_side=True) | |
| print( | |
| json.dumps( | |
| { | |
| "event": "worker_ready", | |
| "bind": [args.host, args.port], | |
| "layers": [args.start_layer, args.end_layer], | |
| "device": args.device, | |
| } | |
| ), | |
| flush=True, | |
| ) | |
| httpd.serve_forever() | |
| class LocalStageClient: | |
| def __init__(self, stage: StageModule, name: str): | |
| self.stage = stage | |
| self.name = name | |
| def run(self, hidden: Any, mode: str, sat_block: int) -> tuple[Any, dict[str, Any]]: | |
| out, sec = self.stage.run(hidden, mode, sat_block) | |
| return out, {"name": self.name, "sec": sec, "layers": [self.stage.start_layer, self.stage.end_layer]} | |
| def run_cached( | |
| self, | |
| hidden: Any, | |
| mode: str, | |
| sat_block: int, | |
| session_id: str, | |
| total_seq_len: int, | |
| reset_cache: bool, | |
| ) -> tuple[Any, dict[str, Any]]: | |
| out, sec = self.stage.run_cached(hidden, mode, sat_block, session_id, total_seq_len, reset_cache) | |
| return out, {"name": self.name, "sec": sec, "layers": [self.stage.start_layer, self.stage.end_layer], "cached": True} | |
| class RemoteStageClient: | |
| def __init__(self, url: str, token: str, name: str, insecure: bool): | |
| self.url = url.rstrip("/") | |
| self.token = token | |
| self.name = name | |
| self.insecure = insecure | |
| def run(self, hidden: Any, mode: str, sat_block: int) -> tuple[Any, dict[str, Any]]: | |
| payload = tensor_payload({"hidden": hidden.detach().cpu(), "mode": mode, "sat_block": sat_block}) | |
| headers = {"Content-Type": "application/octet-stream"} | |
| if self.token: | |
| headers["Authorization"] = f"Bearer {self.token}" | |
| req = Request(self.url + "/run", data=payload, method="POST", headers=headers) | |
| context = ssl._create_unverified_context() if self.insecure else None | |
| start = time.time() | |
| with urlopen(req, timeout=600, context=context) as r: | |
| result = tensor_from_payload(r.read()) | |
| wall = time.time() - start | |
| return result["hidden"], { | |
| "name": self.name, | |
| "sec": float(result.get("stage_sec", 0.0)), | |
| "wall_sec": wall, | |
| "layers": [result.get("start_layer"), result.get("end_layer")], | |
| } | |
| def run_cached( | |
| self, | |
| hidden: Any, | |
| mode: str, | |
| sat_block: int, | |
| session_id: str, | |
| total_seq_len: int, | |
| reset_cache: bool, | |
| ) -> tuple[Any, dict[str, Any]]: | |
| payload = tensor_payload( | |
| { | |
| "hidden": hidden.detach().cpu(), | |
| "mode": mode, | |
| "sat_block": sat_block, | |
| "use_cache": True, | |
| "session_id": session_id, | |
| "total_seq_len": int(total_seq_len), | |
| "reset_cache": bool(reset_cache), | |
| } | |
| ) | |
| headers = {"Content-Type": "application/octet-stream"} | |
| if self.token: | |
| headers["Authorization"] = f"Bearer {self.token}" | |
| req = Request(self.url + "/run", data=payload, method="POST", headers=headers) | |
| context = ssl._create_unverified_context() if self.insecure else None | |
| start = time.time() | |
| with urlopen(req, timeout=600, context=context) as r: | |
| result = tensor_from_payload(r.read()) | |
| wall = time.time() - start | |
| return result["hidden"], { | |
| "name": self.name, | |
| "sec": float(result.get("stage_sec", 0.0)), | |
| "wall_sec": wall, | |
| "layers": [result.get("start_layer"), result.get("end_layer")], | |
| "cached": True, | |
| } | |
| def parse_stage_specs(args: argparse.Namespace, runtime: Any, sd: dict[str, Any]) -> list[Any]: | |
| specs = args.stage or [] | |
| cfg = sd["cfg"] | |
| if not specs: | |
| specs = [f"local:0:{int(cfg['layers'])}"] | |
| out = [] | |
| for idx, spec in enumerate(specs): | |
| if spec.startswith("local:"): | |
| _, a, b = spec.split(":", 2) | |
| stage = StageModule(runtime, sd, int(a), int(b), args.device, args.attn_backend) | |
| out.append(LocalStageClient(stage, f"local-{a}-{b}")) | |
| continue | |
| if "," not in spec: | |
| raise SystemExit("remote stage syntax: URL,START,END or local:START:END") | |
| url, a, b = [x.strip() for x in spec.split(",", 2)] | |
| out.append(RemoteStageClient(url, args.token, f"remote-{idx}-{a}-{b}", args.insecure)) | |
| return out | |
| def restore_heads(runtime: Any, sd: dict[str, Any], device: str): | |
| torch = torch_io() | |
| cfg = sd["cfg"] | |
| tie_weights = bool(sd.get("tie_weights", False)) | |
| emb = torch.nn.Embedding(runtime.VOCAB, int(cfg["d"])).to(device) | |
| ln = torch.nn.LayerNorm(int(cfg["d"])).to(device) | |
| core_sd = runtime._strip_orig_mod_prefix(sd["core"]) | |
| emb.weight.data.copy_(core_sd["emb.weight"].to(device)) | |
| ln.load_state_dict({"weight": core_sd["ln.weight"], "bias": core_sd["ln.bias"]}) | |
| ar_h = runtime.ARHead(int(cfg["d"]), tie_weights=tie_weights, embedding_weight=emb.weight if tie_weights else None).to(device) | |
| ar_h.load_state_dict(sd["ar"]) | |
| emb.eval() | |
| ln.eval() | |
| ar_h.eval() | |
| return emb, ln, ar_h | |
| def run_stage_pipeline( | |
| stages: list[Any], | |
| hidden: Any, | |
| args: argparse.Namespace, | |
| use_cache: bool = False, | |
| session_id: str = "", | |
| total_seq_len: int = 0, | |
| reset_cache: bool = False, | |
| ) -> tuple[Any, list[dict[str, Any]]]: | |
| stats = [] | |
| for stage in stages: | |
| if use_cache: | |
| hidden, stat = stage.run_cached( | |
| hidden, | |
| args.mode, | |
| args.sat_block, | |
| session_id, | |
| int(total_seq_len), | |
| bool(reset_cache), | |
| ) | |
| else: | |
| hidden, stat = stage.run(hidden, args.mode, args.sat_block) | |
| stats.append(stat) | |
| return hidden, stats | |
| def sample_next(runtime: Any, ar_h: Any, hidden: Any, ids: Any, args: argparse.Namespace) -> Any: | |
| logits = ar_h(hidden)[:, -1] | |
| logits = runtime._apply_penalties( | |
| logits, | |
| ids.to(logits.device), | |
| args.penalty_last_n, | |
| args.repetition_penalty, | |
| args.presence_penalty, | |
| args.frequency_penalty, | |
| ) | |
| return runtime._sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy) | |
| def cmd_infer(args: argparse.Namespace) -> None: | |
| torch = torch_io() | |
| runtime = load_agillm41(args.agillm35_path) | |
| sd = load_ckpt(runtime, args.ckpt) | |
| args.device = resolve_device(args.device) | |
| if bool(sd["cfg"].get("anchor_memory", False)): | |
| raise SystemExit("distributed phase-1 does not support anchor_memory yet") | |
| stages = parse_stage_specs(args, runtime, sd) | |
| emb, ln, ar_h = restore_heads(runtime, sd, args.device) | |
| del sd | |
| gc.collect() | |
| prompt_tokens = runtime.tok.encode(args.prompt) | |
| if not prompt_tokens: | |
| prompt_tokens = [runtime.EOS] | |
| ids = torch.tensor([prompt_tokens], dtype=torch.long) | |
| prompt_len = ids.size(1) | |
| stage_stats: list[dict[str, Any]] = [] | |
| session_id = args.session_id or f"agillm41-{uuid.uuid4().hex}" | |
| eos_id = getattr(runtime, "EOS", None) | |
| generated_tokens = 0 | |
| start = time.time() | |
| with torch.no_grad(): | |
| if args.cache_mode == "kv": | |
| hidden = emb(ids.to(args.device)).detach().cpu() | |
| hidden, stats = run_stage_pipeline( | |
| stages, | |
| hidden, | |
| args, | |
| use_cache=True, | |
| session_id=session_id, | |
| total_seq_len=int(ids.size(1)), | |
| reset_cache=True, | |
| ) | |
| stage_stats.extend(stats) | |
| for step in range(int(args.max_new)): | |
| h = ln(hidden.to(args.device)) | |
| nxt = sample_next(runtime, ar_h, h, ids, args) | |
| ids = torch.cat([ids, nxt.detach().cpu()], dim=1) | |
| generated_tokens += 1 | |
| if eos_id is not None and int(nxt.reshape(-1)[0].item()) == int(eos_id): | |
| break | |
| if step + 1 >= int(args.max_new): | |
| break | |
| hidden = emb(nxt.to(args.device)).detach().cpu() | |
| hidden, stats = run_stage_pipeline( | |
| stages, | |
| hidden, | |
| args, | |
| use_cache=True, | |
| session_id=session_id, | |
| total_seq_len=int(ids.size(1)), | |
| reset_cache=False, | |
| ) | |
| stage_stats.extend(stats) | |
| else: | |
| for _ in range(int(args.max_new)): | |
| hidden = emb(ids.to(args.device)).detach().cpu() | |
| hidden, stats = run_stage_pipeline(stages, hidden, args, use_cache=False) | |
| stage_stats.extend(stats) | |
| h = ln(hidden.to(args.device)) | |
| nxt = sample_next(runtime, ar_h, h, ids, args) | |
| ids = torch.cat([ids, nxt.detach().cpu()], dim=1) | |
| generated_tokens += 1 | |
| if eos_id is not None and int(nxt.reshape(-1)[0].item()) == int(eos_id): | |
| break | |
| elapsed = time.time() - start | |
| all_ids = ids[0].tolist() | |
| prompt = runtime.tok.decode(all_ids[:prompt_len], skip_special_tokens=True) | |
| completion = runtime.tok.decode(all_ids[prompt_len:], skip_special_tokens=True) | |
| by_stage: dict[str, dict[str, Any]] = {} | |
| for stat in stage_stats: | |
| item = by_stage.setdefault(stat["name"], {"calls": 0, "sec": 0.0, "wall_sec": 0.0, "layers": stat.get("layers")}) | |
| item["calls"] += 1 | |
| item["sec"] += float(stat.get("sec", 0.0)) | |
| item["wall_sec"] += float(stat.get("wall_sec", stat.get("sec", 0.0))) | |
| result = { | |
| "event": "distributed_infer_done", | |
| "mode": args.mode, | |
| "cache_mode": args.cache_mode, | |
| "session_id": session_id if args.cache_mode == "kv" else None, | |
| "tokens": generated_tokens, | |
| "elapsed_sec": round(elapsed, 3), | |
| "tok_per_sec": round(generated_tokens / max(elapsed, 1e-9), 3), | |
| "stages": by_stage, | |
| } | |
| if args.json: | |
| result["prompt"] = prompt | |
| result["completion"] = completion | |
| print(json.dumps(result, indent=2)) | |
| else: | |
| print(prompt + completion) | |
| print(json.dumps(result, indent=2)) | |
| def cmd_plan(args: argparse.Namespace) -> None: | |
| runtime = load_agillm41(args.agillm35_path) | |
| sd = load_ckpt(runtime, args.ckpt) | |
| layers = int(sd["cfg"]["layers"]) | |
| ranges = dblock_ranges(layers, args.dblock_blocks) | |
| print(json.dumps({"layers": layers, "dblock_blocks": args.dblock_blocks, "ranges": ranges}, indent=2)) | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description="AGILLM4.1 distributed transformer/MoE/DiffusionBlock inference") | |
| sub = ap.add_subparsers(dest="cmd", required=True) | |
| common = argparse.ArgumentParser(add_help=False) | |
| common.add_argument( | |
| "--agillm35-path", | |
| "--agillm41-path", | |
| dest="agillm35_path", | |
| default=os.environ.get("AGILLM41_RUNTIME") or os.environ.get("AGILLM35_RUNTIME", "./agillm41.py"), | |
| ) | |
| common.add_argument("--ckpt", required=True) | |
| common.add_argument("--attn-backend", choices=["manual", "sdpa"], default="manual") | |
| common.add_argument("--device", default="auto") | |
| p = sub.add_parser("plan", parents=[common]) | |
| p.add_argument("--dblock-blocks", type=int, default=8) | |
| p.set_defaults(func=cmd_plan) | |
| p = sub.add_parser("worker", parents=[common]) | |
| p.add_argument("--start-layer", type=int, required=True) | |
| p.add_argument("--end-layer", type=int, required=True) | |
| p.add_argument("--host", default="127.0.0.1") | |
| p.add_argument("--port", type=int, default=9100) | |
| p.add_argument("--token", default=os.environ.get("AGILLM41_INFER_TOKEN") or os.environ.get("AGILLM35_INFER_TOKEN", "")) | |
| p.add_argument("--max-payload-bytes", type=int, default=2_000_000_000) | |
| p.add_argument("--tls-cert") | |
| p.add_argument("--tls-key") | |
| p.set_defaults(func=cmd_worker) | |
| p = sub.add_parser("infer", parents=[common]) | |
| p.add_argument("--prompt", required=True) | |
| p.add_argument("--max-new", type=int, default=16) | |
| p.add_argument("--mode", choices=["ar"], default="ar") | |
| p.add_argument("--cache-mode", choices=["kv", "full"], default="kv") | |
| p.add_argument("--session-id", default="") | |
| p.add_argument("--stage", action="append", help="local:START:END or URL,START,END. Repeat in pipeline order.") | |
| p.add_argument("--token", default=os.environ.get("AGILLM41_INFER_TOKEN") or os.environ.get("AGILLM35_INFER_TOKEN", "")) | |
| p.add_argument("--insecure", action="store_true") | |
| p.add_argument("--temperature", type=float, default=0.7) | |
| p.add_argument("--greedy", action="store_true") | |
| p.add_argument("--top-k", type=int, default=0) | |
| p.add_argument("--top-p", type=float, default=0.9) | |
| p.add_argument("--min-p", type=float, default=0.0) | |
| p.add_argument("--repetition-penalty", type=float, default=1.3) | |
| p.add_argument("--presence-penalty", type=float, default=0.0) | |
| p.add_argument("--frequency-penalty", type=float, default=0.3) | |
| p.add_argument("--penalty-last-n", type=int, default=128) | |
| p.add_argument("--sat-block", type=int, default=8) | |
| p.add_argument("--json", action="store_true") | |
| p.set_defaults(func=cmd_infer) | |
| args = ap.parse_args() | |
| args.func(args) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |
Xet Storage Details
- Size:
- 28.1 kB
- Xet hash:
- 9574511387e95489de56291e6cc89c122407047fdd067e161e41f9a90884c096
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.