Buckets:
| #!/usr/bin/env python3 | |
| """Batch planner and package exporter for AGILLM4.1. | |
| Loads the checkpoint once, plans assignments for all workers (respecting stickiness | |
| and updating reservation state), and writes all packages and a single shared_frozen.pt. | |
| Supports persistent daemon mode to keep checkpoint loaded in memory. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import hashlib | |
| import importlib.util | |
| import json | |
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| try: | |
| import fcntl | |
| except Exception: | |
| fcntl = None | |
| TOTAL_LAYERS = int(os.environ.get("AGILLM41_TOTAL_LAYERS", "28")) | |
| LAYERS_PER_BLOCK = int(os.environ.get("AGILLM41_LAYERS_PER_BLOCK", "7")) | |
| NUM_BLOCKS = max(1, TOTAL_LAYERS // LAYERS_PER_BLOCK) | |
| MASTER_LOG = Path(os.environ.get("AGILLM41_MASTER_LOG", "/workspace/agillm41_master_train.log")) | |
| LEASE_STATE = Path(os.environ.get("AGILLM41_LEASE_STATE", "/workspace/agillm41_lease_state.json")) | |
| PLAN_STATE = Path(os.environ.get("AGILLM41_LEASE_PLAN_STATE", "/workspace/agillm41_lease_plan_state.json")) | |
| DECIDE = Path(os.environ.get("AGILLM41_LEASE_DECIDE", "/workspace/agillm41_lease_decide.py")) | |
| TAIL_BYTES = int(os.environ.get("AGILLM41_PLAN_TAIL_BYTES", str(4 * 1024 * 1024))) | |
| TOKENIZER_META_KEYS = ( | |
| "tokenizer_payload_schema", | |
| "tokenizer_id", | |
| "tokenizer_json", | |
| "tokenizer_bundle", | |
| "tokenizer_special", | |
| "transformers_version", | |
| "tokenizers_version", | |
| ) | |
| TOKENIZER_META_ALIASES = { | |
| "payload_schema": "tokenizer_payload_schema", | |
| "schema": "tokenizer_payload_schema", | |
| "schema_version": "tokenizer_payload_schema", | |
| "special_tokens": "tokenizer_special", | |
| } | |
| def normalize_tokenizer_metadata(meta: dict[str, Any]) -> dict[str, Any]: | |
| out = dict(meta) | |
| for src, dst in TOKENIZER_META_ALIASES.items(): | |
| if src in out and dst not in out and out[src] is not None: | |
| out[dst] = out[src] | |
| if out.get("tokenizer_json") and not out.get("tokenizer_payload_schema"): | |
| out["tokenizer_payload_schema"] = 2 | |
| return out | |
| def read_tokenizer_sidecar(ckpt: Path) -> dict[str, Any]: | |
| candidates = [Path(str(ckpt) + ".tokenizer.json"), ckpt.parent / "latest.tokenizer.json"] | |
| for candidate in candidates: | |
| try: | |
| if candidate.exists(): | |
| data = json.loads(candidate.read_text(encoding="utf-8")) | |
| if isinstance(data, dict): | |
| return normalize_tokenizer_metadata({k: data[k] for k in TOKENIZER_META_KEYS if k in data and data[k] is not None} | {TOKENIZER_META_ALIASES[k]: data[k] for k in TOKENIZER_META_ALIASES if k in data and data[k] is not None}) | |
| except Exception: | |
| pass | |
| return {} | |
| def checkpoint_tokenizer_metadata(ck: dict[str, Any], ckpt: Path) -> dict[str, Any]: | |
| meta = read_tokenizer_sidecar(ckpt) | |
| for key in TOKENIZER_META_KEYS: | |
| if key in ck and ck[key] is not None: | |
| meta[key] = ck[key] | |
| return normalize_tokenizer_metadata(meta) | |
| def tokenizer_summary(meta: dict[str, Any]) -> dict[str, Any]: | |
| return { | |
| "payload_schema": meta.get("tokenizer_payload_schema"), | |
| "id": meta.get("tokenizer_id"), | |
| "has_tokenizer_json": bool(meta.get("tokenizer_json")), | |
| "has_tokenizer_bundle": bool(meta.get("tokenizer_bundle")), | |
| "special": meta.get("tokenizer_special"), | |
| } | |
| def tokenizer_cache_tag(meta: dict[str, Any]) -> str: | |
| payload: dict[str, Any] = {} | |
| for key in TOKENIZER_META_KEYS: | |
| if key not in meta: | |
| continue | |
| value = meta[key] | |
| if isinstance(value, str): | |
| payload[key] = {"len": len(value), "sha256": hashlib.sha256(value.encode("utf-8")).hexdigest()} | |
| elif isinstance(value, (dict, list)): | |
| text = json.dumps(value, sort_keys=True, separators=(",", ":"), ensure_ascii=False) | |
| payload[key] = {"len": len(text), "sha256": hashlib.sha256(text.encode("utf-8")).hexdigest()} | |
| else: | |
| payload[key] = value | |
| text = json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str) | |
| return hashlib.sha256(text.encode("utf-8")).hexdigest()[:12] | |
| DBLOCK_RE = re.compile( | |
| r"\[dblock\]\s+step=(?P<step>\d+)\s+block=(?P<block>\d+)\s+.*?" | |
| r"counts=\[(?P<counts>[^\]]+)\]\s+ema=\[(?P<ema>[^\]]+)\]" | |
| ) | |
| def read_json(path: Path, default: Any) -> Any: | |
| try: | |
| return json.loads(path.read_text()) | |
| except Exception: | |
| return default | |
| def write_json_atomic(path: Path, data: Any) -> None: | |
| tmp = path.with_suffix(path.suffix + ".tmp") | |
| tmp.write_text(json.dumps(data, indent=2, sort_keys=True)) | |
| tmp.replace(path) | |
| def read_tail(path: Path, limit: int = TAIL_BYTES) -> str: | |
| try: | |
| with path.open("rb") as fh: | |
| fh.seek(0, os.SEEK_END) | |
| size = fh.tell() | |
| fh.seek(max(0, size - limit), os.SEEK_SET) | |
| return fh.read().decode("utf-8", "ignore") | |
| except Exception: | |
| return "" | |
| def parse_nums(text: str, as_float: bool = False) -> list[float] | list[int]: | |
| out: list[float] = [] | |
| for part in text.split(","): | |
| part = part.strip() | |
| if not part: | |
| continue | |
| try: | |
| out.append(float(part) if as_float else int(float(part))) | |
| except Exception: | |
| pass | |
| return out | |
| def latest_dblock_stats(text: str) -> dict[str, Any]: | |
| mode_idx = text.rfind("[dblock] DiffusionBlocks mode") | |
| segment = text[mode_idx:] if mode_idx >= 0 else text | |
| matches = list(DBLOCK_RE.finditer(segment)) or list(DBLOCK_RE.finditer(text)) | |
| if not matches: | |
| return { | |
| "step": 0, | |
| "counts": [0 for _ in range(NUM_BLOCKS)], | |
| "ema": [1.0 for _ in range(NUM_BLOCKS)], | |
| "source": "default", | |
| } | |
| last = matches[-1] | |
| counts = list(parse_nums(last.group("counts")))[:NUM_BLOCKS] | |
| ema = list(parse_nums(last.group("ema"), as_float=True))[:NUM_BLOCKS] | |
| while len(counts) < NUM_BLOCKS: | |
| counts.append(0) | |
| while len(ema) < NUM_BLOCKS: | |
| ema.append(sum(ema) / len(ema) if ema else 1.0) | |
| return {"step": int(last.group("step")), "counts": counts, "ema": ema, "source": "log"} | |
| def async_coverage(text: str) -> dict[str, Any]: | |
| layer_counts = [0 for _ in range(TOTAL_LAYERS)] | |
| layer_last_step = [0 for _ in range(TOTAL_LAYERS)] | |
| block_event_counts = [0 for _ in range(NUM_BLOCKS)] | |
| events: list[dict[str, Any]] = [] | |
| current_step = 0 | |
| for line in text.splitlines(): | |
| if "async_side_update_applied" not in line or "{" not in line: | |
| continue | |
| try: | |
| data = json.loads(line[line.index("{") :]) | |
| except Exception: | |
| continue | |
| try: | |
| step = int(data.get("step") or 0) | |
| except Exception: | |
| step = 0 | |
| current_step = max(current_step, step) | |
| try: | |
| block_id = int(data.get("block_id") or 0) | |
| except Exception: | |
| block_id = 0 | |
| if 0 <= block_id < NUM_BLOCKS: | |
| block_event_counts[block_id] += 1 | |
| layers = data.get("layers") or [] | |
| clean_layers = [] | |
| for layer in layers: | |
| try: | |
| layer_i = int(layer) | |
| except Exception: | |
| continue | |
| if 0 <= layer_i < TOTAL_LAYERS: | |
| layer_counts[layer_i] += 1 | |
| layer_last_step[layer_i] = max(layer_last_step[layer_i], step) | |
| clean_layers.append(layer_i) | |
| events.append( | |
| { | |
| "step": step, | |
| "worker_id": data.get("worker_id"), | |
| "block_id": block_id, | |
| "layers": clean_layers, | |
| "tok_per_sec": data.get("tok_per_sec"), | |
| } | |
| ) | |
| return { | |
| "layer_counts": layer_counts, | |
| "layer_last_step": layer_last_step, | |
| "block_event_counts": block_event_counts, | |
| "current_step": current_step, | |
| "events": events[-80:], | |
| } | |
| def norm_high(values: list[float] | list[int]) -> list[float]: | |
| vals = [float(v) for v in values] | |
| lo = min(vals) if vals else 0.0 | |
| hi = max(vals) if vals else 0.0 | |
| if hi <= lo: | |
| return [0.5 for _ in vals] | |
| return [(v - lo) / (hi - lo) for v in vals] | |
| def norm_low(values: list[float] | list[int]) -> list[float]: | |
| vals = [float(v) for v in values] | |
| lo = min(vals) if vals else 0.0 | |
| hi = max(vals) if vals else 0.0 | |
| if hi <= lo: | |
| return [0.5 for _ in vals] | |
| return [(hi - v) / (hi - lo) for v in vals] | |
| def layer_window(block_id: int, offset: int, max_layers: int) -> list[int]: | |
| start = block_id * LAYERS_PER_BLOCK | |
| width = LAYERS_PER_BLOCK if max_layers >= LAYERS_PER_BLOCK else max(1, max_layers) | |
| return [start + ((offset + i) % LAYERS_PER_BLOCK) for i in range(width)] | |
| def stable_jitter(*parts: Any) -> float: | |
| raw = "|".join(str(p) for p in parts).encode("utf-8", "ignore") | |
| digest = hashlib.blake2b(raw, digest_size=4).hexdigest() | |
| return int(digest, 16) / 0xFFFFFFFF * 0.01 | |
| def decide_capacity(worker: str) -> tuple[int, int, int]: | |
| try: | |
| cp = subprocess.run( | |
| ["python3", str(DECIDE), worker], | |
| capture_output=True, | |
| text=True, | |
| timeout=45, | |
| check=False, | |
| ) | |
| fields = cp.stdout.strip().split() | |
| if len(fields) >= 3: | |
| return max(1, int(fields[0])), max(1, int(fields[1])), max(1, int(fields[2])) | |
| except Exception: | |
| pass | |
| return 1, 128, 1 | |
| def disabled_reason(worker: str) -> str | None: | |
| state = read_json(LEASE_STATE, {}) | |
| rec = state.get(worker, {}) if isinstance(state, dict) else {} | |
| if not isinstance(rec, dict): | |
| return None | |
| text = " ".join(str(rec.get(k) or "") for k in ("failure", "failure_seen")) | |
| low = text.lower() | |
| if "disabled" in low: | |
| return text.strip()[:200] or "disabled" | |
| if worker.endswith("igpu") and ("directml" in low or "backward" in low or "runtimeerror" in low): | |
| return text.strip()[:200] or "igpu backward failure" | |
| return None | |
| def build_scores(text: str) -> dict[str, Any]: | |
| dblock = latest_dblock_stats(text) | |
| coverage = async_coverage(text) | |
| counts = [int(x) for x in dblock["counts"]] | |
| ema = [float(x) for x in dblock["ema"]] | |
| block_layer_merges = [ | |
| sum(coverage["layer_counts"][b * LAYERS_PER_BLOCK : (b + 1) * LAYERS_PER_BLOCK]) | |
| for b in range(NUM_BLOCKS) | |
| ] | |
| block_last = [ | |
| max(coverage["layer_last_step"][b * LAYERS_PER_BLOCK : (b + 1) * LAYERS_PER_BLOCK] or [0]) | |
| for b in range(NUM_BLOCKS) | |
| ] | |
| current_step = max(int(coverage["current_step"] or 0), int(dblock["step"] or 0)) | |
| block_stale = [max(0, current_step - step) if step else current_step for step in block_last] | |
| ema_need = norm_high(ema) | |
| count_need = norm_low(counts) | |
| side_need = norm_low(block_layer_merges) | |
| stale_need = norm_high(block_stale) | |
| block_scores = [] | |
| for i in range(NUM_BLOCKS): | |
| block_scores.append( | |
| 0.45 * ema_need[i] | |
| + 0.25 * count_need[i] | |
| + 0.20 * side_need[i] | |
| + 0.10 * stale_need[i] | |
| ) | |
| layer_count_need = norm_low(coverage["layer_counts"]) | |
| layer_staleness = [max(0, current_step - x) if x else current_step for x in coverage["layer_last_step"]] | |
| layer_stale_need = norm_high(layer_staleness) | |
| layer_scores = [] | |
| for layer in range(TOTAL_LAYERS): | |
| block = min(NUM_BLOCKS - 1, layer // LAYERS_PER_BLOCK) | |
| never = 1.0 if coverage["layer_counts"][layer] == 0 else 0.0 | |
| layer_scores.append( | |
| 0.40 * layer_count_need[layer] | |
| + 0.30 * layer_stale_need[layer] | |
| + 0.15 * never | |
| + 0.15 * block_scores[block] | |
| ) | |
| return { | |
| "dblock": dblock, | |
| "coverage": coverage, | |
| "block_layer_merges": block_layer_merges, | |
| "block_last": block_last, | |
| "block_stale": block_stale, | |
| "block_scores": block_scores, | |
| "layer_scores": layer_scores, | |
| "current_step": current_step, | |
| } | |
| def choose_assignment( | |
| worker: str, | |
| max_layers: int, | |
| scores: dict[str, Any], | |
| reservations: list[dict[str, Any]], | |
| round_id: str, | |
| history: list[dict[str, Any]] = None, | |
| ) -> dict[str, Any]: | |
| reserved_layers = { | |
| int(layer) | |
| for item in reservations | |
| for layer in item.get("layers", []) | |
| if isinstance(layer, int) or str(layer).isdigit() | |
| } | |
| last_block_id = None | |
| last_layer_offset = None | |
| if history: | |
| for item in reversed(history): | |
| if item.get("worker") == worker: | |
| last_block_id = item.get("block_id") | |
| last_layer_offset = item.get("layer_offset") | |
| break | |
| block_scores = scores["block_scores"] | |
| layer_scores = scores["layer_scores"] | |
| best: dict[str, Any] | None = None | |
| for block_id in range(NUM_BLOCKS): | |
| offsets = [0] if max_layers >= LAYERS_PER_BLOCK else list(range(LAYERS_PER_BLOCK)) | |
| for offset in offsets: | |
| layers = layer_window(block_id, offset, max_layers) | |
| collisions = sum(1 for layer in layers if layer in reserved_layers) | |
| window_score = sum(layer_scores[layer] for layer in layers) / max(1, len(layers)) | |
| boost = 0.0 | |
| if last_block_id is not None and block_id == last_block_id: | |
| boost += 0.8 | |
| if last_layer_offset is not None and offset == last_layer_offset: | |
| boost += 0.2 | |
| score = 1.35 * block_scores[block_id] + window_score - 2.75 * collisions + boost | |
| score += stable_jitter(round_id, worker, block_id, offset) | |
| candidate = { | |
| "worker": worker, | |
| "block_id": block_id, | |
| "layer_offset": offset, | |
| "layers": layers, | |
| "score": score, | |
| "collisions": collisions, | |
| "block_score": block_scores[block_id], | |
| "window_score": window_score, | |
| } | |
| if best is None or candidate["score"] > best["score"]: | |
| best = candidate | |
| assert best is not None | |
| return best | |
| def explain_reason(choice: dict[str, Any], scores: dict[str, Any]) -> dict[str, Any]: | |
| dblock = scores["dblock"] | |
| cov = scores["coverage"] | |
| return { | |
| "block_id": choice["block_id"], | |
| "layers": choice["layers"], | |
| "score": round(float(choice["score"]), 4), | |
| "collisions": choice["collisions"], | |
| "dblock_ema": dblock["ema"], | |
| "dblock_counts": dblock["counts"], | |
| "block_layer_merges": scores["block_layer_merges"], | |
| "selected_layer_merge_counts": [cov["layer_counts"][l] for l in choice["layers"]], | |
| "selected_layer_last_step": [cov["layer_last_step"][l] for l in choice["layers"]], | |
| "current_step": scores["current_step"], | |
| } | |
| def dblock_layers(total_layers: int, blocks: int) -> list[list[int]]: | |
| span = max(1, total_layers // blocks) | |
| assign = [list(range(i * span, (i + 1) * span)) for i in range(blocks)] | |
| assign[-1] = list(range((blocks - 1) * span, total_layers)) | |
| return assign | |
| def local_block_state(core_state: dict[str, Any], layers: list[int]) -> dict[str, Any]: | |
| out: dict[str, Any] = {} | |
| for local_i, global_i in enumerate(layers): | |
| src_prefix = f"blocks.{global_i}." | |
| dst_prefix = f"blocks.{local_i}." | |
| for key, value in core_state.items(): | |
| if isinstance(key, str) and key.startswith(src_prefix): | |
| out[dst_prefix + key[len(src_prefix) :]] = value.detach().cpu() | |
| return out | |
| def token_batches(vocab: int, steps: int, batch_size: int, block_size: int, seed: int) -> torch.Tensor: | |
| gen = torch.Generator(device="cpu") | |
| gen.manual_seed(int(seed)) | |
| return torch.randint(2, int(vocab), (int(steps), int(batch_size), int(block_size)), generator=gen, dtype=torch.long) | |
| def load_runtime(path: str | Path): | |
| path = Path(path).resolve() | |
| os.environ.setdefault("TOKENIZER_ID", "deepseek-ai/DeepSeek-V4-Pro") | |
| parent = str(path.parent) | |
| if parent not in sys.path: | |
| sys.path.insert(0, parent) | |
| spec = importlib.util.spec_from_file_location("agillm41_export_runtime", path) | |
| if spec is None or spec.loader is None: | |
| raise RuntimeError(f"cannot import AGILLM4.1 runtime from {path}") | |
| module = importlib.util.module_from_spec(spec) | |
| sys.modules["agillm41_export_runtime"] = module | |
| spec.loader.exec_module(module) | |
| return module | |
| def real_token_batches(runtime: Any, source: str, steps: int, batch_size: int, block_size: int, seed: int) -> torch.Tensor: | |
| if source == "__default__": | |
| source = getattr(runtime, "DEFAULT_PRETRAIN_SOURCES") | |
| total = int(steps) * int(batch_size) * int(block_size) | |
| stream = runtime.token_stream(source, total, seed=int(seed), streaming=True) | |
| data = [] | |
| for _ in range(total): | |
| data.append(int(next(stream))) | |
| return torch.tensor(data, dtype=torch.long).view(int(steps), int(batch_size), int(block_size)) | |
| class Logger: | |
| def __init__(self, capture: bool = False): | |
| self.capture = capture | |
| self.buffer = [] | |
| def log(self, msg: str): | |
| print(msg, flush=True) | |
| if self.capture: | |
| self.buffer.append(msg) | |
| class CheckpointCache: | |
| def __init__(self): | |
| self.ckpt_path = None | |
| self.ck = None | |
| def get(self, path: Path) -> dict[str, Any]: | |
| if self.ckpt_path != path: | |
| print(f"[CACHE] Loading checkpoint {path}...", flush=True) | |
| t0 = time.time() | |
| self.ck = None | |
| import gc | |
| gc.collect() | |
| self.ck = torch.load(path, map_location="cpu", weights_only=False) | |
| self.ckpt_path = path | |
| print(f"[CACHE] Loaded checkpoint {path} in {time.time() - t0:.2f} seconds.", flush=True) | |
| else: | |
| print(f"[CACHE] Checkpoint cache hit for {path}.", flush=True) | |
| return self.ck | |
| def update_dynamic_blocks() -> None: | |
| global LAYERS_PER_BLOCK, NUM_BLOCKS | |
| state_path = Path(LEASE_STATE) | |
| if not state_path.exists(): | |
| return | |
| try: | |
| state_data = json.loads(state_path.read_text()) | |
| except Exception: | |
| return | |
| active_workers = [] | |
| now = time.time() | |
| for name, w in state_data.items(): | |
| ts = w.get("ts", 0) | |
| if now - ts < 1800 and not disabled_reason(name): | |
| active_workers.append(w) | |
| if not active_workers: | |
| return | |
| has_ultra_low = any(w.get("max_layers", 1) <= 1 for w in active_workers) | |
| has_low_cap = any(w.get("max_layers", 1) <= 2 for w in active_workers) | |
| all_high_cap = all(w.get("max_layers", 1) >= 14 for w in active_workers) | |
| all_mid_high = all(w.get("max_layers", 1) >= 7 for w in active_workers) | |
| if all_high_cap: | |
| chosen_lpb = 28 | |
| elif all_mid_high: | |
| chosen_lpb = 14 | |
| elif has_ultra_low: | |
| chosen_lpb = 2 | |
| elif has_low_cap: | |
| chosen_lpb = 4 | |
| else: | |
| chosen_lpb = 7 | |
| if TOTAL_LAYERS % chosen_lpb != 0: | |
| chosen_lpb = 7 | |
| LAYERS_PER_BLOCK = chosen_lpb | |
| NUM_BLOCKS = max(1, TOTAL_LAYERS // LAYERS_PER_BLOCK) | |
| hot_config_path = Path("/workspace/hot_config.json") | |
| try: | |
| cfg = {} | |
| if hot_config_path.exists(): | |
| cfg = json.loads(hot_config_path.read_text()) | |
| cfg["dblock_blocks"] = NUM_BLOCKS | |
| tmp = hot_config_path.with_suffix(hot_config_path.suffix + ".tmp") | |
| tmp.write_text(json.dumps(cfg, indent=2, sort_keys=True)) | |
| tmp.replace(hot_config_path) | |
| except Exception as e: | |
| print(f"[dblock_autotuning] Error writing hot_config: {e}", file=sys.stderr) | |
| print(f"[dblock_autotuning] Dynamically selected LAYERS_PER_BLOCK = {LAYERS_PER_BLOCK}, NUM_BLOCKS = {NUM_BLOCKS} based on {len(active_workers)} active workers.", file=sys.stderr) | |
| def apply_explicit_dblock_blocks(value: Any) -> None: | |
| global LAYERS_PER_BLOCK, NUM_BLOCKS | |
| if value in (None, "", 0, "0"): | |
| return | |
| blocks = max(1, int(value)) | |
| if TOTAL_LAYERS % blocks != 0: | |
| raise ValueError(f"--dblock-blocks={blocks} must evenly divide TOTAL_LAYERS={TOTAL_LAYERS}") | |
| NUM_BLOCKS = blocks | |
| LAYERS_PER_BLOCK = TOTAL_LAYERS // NUM_BLOCKS | |
| def run_planning_and_export( | |
| args_dict: dict[str, Any], | |
| ck_data: dict[str, Any] = None, | |
| logger: Logger = None, | |
| ) -> dict[str, Any]: | |
| update_dynamic_blocks() | |
| apply_explicit_dblock_blocks(args_dict.get("dblock_blocks")) | |
| if logger is None: | |
| logger = Logger() | |
| ckpt = Path(args_dict["ckpt"]) | |
| out_dir = Path(args_dict["out_dir"]) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| start = time.time() | |
| if ck_data is None: | |
| logger.log(f"Loading checkpoint {ckpt}...") | |
| ck = torch.load(ckpt, map_location="cpu", weights_only=False) | |
| else: | |
| ck = ck_data | |
| if "cfg" in ck: | |
| cfg = dict(ck["cfg"]) | |
| elif "seed_meta" in ck: | |
| cfg = dict(ck["seed_meta"].get("v4_preset") or ck["seed_meta"].get("v3_preset", {})) | |
| if not cfg: | |
| raise KeyError("Neither cfg nor seed_meta presets found in checkpoint") | |
| else: | |
| raise KeyError("Neither cfg nor seed_meta found in checkpoint") | |
| core = ck["core"] | |
| vocab = int(core["emb.weight"].shape[0]) | |
| if os.environ.get("AGILLM_MOE_SHARED_EXPERTS"): | |
| cfg["moe_shared_experts"] = int(os.environ["AGILLM_MOE_SHARED_EXPERTS"]) | |
| if os.environ.get("AGILLM_MOE_SHARED_MLP_MULT"): | |
| cfg["moe_shared_mlp_mult"] = int(os.environ["AGILLM_MOE_SHARED_MLP_MULT"]) | |
| assignments = dblock_layers(int(cfg["layers"]), NUM_BLOCKS) | |
| tie_weights = bool(ck.get("tie_weights", False)) | |
| tok_meta = checkpoint_tokenizer_metadata(ck, ckpt) | |
| runtime = load_runtime(args_dict["runtime"]) if args_dict.get("source") else None | |
| # 1. Save shared_frozen.pt once (using local caching/hardlinking) | |
| tmp_cache_dir = Path("/tmp/agillm41_shared_cache") | |
| tmp_cache_dir.mkdir(parents=True, exist_ok=True) | |
| cached_shared_path = tmp_cache_dir / f"shared_frozen_{ckpt.stem}_{tokenizer_cache_tag(tok_meta)}.pt" | |
| if not cached_shared_path.exists(): | |
| logger.log(f"Saving new shared_frozen cache to {cached_shared_path}...") | |
| shared = { | |
| "kind": "agillm4_bench_shared_v1", | |
| "cfg": cfg, | |
| "tie_weights": tie_weights, | |
| "tokenizer_id": tok_meta.get("tokenizer_id", ck.get("tokenizer_id")), | |
| **tok_meta, | |
| "vocab": vocab, | |
| "emb_weight": core["emb.weight"].detach().cpu().to(torch.float16), | |
| "ln_weight": core["ln.weight"].detach().cpu(), | |
| "ln_bias": core["ln.bias"].detach().cpu(), | |
| } | |
| if not tie_weights: | |
| shared["ar"] = {k: v.detach().cpu() for k, v in ck.get("ar", {}).items()} | |
| shared["sat"] = {k: v.detach().cpu() for k, v in ck.get("sat", {}).items()} | |
| shared["nat"] = {k: v.detach().cpu() for k, v in ck.get("nat", {}).items()} | |
| else: | |
| sat = ck.get("sat", {}) | |
| if "gate.weight" in sat and "gate.bias" in sat: | |
| shared["sat_gate"] = { | |
| "gate.weight": sat["gate.weight"].detach().cpu(), | |
| "gate.bias": sat["gate.bias"].detach().cpu(), | |
| } | |
| tmp_shared = cached_shared_path.with_suffix(".pt.tmp") | |
| torch.save(shared, tmp_shared, _use_new_zipfile_serialization=False) | |
| tmp_shared.replace(cached_shared_path) | |
| logger.log(f"Saved cached shared_frozen.pt to {cached_shared_path}") | |
| # Now link or copy it to out_dir | |
| shared_path = out_dir / "shared_frozen.pt" | |
| if not shared_path.exists(): | |
| try: | |
| if shared_path.exists(): | |
| shared_path.unlink() | |
| os.link(cached_shared_path, shared_path) | |
| logger.log(f"Linked shared_frozen.pt to {shared_path}") | |
| except Exception as e: | |
| import shutil | |
| shutil.copy(cached_shared_path, shared_path) | |
| logger.log(f"Copied shared_frozen.pt to {shared_path} (fallback due to {e})") | |
| # 2. Build scores from master log once | |
| text = read_tail(MASTER_LOG) | |
| scores = build_scores(text) | |
| # 3. Plan leases for all workers under lock | |
| planned_workers = [] | |
| worker_names = [w.strip() for w in args_dict["workers"].split(",") if w.strip()] | |
| PLAN_STATE.parent.mkdir(parents=True, exist_ok=True) | |
| lock_path = PLAN_STATE.with_suffix(PLAN_STATE.suffix + ".lock") | |
| with lock_path.open("a+") as lock_fh: | |
| if fcntl is not None: | |
| fcntl.flock(lock_fh, fcntl.LOCK_EX) | |
| state = read_json(PLAN_STATE, {}) | |
| if not isinstance(state, dict) or state.get("round") != args_dict["round"]: | |
| state = {"round": args_dict["round"], "reservations": [], "history": []} | |
| reservations = state.setdefault("reservations", []) | |
| history = state.get("history", []) | |
| state_data = read_json(LEASE_STATE, {}) | |
| for worker in worker_names: | |
| reason = disabled_reason(worker) | |
| if reason: | |
| logger.log(f"Worker {worker} skipped: {reason}") | |
| continue | |
| batch, block_tokens, max_layers_decided = decide_capacity(worker) | |
| max_layers = max(1, min(LAYERS_PER_BLOCK, int(max_layers_decided))) | |
| choice = choose_assignment(worker, max_layers, scores, reservations, args_dict["round"], history) | |
| choice.update({"batch": batch, "block_tokens": block_tokens, "max_layers": max_layers, "at": time.time()}) | |
| reason_json = explain_reason(choice, scores) | |
| choice["reason"] = reason_json | |
| reservations.append(choice) | |
| history.append(choice) | |
| # Determine steps count | |
| rec = state_data.get(worker, {}) if isinstance(state_data, dict) else {} | |
| tokps = rec.get("tokps") or rec.get("decision_tokps") | |
| if not tokps or tokps <= 0: | |
| if "geth" in worker: | |
| tokps = 20.0 | |
| elif "communist" in worker: | |
| tokps = 20.0 | |
| elif "prime" in worker: | |
| tokps = 8.0 | |
| elif "mcp" in worker: | |
| tokps = 8.0 | |
| elif "laptop" in worker: | |
| tokps = 5.0 | |
| else: | |
| tokps = 15.0 | |
| target_duration = float(os.environ.get("AGILLM41_LEASE_TARGET_DURATION", "240")) | |
| step_tokens = batch * block_tokens | |
| steps = max(5, min(100, int(round((tokps * target_duration) / step_tokens)))) | |
| choice["steps"] = steps | |
| planned_workers.append(choice) | |
| state["reservations"] = reservations[-64:] | |
| state["history"] = history[-256:] | |
| state["last_scores"] = { | |
| "block_scores": [round(float(x), 4) for x in scores["block_scores"]], | |
| "dblock_ema": scores["dblock"]["ema"], | |
| "dblock_counts": scores["dblock"]["counts"], | |
| "block_layer_merges": scores["block_layer_merges"], | |
| "current_step": scores["current_step"], | |
| } | |
| write_json_atomic(PLAN_STATE, state) | |
| if fcntl is not None: | |
| fcntl.flock(lock_fh, fcntl.LOCK_UN) | |
| # 4. Generate and save packages for planned workers | |
| manifest = { | |
| "kind": "agillm4_dblock_bench_manifest_v1", | |
| "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "source_ckpt": str(ckpt), | |
| "source_step": int(ck.get("step", 0) or 0), | |
| "source_seen_tok": int(ck.get("seen_tok", 0) or 0), | |
| "cfg": cfg, | |
| "tie_weights": tie_weights, | |
| "tokenizer_id": tok_meta.get("tokenizer_id", ck.get("tokenizer_id")), | |
| "tokenizer": tokenizer_summary(tok_meta), | |
| "vocab": vocab, | |
| "dblock_blocks": NUM_BLOCKS, | |
| "steps": 1, | |
| "batch_size": 1, | |
| "block_size": 128, | |
| "shared": str(shared_path), | |
| "packages": [], | |
| } | |
| for idx, choice in enumerate(planned_workers): | |
| worker_id = choice["worker"] | |
| block_id = choice["block_id"] | |
| layers = choice["layers"] | |
| steps = choice["steps"] | |
| batch = choice["batch"] | |
| block_size = choice["block_tokens"] | |
| if "v100" in worker_id or "gpu" in worker_id or any(x in worker_id.lower() for x in ("5090", "4090", "3090")): | |
| ar_loss_tokens = 256 | |
| nat_loss_tokens = 256 | |
| else: | |
| ar_loss_tokens = 64 | |
| nat_loss_tokens = 64 | |
| if args_dict.get("ar_loss_tokens") is not None: | |
| ar_loss_tokens = max(1, int(args_dict.get("ar_loss_tokens"))) | |
| if args_dict.get("nat_loss_tokens") is not None: | |
| nat_loss_tokens = max(1, int(args_dict.get("nat_loss_tokens"))) | |
| sat_loss_tokens = int(args_dict.get("sat_loss_tokens", 0) or 0) | |
| if sat_loss_tokens <= 0: | |
| sat_loss_tokens = min(int(ar_loss_tokens), int(nat_loss_tokens)) | |
| runtime_args = { | |
| "attn_backend": args_dict.get("attn_backend", "manual"), | |
| "sublinear_window": int(args_dict.get("sublinear_window", 128)), | |
| "sublinear_stride": int(args_dict.get("sublinear_stride", 128)), | |
| "sublinear_max_anchors": int(args_dict.get("sublinear_max_anchors", 128)), | |
| "sublinear_chunk": int(args_dict.get("sublinear_chunk", 128)), | |
| "sublinear_sinks": int(args_dict.get("sublinear_sinks", 4)), | |
| "sublinear_recent_anchors": int(args_dict.get("sublinear_recent_anchors", 64)), | |
| "sublinear_pooled_landmarks": bool(args_dict.get("sublinear_pooled_landmarks", False)), | |
| "dblock_objective_mode": args_dict.get("objective_mode", "stochastic"), | |
| "dblock_ar_prob": float(args_dict.get("ar_prob", 0.45)), | |
| "dblock_sat_prob": float(args_dict.get("sat_prob", 0.40)), | |
| "dblock_nat_prob": float(args_dict.get("nat_prob", 0.15)), | |
| "dblock_ar_loss_tokens": int(ar_loss_tokens), | |
| "dblock_sat_loss_tokens": int(sat_loss_tokens), | |
| "dblock_nat_loss_tokens": int(nat_loss_tokens), | |
| "nat_mask_ratio": float(args_dict.get("nat_mask_ratio", 0.5)), | |
| "nat_max_tokens": int(block_size), | |
| } | |
| optional_keys = [ | |
| "amp", "grad_checkpoint", "dblock_checkpoint_stride", | |
| "dblock_checkpoint_skip_tail", "dblock_activation_offload", | |
| "dblock_activation_offload_min_mb" | |
| ] | |
| for ok in optional_keys: | |
| if args_dict.get(ok) is not None: | |
| runtime_args[ok] = args_dict[ok] | |
| batch_seed = int(args_dict.get("seed", 20260602)) + idx * 1009 | |
| if runtime is not None: | |
| ids = real_token_batches(runtime, args_dict["source"], steps, batch, block_size, batch_seed) | |
| data_mode = "real" | |
| else: | |
| ids = token_batches(vocab, steps, batch, block_size, batch_seed) | |
| data_mode = "synthetic" | |
| pkg = { | |
| "kind": "agillm4_dblock_bench_package_v1", | |
| "worker_id": worker_id, | |
| "block_id": int(block_id), | |
| "layers": layers, | |
| "cfg": cfg, | |
| "tie_weights": tie_weights, | |
| "tokenizer_id": tok_meta.get("tokenizer_id", ck.get("tokenizer_id")), | |
| "tokenizer": tokenizer_summary(tok_meta), | |
| "vocab": vocab, | |
| "dblock_blocks": NUM_BLOCKS, | |
| "steps": int(steps), | |
| "batch_size": int(batch), | |
| "block_size": int(block_size), | |
| "data_mode": data_mode, | |
| "source": args_dict.get("source", ""), | |
| "ids_batches": ids, | |
| "block_state": local_block_state(core, layers), | |
| "runtime_args": runtime_args, | |
| } | |
| out = out_dir / f"lease_{worker_id}_block{block_id}_agillm4bench.pt" | |
| tmp = out.with_suffix(".pt.tmp") | |
| torch.save(pkg, tmp, _use_new_zipfile_serialization=False) | |
| tmp.replace(out) | |
| manifest["packages"].append( | |
| { | |
| "worker_id": worker_id, | |
| "block_id": int(block_id), | |
| "layers": layers, | |
| "path": str(out), | |
| "bytes": out.stat().st_size, | |
| } | |
| ) | |
| logger.log(json.dumps({"event": "save_package", "worker_id": worker_id, "block_id": block_id, "layers": layers, "path": str(out), "bytes": out.stat().st_size})) | |
| manifest["wall_sec"] = round(time.time() - start, 3) | |
| (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8") | |
| logger.log(json.dumps({"event": "done", "out_dir": str(out_dir), "wall_sec": manifest["wall_sec"]}, indent=2)) | |
| return {"status": "success", "wall_sec": manifest["wall_sec"], "out_dir": str(out_dir), "logs": logger.buffer} | |
| def send_request_to_daemon(port: int, args_dict: dict[str, Any]) -> dict[str, Any]: | |
| import urllib.request | |
| import urllib.error | |
| serializable = {} | |
| for k, v in args_dict.items(): | |
| if isinstance(v, Path): | |
| serializable[k] = str(v) | |
| else: | |
| serializable[k] = v | |
| req = urllib.request.Request( | |
| f"http://127.0.0.1:{port}/plan", | |
| data=json.dumps(serializable).encode("utf-8"), | |
| headers={"Content-Type": "application/json"} | |
| ) | |
| with urllib.request.urlopen(req, timeout=120) as response: | |
| return json.loads(response.read().decode("utf-8")) | |
| def run_daemon(port: int): | |
| from http.server import HTTPServer, BaseHTTPRequestHandler | |
| cache = CheckpointCache() | |
| class DaemonHandler(BaseHTTPRequestHandler): | |
| def log_message(self, format, *args): | |
| pass | |
| def do_POST(self): | |
| if self.path != "/plan": | |
| self.send_response(404) | |
| self.end_headers() | |
| return | |
| try: | |
| content_length = int(self.headers.get("Content-Length", 0)) | |
| post_data = self.rfile.read(content_length) | |
| args_dict = json.loads(post_data.decode("utf-8")) | |
| ckpt_path = Path(args_dict["ckpt"]) | |
| ck_data = cache.get(ckpt_path) | |
| logger = Logger(capture=True) | |
| res = run_planning_and_export(args_dict, ck_data, logger) | |
| self.send_response(200) | |
| self.send_header("Content-Type", "application/json") | |
| self.end_headers() | |
| self.wfile.write(json.dumps(res).encode("utf-8")) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| self.send_response(500) | |
| self.send_header("Content-Type", "application/json") | |
| self.end_headers() | |
| self.wfile.write(json.dumps({"status": "error", "error": str(e)}).encode("utf-8")) | |
| server = HTTPServer(("127.0.0.1", port), DaemonHandler) | |
| print(f"Planner daemon started on http://127.0.0.1:{port}/", flush=True) | |
| try: | |
| server.serve_forever() | |
| except KeyboardInterrupt: | |
| print("Planner daemon stopping...", flush=True) | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description="Batch planner and package exporter") | |
| ap.add_argument("--daemon", action="store_true", help="Run as a persistent daemon") | |
| ap.add_argument("--port", type=int, default=18888, help="Daemon port") | |
| ap.add_argument("--no-daemon", action="store_true", help="Force local slow path") | |
| ap.add_argument("--ckpt", required=False) | |
| ap.add_argument("--out-dir", required=False) | |
| ap.add_argument("--round", required=False) | |
| ap.add_argument("--workers", required=False, help="comma-separated list of worker names to plan and export") | |
| ap.add_argument("--dblock-blocks", type=int, default=None) | |
| ap.add_argument("--seed", type=int, default=20260602) | |
| ap.add_argument("--runtime", default="agillm41.py") | |
| ap.add_argument("--source", default="") | |
| ap.add_argument("--attn-backend", choices=["manual", "sdpa", "sublinear"], default="manual") | |
| ap.add_argument("--sublinear-window", type=int, default=128) | |
| ap.add_argument("--sublinear-stride", type=int, default=128) | |
| ap.add_argument("--sublinear-max-anchors", type=int, default=128) | |
| ap.add_argument("--sublinear-chunk", type=int, default=128) | |
| ap.add_argument("--sublinear-sinks", type=int, default=4) | |
| ap.add_argument("--sublinear-recent-anchors", type=int, default=64) | |
| ap.add_argument("--sublinear-pooled-landmarks", action="store_true") | |
| ap.add_argument("--objective-mode", choices=["stochastic", "periodic"], default="stochastic") | |
| ap.add_argument("--ar-prob", type=float, default=0.45) | |
| ap.add_argument("--sat-prob", type=float, default=0.40) | |
| ap.add_argument("--nat-prob", type=float, default=0.15) | |
| ap.add_argument("--ar-loss-tokens", type=int, default=None) | |
| ap.add_argument("--sat-loss-tokens", type=int, default=128) | |
| ap.add_argument("--nat-loss-tokens", type=int, default=None) | |
| ap.add_argument("--nat-mask-ratio", type=float, default=0.5) | |
| ap.add_argument("--amp", action=argparse.BooleanOptionalAction, default=None) | |
| ap.add_argument("--grad-checkpoint", action=argparse.BooleanOptionalAction, default=None) | |
| ap.add_argument("--dblock-checkpoint-stride", type=int, default=None) | |
| ap.add_argument("--dblock-checkpoint-skip-tail", type=int, default=None) | |
| ap.add_argument("--dblock-activation-offload", action=argparse.BooleanOptionalAction, default=None) | |
| ap.add_argument("--dblock-activation-offload-min-mb", type=float, default=None) | |
| args = ap.parse_args() | |
| if args.daemon: | |
| run_daemon(args.port) | |
| return 0 | |
| if not (args.ckpt and args.out_dir and args.round and args.workers): | |
| ap.error("the following arguments are required: --ckpt, --out-dir, --round, --workers (unless running with --daemon)") | |
| # Client mode: try sending request to daemon first | |
| if not args.no_daemon: | |
| try: | |
| payload = dict(vars(args)) | |
| # Resolve relative paths to absolute paths | |
| if payload.get("ckpt"): | |
| payload["ckpt"] = str(Path(payload["ckpt"]).resolve()) | |
| if payload.get("out_dir"): | |
| payload["out_dir"] = str(Path(payload["out_dir"]).resolve()) | |
| if payload.get("runtime"): | |
| payload["runtime"] = str(Path(payload["runtime"]).resolve()) | |
| # Remove daemon specific control keys from payload | |
| payload.pop("daemon", None) | |
| payload.pop("port", None) | |
| payload.pop("no_daemon", None) | |
| res = send_request_to_daemon(args.port, payload) | |
| if res.get("status") == "success": | |
| for line in res.get("logs", []): | |
| print(line, flush=True) | |
| return 0 | |
| else: | |
| print(f"[CLIENT] Daemon returned error: {res.get('error')}. Falling back to local planning...", flush=True) | |
| except Exception as e: | |
| print(f"[CLIENT] Connection to daemon failed: {e}. Falling back to local planning...", flush=True) | |
| # Local planning fallback path | |
| res = run_planning_and_export(vars(args)) | |
| return 0 if res.get("status") == "success" else 1 | |
| if __name__ == "__main__": | |
| try: | |
| _exit_code = int(main() or 0) | |
| except SystemExit as exc: | |
| _code = exc.code if isinstance(exc.code, int) else 1 | |
| try: | |
| sys.stdout.flush() | |
| sys.stderr.flush() | |
| finally: | |
| os._exit(_code) | |
| except BaseException: | |
| import traceback | |
| traceback.print_exc() | |
| try: | |
| sys.stdout.flush() | |
| sys.stderr.flush() | |
| finally: | |
| os._exit(1) | |
| try: | |
| sys.stdout.flush() | |
| sys.stderr.flush() | |
| finally: | |
| os._exit(_exit_code) | |
Xet Storage Details
- Size:
- 39.8 kB
- Xet hash:
- 44a0b0144d94684f99875eb4220fe0c4e7ae3b973d50774d0079187006cffe4e
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.