| |
| """ |
| markov_lm.py β Hybrid CPU+GPU N-gram (Markov Chain) Language Model |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Uses same datasets, tokenizer, and infrastructure as AGILLM nC.py. |
| Target: beat Infini-gram's 47% next-token accuracy on Pile validation. |
| |
| Architecture: |
| TRAIN β CPU: stream tokens, accumulate counts in Python dicts (memory-bound) |
| FREEZE β Convert dicts to GPU hash tables (sorted tensors + searchsorted) |
| EVAL β GPU: batch context lookup, parallel KN smoothing, vectorised accuracy |
| INFER β GPU: parallel sampling from smoothed distributions |
| |
| Key design: |
| - Modified Kneser-Ney smoothing (gold standard for n-gram LMs) |
| - GPU hash tables via sorted int64 keys + torch.searchsorted (batch parallel) |
| - FNV-1a hash: n-gram tuple β int64 for O(log N) GPU lookup |
| - All orders 1..max_order stored and queried simultaneously |
| - Prunes n-grams below min_count threshold to control memory |
| - Checkpoint: save/load as .pkl (CPU dicts) |
| |
| Usage: |
| python markov_lm.py train --max_order 5 --tokens 500_000_000 --save markov_5gram |
| python markov_lm.py eval --model markov_5gram --eval_tokens 1_000_000 |
| python markov_lm.py generate --model markov_5gram --prompt "The meaning of" --max_new 200 |
| python markov_lm.py status --model markov_5gram |
| """ |
|
|
| from __future__ import annotations |
| import argparse, gc, math, os, pickle, sys, time, json |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| from datetime import datetime, timezone, timedelta |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| |
| DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| CPU = torch.device("cpu") |
| print(f"[markov_lm] Device: {DEV}") |
| if DEV.type == "cuda": |
| print(f"[markov_lm] GPU: {torch.cuda.get_device_name(0)}") |
| props = torch.cuda.get_device_properties(0) |
| vram = getattr(props, 'total_memory', None) or getattr(props, 'total_mem', 0) |
| print(f"[markov_lm] VRAM: {vram / 1e9:.1f} GB") |
|
|
| |
| TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "gpt2") |
|
|
| def _load_tokenizer(): |
| from transformers import AutoTokenizer, logging as hf_log |
| hf_log.set_verbosity_error() |
| t = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True) |
| if t.pad_token is None: |
| t.add_special_tokens({"pad_token": "<|pad|>"}) |
| return t |
|
|
| tok = _load_tokenizer() |
| VOCAB = max(tok.get_vocab().values()) + 1 |
| EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id |
| print(f"[markov_lm] Vocab: {VOCAB:,} | EOS: {EOS}") |
|
|
| |
| DEFAULT_SOURCES = ",".join([ |
| "OpenTransformer/goddess-crawl", |
| "OpenTransformer/agillm-crawl-data", |
| "OpenTransformer/web-crawl-2026", |
| "OpenTransformer/web-crawl-clean-v2", |
| "OpenTransformer/scraped-web-data", |
| "OpenTransformer/turbo-crawl", |
| "OpenTransformer/sft-data-clean", |
| "OpenTransformer/web-crawl-v1", |
| ]) |
|
|
| EVAL_SOURCE = "monology/pile-uncopyrighted" |
|
|
| |
| def get_uk_time() -> str: |
| utc = datetime.now(timezone.utc) |
| y = utc.year |
| mar = datetime(y, 3, 31, 1, 0, tzinfo=timezone.utc) |
| while mar.weekday() != 6: mar = mar.replace(day=mar.day - 1) |
| oct = datetime(y, 10, 31, 1, 0, tzinfo=timezone.utc) |
| while oct.weekday() != 6: oct = oct.replace(day=oct.day - 1) |
| if mar <= utc < oct: |
| return (utc + timedelta(hours=1)).strftime('%Y-%m-%d %H:%M:%S BST') |
| return utc.strftime('%Y-%m-%d %H:%M:%S GMT') |
|
|
| |
| class SafeProgress: |
| def __init__(self, total, initial=0, unit="tok", interval=5_000_000): |
| self.total, self.n, self.unit = total, initial, unit |
| self.last_print, self.postfix = initial, {} |
| self.start_time = time.time() |
| self.interval = interval |
|
|
| def update(self, n=1): |
| self.n += n |
| if self.n - self.last_print >= self.interval: |
| self._print(); self.last_print = self.n |
|
|
| def set_postfix(self, **kw): self.postfix = kw |
|
|
| def _print(self): |
| el = time.time() - self.start_time |
| rate = self.n / el if el > 0 else 0 |
| pct = 100 * self.n / self.total if self.total > 0 else 0 |
| pf = ' '.join(f"{k}={v}" for k, v in self.postfix.items()) |
| print(f"[{pct:.1f}%] {self.n:,}/{self.total:,} {self.unit} | {rate:,.0f} tok/s | {pf}") |
| sys.stdout.flush() |
|
|
| def close(self): self._print(); print("Done.") |
|
|
| |
| def _open_stream(ds_name: str, seed: int = 42): |
| from datasets import load_dataset, DownloadConfig |
| dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True) |
| base, config = (ds_name.split(":", 1) + [None])[:2] |
| if config: |
| ds = load_dataset(base, config, split="train", streaming=True, download_config=dc) |
| else: |
| ds = load_dataset(base, split="train", streaming=True, download_config=dc) |
| return iter(ds.shuffle(buffer_size=1000, seed=seed)) |
|
|
| def token_stream(ds_names: str, target: int, seed: int = 42, field: str = "text"): |
| sources = [s.strip() for s in ds_names.split(",") if s.strip()] |
| if not sources: return |
| src_idx, emitted, it, attempts = 0, 0, None, 0 |
| while emitted < target: |
| try: |
| if it is None: it = _open_stream(sources[src_idx], seed) |
| ex = next(it) |
| text = ex.get(field) or ex.get("text") |
| if not isinstance(text, str): continue |
| enc = tok.encode(text) |
| if EOS is not None and (not enc or enc[-1] != EOS): |
| enc.append(EOS) |
| for t in enc: |
| yield t; emitted += 1 |
| if emitted >= target: return |
| attempts = 0 |
| except StopIteration: |
| it = None; src_idx = (src_idx + 1) % len(sources) |
| except Exception as e: |
| attempts += 1 |
| s = min(60.0, 2.0 ** min(attempts, 6)) |
| print(f"[stream-retry] {sources[src_idx]}: {type(e).__name__}, sleep {s:.1f}s") |
| time.sleep(s); it = None |
| if attempts % 5 == 0 and len(sources) > 1: |
| src_idx = (src_idx + 1) % len(sources) |
|
|
| def token_stream_chunked(ds_names: str, target: int, chunk_size: int = 16384, **kw): |
| buf = [] |
| for t in token_stream(ds_names, target, **kw): |
| buf.append(t) |
| if len(buf) >= chunk_size: |
| yield buf; buf = [] |
| if buf: yield buf |
|
|
|
|
| |
| |
| |
| |
|
|
| FNV_OFFSET = 14695981039346656037 |
| FNV_PRIME = 1099511628211 |
| MASK64 = (1 << 64) - 1 |
| INT64_MAX = (1 << 63) - 1 |
| INT64_WRAP = 1 << 64 |
| FNV_OFFSET_S = FNV_OFFSET - INT64_WRAP |
|
|
| def _hash_ngram_py(ngram: tuple) -> int: |
| """CPU hash for building tables.""" |
| h = FNV_OFFSET |
| for t in ngram: |
| h ^= (t & MASK64) |
| h = (h * FNV_PRIME) & MASK64 |
| return h if h <= INT64_MAX else h - INT64_WRAP |
|
|
| def _hash_ngram_batch_gpu(contexts: torch.Tensor) -> torch.Tensor: |
| """ |
| GPU batch hash. contexts: (batch, n) int64 tensor β (batch,) int64 hashes. |
| Same FNV-1a algorithm vectorised in torch ops. |
| """ |
| B, N = contexts.shape |
| h = torch.full((B,), FNV_OFFSET_S, dtype=torch.int64, device=contexts.device) |
| for i in range(N): |
| h = h ^ contexts[:, i] |
| h = h * FNV_PRIME |
| return h |
|
|
|
|
| class GPUHashTable: |
| """ |
| Immutable GPU hash table for n-gram lookups. |
| Built once from CPU dict, all lookups are GPU batch ops. |
| |
| Storage: sorted array of (hash, count) pairs. |
| Lookup: searchsorted on hash array β O(log N) per query, fully parallel. |
| Collision rate with FNV-1a on int64: ~1/2^64 per pair β negligible. |
| """ |
|
|
| def __init__(self): |
| self.hashes: Optional[torch.Tensor] = None |
| self.counts: Optional[torch.Tensor] = None |
| self.continuation_counts: Optional[torch.Tensor] = None |
| self.total: int = 0 |
| self.size: int = 0 |
|
|
| def build_from_dict(self, count_dict: Dict[tuple, Dict[int, int]], device=DEV): |
| """Convert {context: {next_tok: count}} to GPU sorted hash table.""" |
| entries = [] |
| total = 0 |
| for ctx, nexts in count_dict.items(): |
| for next_tok, cnt in nexts.items(): |
| key = ctx + (next_tok,) |
| h = _hash_ngram_py(key) |
| entries.append((h, cnt)) |
| total += cnt |
|
|
| if not entries: |
| self.hashes = torch.empty(0, dtype=torch.int64, device=device) |
| self.counts = torch.empty(0, dtype=torch.int64, device=device) |
| self.total = 0; self.size = 0 |
| return self |
|
|
| entries.sort(key=lambda x: x[0]) |
| self.hashes = torch.tensor([e[0] for e in entries], dtype=torch.int64, device=device) |
| self.counts = torch.tensor([e[1] for e in entries], dtype=torch.int64, device=device) |
| self.total = total |
| self.size = len(entries) |
| return self |
|
|
| def build_context_table(self, count_dict: Dict[tuple, Dict[int, int]], device=DEV): |
| """Build context-level table: hash(context) β total count + unique continuations.""" |
| entries = [] |
| for ctx, nexts in count_dict.items(): |
| h = _hash_ngram_py(ctx) |
| total = sum(nexts.values()) |
| n_unique = len(nexts) |
| entries.append((h, total, n_unique)) |
|
|
| if not entries: |
| self.hashes = torch.empty(0, dtype=torch.int64, device=device) |
| self.counts = torch.empty(0, dtype=torch.int64, device=device) |
| self.continuation_counts = torch.empty(0, dtype=torch.int64, device=device) |
| self.total = 0; self.size = 0 |
| return self |
|
|
| entries.sort(key=lambda x: x[0]) |
| self.hashes = torch.tensor([e[0] for e in entries], dtype=torch.int64, device=device) |
| self.counts = torch.tensor([e[1] for e in entries], dtype=torch.int64, device=device) |
| self.continuation_counts = torch.tensor([e[2] for e in entries], dtype=torch.int64, device=device) |
| self.total = sum(e[1] for e in entries) |
| self.size = len(entries) |
| return self |
|
|
| def batch_lookup(self, hashes: torch.Tensor) -> torch.Tensor: |
| """GPU batch lookup. hashes: (B,) int64 β (B,) int64 counts (0 if miss).""" |
| if self.size == 0: |
| return torch.zeros_like(hashes) |
| idx = torch.searchsorted(self.hashes, hashes).clamp(0, self.size - 1) |
| found = (self.hashes[idx] == hashes) |
| return torch.where(found, self.counts[idx], torch.zeros_like(hashes)) |
|
|
| def batch_lookup_continuations(self, hashes: torch.Tensor) -> torch.Tensor: |
| """Lookup unique continuation count for context hashes.""" |
| if self.size == 0 or self.continuation_counts is None: |
| return torch.zeros_like(hashes) |
| idx = torch.searchsorted(self.hashes, hashes).clamp(0, self.size - 1) |
| found = (self.hashes[idx] == hashes) |
| return torch.where(found, self.continuation_counts[idx], torch.zeros_like(hashes)) |
|
|
| def memory_bytes(self) -> int: |
| total = 0 |
| for t in [self.hashes, self.counts, self.continuation_counts]: |
| if t is not None: total += t.nelement() * t.element_size() |
| return total |
|
|
|
|
| |
| |
| |
|
|
| class MarkovLM: |
| """ |
| Classical n-gram LM with Modified Kneser-Ney smoothing. |
| |
| Training: CPU (dict-based counting, memory-bound). |
| Inference/Eval: GPU (hash table batch lookup, vectorised smoothing). |
| |
| cpu_counts[order] = {context_tuple: {next_token: count}} |
| order 0 = unigram (context = ()) |
| order 1 = bigram (context = (t,)) |
| order k = (k+1)-gram (context = (t1,...,tk)) |
| """ |
|
|
| def __init__(self, max_order: int = 5): |
| self.max_order = max_order |
| self.cpu_counts: List[Dict[tuple, Dict[int, int]]] = [ |
| defaultdict(lambda: defaultdict(int)) for _ in range(max_order) |
| ] |
| self.total_tokens = 0 |
| self.tokens_trained = 0 |
|
|
| |
| self.gpu_ngram_tables: List[Optional[GPUHashTable]] = [None] * max_order |
| self.gpu_context_tables: List[Optional[GPUHashTable]] = [None] * max_order |
| self.frozen = False |
|
|
| |
| self.discounts: List[Tuple[float, float, float]] = [(0.5, 1.0, 1.5)] * max_order |
|
|
| |
| self.gpu_unigram_probs: Optional[torch.Tensor] = None |
|
|
| |
|
|
| def train_on_tokens(self, token_iter, total_tokens: int, save_path: Optional[str] = None, |
| save_every: int = 100_000_000, min_count_prune: int = 2, |
| prune_every: int = 200_000_000): |
| window: List[int] = [] |
| pbar = SafeProgress(total_tokens, unit="tok", interval=2_000_000) |
| last_save = 0 |
| last_prune = 0 |
|
|
| print(f"[train] max_order={self.max_order}, target={total_tokens:,}") |
| print(f"[train] {get_uk_time()}") |
|
|
| for chunk in token_iter: |
| for t in chunk: |
| window.append(t) |
| self.total_tokens += 1 |
| self.tokens_trained += 1 |
|
|
| for order in range(self.max_order): |
| ctx_len = order |
| if len(window) > ctx_len: |
| ctx = tuple(window[-(ctx_len + 1):-1]) if ctx_len > 0 else () |
| next_tok = window[-1] |
| self.cpu_counts[order][ctx][next_tok] += 1 |
|
|
| if len(window) > self.max_order + 10: |
| window = window[-(self.max_order + 1):] |
|
|
| pbar.update(1) |
|
|
| if self.tokens_trained - last_prune >= prune_every and min_count_prune > 1: |
| self._prune_counts(min_count_prune) |
| last_prune = self.tokens_trained |
|
|
| if save_path and self.tokens_trained - last_save >= save_every: |
| self._save_cpu(save_path) |
| last_save = self.tokens_trained |
|
|
| pbar.close() |
| self._print_stats() |
| if save_path: |
| self._save_cpu(save_path) |
|
|
| def _prune_counts(self, min_count: int): |
| """Per-order pruning: keep more at low orders, prune harder at high orders. |
| Orders 0-2 (uni/bi/tri): never prune (small, high value) |
| Orders 3-4: prune at min_count |
| Orders 5+: prune at min_count + (order - 4) to save memory""" |
| pruned = 0 |
| for order in range(1, self.max_order): |
| if order <= 2: |
| continue |
| threshold = min_count + max(0, order - 4) |
| to_delete_ctx = [] |
| for ctx, nexts in self.cpu_counts[order].items(): |
| low = [t for t, c in nexts.items() if c < threshold] |
| for t in low: |
| del nexts[t]; pruned += 1 |
| if not nexts: |
| to_delete_ctx.append(ctx) |
| for ctx in to_delete_ctx: |
| del self.cpu_counts[order][ctx] |
| if pruned: |
| gc.collect() |
| print(f" [prune] Removed {pruned:,} entries (per-order thresholds)") |
|
|
| def _print_stats(self): |
| print(f"\n{'='*60}") |
| print(f" N-GRAM MODEL β {self.max_order}-gram | {self.tokens_trained:,} tokens trained") |
| print(f"{'='*60}") |
| total_entries = 0 |
| for order in range(self.max_order): |
| n_ctx = len(self.cpu_counts[order]) |
| n_ent = sum(len(v) for v in self.cpu_counts[order].values()) |
| total_entries += n_ent |
| bytes_est = n_ent * (8 * (order + 2)) |
| print(f" {order+1}-gram: {n_ctx:>12,} contexts | {n_ent:>12,} entries | ~{bytes_est/1e9:.2f} GB") |
| print(f" TOTAL: {total_entries:>39,} entries") |
| print(f"{'='*60}") |
|
|
| |
|
|
| def freeze(self, device=DEV, prune_threshold: int = 1): |
| print(f"\n[freeze] Building GPU hash tables on {device}...") |
| t0 = time.time() |
|
|
| if prune_threshold > 1: |
| self._prune_counts(prune_threshold) |
|
|
| self._compute_kn_discounts() |
|
|
| total_gpu_bytes = 0 |
| for order in range(self.max_order): |
| counts = self.cpu_counts[order] |
| if not counts: continue |
|
|
| gt = GPUHashTable().build_from_dict(counts, device=device) |
| self.gpu_ngram_tables[order] = gt |
|
|
| ct = GPUHashTable().build_context_table(counts, device=device) |
| self.gpu_context_tables[order] = ct |
|
|
| mem = gt.memory_bytes() + ct.memory_bytes() |
| total_gpu_bytes += mem |
| print(f" {order+1}-gram: {gt.size:,} entries, {ct.size:,} contexts ({mem/1e6:.1f} MB GPU)") |
|
|
| |
| if self.cpu_counts[0] and () in self.cpu_counts[0]: |
| unigram = self.cpu_counts[0][()] |
| probs = torch.zeros(VOCAB, dtype=torch.float32, device=device) |
| for tok_id, cnt in unigram.items(): |
| if tok_id < VOCAB: |
| probs[tok_id] = cnt |
| self.gpu_unigram_probs = probs / probs.sum().clamp(min=1) |
| total_gpu_bytes += probs.nelement() * 4 |
|
|
| print(f"[freeze] Done in {time.time()-t0:.1f}s. GPU: {total_gpu_bytes/1e6:.1f} MB") |
| self.frozen = True |
|
|
| def free_cpu_counts(self): |
| for i in range(self.max_order): |
| self.cpu_counts[i] = {} |
| gc.collect() |
| print("[free] CPU dicts released.") |
|
|
| def _compute_kn_discounts(self): |
| for order in range(self.max_order): |
| n1 = n2 = n3 = n4 = 0 |
| for ctx, nexts in self.cpu_counts[order].items(): |
| for cnt in nexts.values(): |
| if cnt == 1: n1 += 1 |
| elif cnt == 2: n2 += 1 |
| elif cnt == 3: n3 += 1 |
| elif cnt == 4: n4 += 1 |
|
|
| if n1 == 0 or n2 == 0: |
| self.discounts[order] = (0.5, 1.0, 1.5) |
| continue |
|
|
| Y = n1 / (n1 + 2 * n2) |
| D1 = max(0.01, min(1 - 2 * Y * (n2 / max(n1, 1)), 0.99)) |
| D2 = max(D1, min(2 - 3 * Y * (n3 / max(n2, 1)), 1.99)) |
| D3 = max(D2, min(3 - 4 * Y * (n4 / max(n3, 1)), 2.99)) |
|
|
| self.discounts[order] = (D1, D2, D3) |
| print(f" KN {order+1}-gram: D1={D1:.3f} D2={D2:.3f} D3+={D3:.3f}") |
|
|
| |
|
|
| @torch.no_grad() |
| def batch_log_probs(self, contexts: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute log P(target | context) via interpolated Modified KN on GPU. |
| contexts: (B, max_order-1) int64, left-padded with -1 |
| targets: (B,) int64 |
| Returns: (B,) float32 log probabilities |
| """ |
| B = targets.shape[0] |
| device = targets.device |
|
|
| |
| log_probs = torch.full((B,), math.log(1.0 / VOCAB), dtype=torch.float32, device=device) |
|
|
| for order in range(self.max_order): |
| gt = self.gpu_ngram_tables[order] |
| ct = self.gpu_context_tables[order] |
| if gt is None or ct is None or gt.size == 0: |
| continue |
|
|
| ctx_len = order |
|
|
| if ctx_len == 0: |
| |
| if self.gpu_unigram_probs is not None: |
| safe_t = targets.clamp(0, VOCAB - 1) |
| uni_p = self.gpu_unigram_probs[safe_t] |
| valid = uni_p > 0 |
| log_probs = torch.where(valid, torch.log(uni_p + 1e-30), log_probs) |
| continue |
|
|
| if ctx_len > contexts.shape[1]: |
| continue |
|
|
| ctx = contexts[:, -ctx_len:] |
| has_ctx = (ctx >= 0).all(dim=1) |
| if not has_ctx.any(): |
| continue |
|
|
| |
| full_ngram = torch.cat([ctx, targets.unsqueeze(1)], dim=1) |
| ngram_counts = gt.batch_lookup(_hash_ngram_batch_gpu(full_ngram)).float() |
| ctx_hashes = _hash_ngram_batch_gpu(ctx) |
| ctx_totals = ct.batch_lookup(ctx_hashes).float() |
| ctx_uniques = ct.batch_lookup_continuations(ctx_hashes).float() |
|
|
| |
| D1, D2, D3 = self.discounts[order] |
| discount = torch.where(ngram_counts >= 3, D3, |
| torch.where(ngram_counts >= 2, D2, |
| torch.where(ngram_counts >= 1, D1, 0.0))) |
|
|
| numerator = (ngram_counts - discount).clamp(min=0) |
| denominator = ctx_totals.clamp(min=1) |
|
|
| |
| gamma = (D3 * ctx_uniques) / denominator |
| gamma = gamma.clamp(0, 1) |
|
|
| |
| p_lower = log_probs.exp() |
| p_combined = numerator / denominator + gamma * p_lower |
|
|
| valid = has_ctx & (ctx_totals > 0) |
| log_probs = torch.where(valid, torch.log(p_combined.clamp(min=1e-30)), log_probs) |
|
|
| return log_probs |
|
|
| |
|
|
| @torch.no_grad() |
| def evaluate(self, eval_source: str, eval_tokens: int, batch_size: int = 1024): |
| """Evaluate perplexity on held-out data using GPU batch processing.""" |
| assert self.frozen, "Call freeze() first" |
| ctx_len = self.max_order - 1 |
|
|
| print(f"\n[eval] Source: {eval_source}") |
| print(f"[eval] {eval_tokens:,} tokens, batch={batch_size}") |
| print(f"[eval] {get_uk_time()}") |
|
|
| |
| print("[eval] Collecting tokens...") |
| all_tokens = [] |
| for chunk in token_stream_chunked(eval_source, eval_tokens + ctx_len, chunk_size=8192): |
| all_tokens.extend(chunk) |
| if len(all_tokens) >= eval_tokens + ctx_len: |
| break |
| all_tokens = all_tokens[:eval_tokens + ctx_len] |
| tokens_t = torch.tensor(all_tokens, dtype=torch.int64) |
| n_eval = len(all_tokens) - ctx_len |
| print(f"[eval] {n_eval:,} evaluation positions") |
|
|
| total_log_prob = 0.0 |
| total = 0 |
| pbar = SafeProgress(n_eval, unit="tok", interval=50_000) |
|
|
| for start in range(0, n_eval, batch_size): |
| end = min(start + batch_size, n_eval) |
| B = end - start |
|
|
| |
| |
| |
| indices = torch.arange(start, end) |
| ctx_indices = indices.unsqueeze(1) + torch.arange(ctx_len).unsqueeze(0) |
| tgt_indices = indices + ctx_len |
|
|
| contexts = tokens_t[ctx_indices].to(DEV) |
| targets = tokens_t[tgt_indices].to(DEV) |
|
|
| log_probs = self.batch_log_probs(contexts, targets) |
| total_log_prob += log_probs.sum().item() |
| total += B |
| pbar.update(B) |
|
|
| pbar.close() |
|
|
| avg_lp = total_log_prob / total |
| ppl = math.exp(-avg_lp) |
|
|
| print(f"\n{'='*55}") |
| print(f" PERPLEXITY EVALUATION") |
| print(f"{'='*55}") |
| print(f" Tokens: {total:,}") |
| print(f" Avg log P: {avg_lp:.4f}") |
| print(f" Perplexity: {ppl:.2f}") |
| print(f" {get_uk_time()}") |
| print(f"{'='*55}") |
| return {"perplexity": ppl, "avg_log_prob": avg_lp, "tokens": total} |
|
|
| @torch.no_grad() |
| def evaluate_accuracy(self, eval_source: str, eval_tokens: int, |
| batch_size: int = 256, top_k_check: int = 1000): |
| """ |
| Next-token accuracy: for each position, check if the highest-probability |
| token (among top-K unigram candidates + target) matches ground truth. |
| GPU-parallel: each position checks top_k candidates simultaneously. |
| """ |
| assert self.frozen, "Call freeze() first" |
| ctx_len = self.max_order - 1 |
|
|
| print(f"\n[accuracy] top_k_check={top_k_check}, {eval_tokens:,} tokens") |
|
|
| |
| if self.gpu_unigram_probs is not None: |
| _, topk = self.gpu_unigram_probs.topk(min(top_k_check, VOCAB)) |
| else: |
| topk = torch.arange(top_k_check, device=DEV) |
|
|
| |
| all_tokens = [] |
| for chunk in token_stream_chunked(eval_source, eval_tokens + ctx_len, chunk_size=8192): |
| all_tokens.extend(chunk) |
| if len(all_tokens) >= eval_tokens + ctx_len: |
| break |
| all_tokens = all_tokens[:eval_tokens + ctx_len] |
| tokens_t = torch.tensor(all_tokens, dtype=torch.int64) |
| n_eval = len(all_tokens) - ctx_len |
|
|
| correct = 0 |
| total = 0 |
| pbar = SafeProgress(n_eval, unit="tok", interval=20_000) |
|
|
| for start in range(0, n_eval, batch_size): |
| end = min(start + batch_size, n_eval) |
| B = end - start |
|
|
| indices = torch.arange(start, end) |
| ctx_indices = indices.unsqueeze(1) + torch.arange(ctx_len).unsqueeze(0) |
| tgt_indices = indices + ctx_len |
| contexts = tokens_t[ctx_indices].to(DEV) |
| targets = tokens_t[tgt_indices].to(DEV) |
|
|
| |
| target_lp = self.batch_log_probs(contexts, targets) |
|
|
| |
| K = topk.shape[0] |
| ctx_exp = contexts.unsqueeze(1).expand(B, K, ctx_len).reshape(B * K, ctx_len) |
| cand_exp = topk.unsqueeze(0).expand(B, K).reshape(B * K) |
|
|
| cand_lp = self.batch_log_probs(ctx_exp, cand_exp).reshape(B, K) |
| best_cand_lp, _ = cand_lp.max(dim=1) |
|
|
| |
| correct += (target_lp >= best_cand_lp).sum().item() |
| total += B |
|
|
| if total % 10000 < batch_size: |
| pbar.set_postfix(acc=f"{100*correct/max(total,1):.2f}%") |
| pbar.update(B) |
|
|
| pbar.close() |
| acc = correct / max(total, 1) |
|
|
| print(f"\n{'='*55}") |
| print(f" ACCURACY (top-{top_k_check} check)") |
| print(f"{'='*55}") |
| print(f" Correct: {correct:,} / {total:,}") |
| print(f" Accuracy: {100*acc:.2f}%") |
| print(f" Target: 47% (Infini-gram on Pile)") |
| print(f" {get_uk_time()}") |
| print(f"{'='*55}") |
| return {"accuracy": acc, "correct": correct, "total": total} |
|
|
| |
|
|
| @torch.no_grad() |
| def generate(self, prompt: str, max_new: int = 200, temperature: float = 0.8, |
| top_k: int = 50, top_p: float = 0.9): |
| assert self.frozen, "Call freeze() first" |
| ctx_len = self.max_order - 1 |
| ids = tok.encode(prompt) |
| n_cands = min(top_k * 10, VOCAB) |
|
|
| print(f"\n[gen] Prompt: '{prompt}' ({len(ids)} tokens)") |
| t0 = time.time() |
|
|
| |
| if self.gpu_unigram_probs is not None: |
| _, candidates = self.gpu_unigram_probs.topk(n_cands) |
| else: |
| candidates = torch.arange(n_cands, device=DEV) |
|
|
| for _ in range(max_new): |
| if len(ids) >= ctx_len: |
| ctx = ids[-ctx_len:] |
| else: |
| ctx = [-1] * (ctx_len - len(ids)) + ids |
|
|
| ctx_t = torch.tensor([ctx], dtype=torch.int64, device=DEV).expand(n_cands, ctx_len) |
| log_probs = self.batch_log_probs(ctx_t, candidates) |
|
|
| |
| probs = (log_probs / max(temperature, 1e-8)).softmax(0) |
|
|
| |
| if top_k > 0 and top_k < n_cands: |
| vals, idx = probs.topk(top_k) |
| mask = torch.zeros_like(probs) |
| mask.scatter_(0, idx, vals) |
| probs = mask |
|
|
| |
| if top_p < 1.0: |
| sp, si = probs.sort(descending=True) |
| cum = sp.cumsum(0) |
| cutoff = cum > top_p |
| cutoff[0] = False |
| sp[cutoff] = 0 |
| probs = torch.zeros_like(probs).scatter_(0, si, sp) |
|
|
| if probs.sum() == 0: |
| next_tok = candidates[0].item() |
| else: |
| probs = probs / probs.sum() |
| next_tok = candidates[probs.multinomial(1).item()].item() |
|
|
| ids.append(next_tok) |
| if next_tok == EOS: |
| break |
|
|
| elapsed = time.time() - t0 |
| gen_n = len(ids) - len(tok.encode(prompt)) |
| text = tok.decode(ids, skip_special_tokens=True) |
| print(f"\n{text}") |
| print(f"\n[{elapsed:.2f}s | {gen_n} tokens | {gen_n/max(elapsed,0.01):.1f} tok/s]") |
| return text |
|
|
| |
|
|
| def _save_cpu(self, path: str): |
| p = Path(path).with_suffix('.cpu.pkl') |
| p.parent.mkdir(parents=True, exist_ok=True) |
| data = { |
| 'max_order': self.max_order, |
| 'cpu_counts': [{k: dict(v) for k, v in d.items()} for d in self.cpu_counts], |
| 'total_tokens': self.total_tokens, |
| 'tokens_trained': self.tokens_trained, |
| 'discounts': self.discounts, |
| } |
| tmp = p.with_suffix('.tmp') |
| with open(tmp, 'wb') as f: |
| pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) |
| tmp.replace(p) |
| print(f"[save] {p} ({p.stat().st_size/1e6:.1f} MB)") |
|
|
| def save(self, path: str): |
| self._save_cpu(path) |
|
|
| @classmethod |
| def load(cls, path: str) -> 'MarkovLM': |
| p = Path(path) |
| for suffix in ['.cpu.pkl', '.pkl']: |
| candidate = p.with_suffix(suffix) |
| if candidate.exists(): |
| p = candidate; break |
|
|
| print(f"[load] {p}...") |
| with open(p, 'rb') as f: |
| data = pickle.load(f) |
|
|
| model = cls(max_order=data['max_order']) |
| model.total_tokens = data['total_tokens'] |
| model.tokens_trained = data['tokens_trained'] |
| model.discounts = data.get('discounts', model.discounts) |
|
|
| for order in range(model.max_order): |
| raw = data['cpu_counts'][order] |
| dd = defaultdict(lambda: defaultdict(int)) |
| for ctx, nexts in raw.items(): |
| dd[ctx] = defaultdict(int, nexts) |
| model.cpu_counts[order] = dd |
|
|
| model._print_stats() |
| print("[load] Freezing to GPU...") |
| model.freeze() |
| return model |
|
|
| def status(self): |
| self._print_stats() |
| if self.frozen: |
| gpu_mem = sum( |
| (gt.memory_bytes() if gt else 0) + (ct.memory_bytes() if ct else 0) |
| for gt, ct in zip(self.gpu_ngram_tables, self.gpu_context_tables) |
| ) |
| print(f" GPU: {gpu_mem/1e6:.1f} MB") |
| print(f" Frozen: {self.frozen} | {get_uk_time()}") |
|
|
|
|
| |
| |
| |
|
|
| STATUS_FILE = "/workspace/markov_status.json" |
|
|
| def write_status(tokens_trained, phase, tok_per_sec=0, extra=None): |
| try: |
| data = {"tokens_trained": tokens_trained, "phase": phase, |
| "tok_per_sec": tok_per_sec, "updated": time.time(), "uk_time": get_uk_time()} |
| if extra: data.update(extra) |
| with open(STATUS_FILE, 'w') as f: |
| json.dump(data, f) |
| except: pass |
|
|
|
|
| |
| |
| |
|
|
| def cmd_train(args): |
| print(f"{'='*60}") |
| print(f" MARKOV CHAIN LM β HYBRID CPU+GPU") |
| print(f"{'='*60}") |
| print(f" Order: {args.max_order}-gram") |
| print(f" Tokens: {args.tokens:,}") |
| print(f" Source: {args.source[:80]}...") |
| print(f" Save: {args.save}") |
| print(f" Prune: min_count={args.min_count}") |
| print(f" Device: CPU (counting) β {DEV} (inference)") |
| print(f" {get_uk_time()}") |
| print(f"{'='*60}") |
|
|
| if args.resume and Path(args.save).with_suffix('.cpu.pkl').exists(): |
| print("[resume] Loading existing...") |
| model = MarkovLM.load(args.save) |
| remaining = args.tokens - model.tokens_trained |
| if remaining <= 0: |
| print(f"[resume] Already at {model.tokens_trained:,} >= {args.tokens:,}") |
| return model |
| print(f"[resume] {remaining:,} tokens remaining") |
| else: |
| model = MarkovLM(max_order=args.max_order) |
|
|
| stream = token_stream_chunked(args.source, args.tokens, chunk_size=16384, seed=42, field=args.field) |
| model.train_on_tokens(stream, total_tokens=args.tokens, save_path=args.save, |
| save_every=args.save_every, min_count_prune=args.min_count, |
| prune_every=args.prune_every) |
|
|
| print("\n[train] Freezing to GPU...") |
| model.freeze(prune_threshold=args.freeze_prune) |
| model.save(args.save) |
| return model |
|
|
|
|
| def cmd_eval(args): |
| model = MarkovLM.load(args.model) |
| results = model.evaluate(eval_source=args.eval_source, eval_tokens=args.eval_tokens, |
| batch_size=args.batch_size) |
| if args.accuracy: |
| acc = model.evaluate_accuracy(eval_source=args.eval_source, |
| eval_tokens=min(args.eval_tokens, args.accuracy_tokens), |
| batch_size=args.acc_batch_size, top_k_check=args.top_k_check) |
| results.update(acc) |
| return results |
|
|
|
|
| def cmd_generate(args): |
| model = MarkovLM.load(args.model) |
| model.generate(prompt=args.prompt, max_new=args.max_new, |
| temperature=args.temperature, top_k=args.top_k, top_p=args.top_p) |
|
|
|
|
| def cmd_status(args): |
| p = Path(args.model) |
| for s in ['.cpu.pkl', '.pkl']: |
| if p.with_suffix(s).exists(): |
| model = MarkovLM.load(args.model) |
| model.status() |
| return |
| print(f"No model at {args.model}") |
| try: |
| with open(STATUS_FILE) as f: |
| s = json.load(f) |
| print(f"Phase: {s.get('phase')} | Tokens: {s.get('tokens_trained',0):,} | {s.get('uk_time','?')}") |
| except: print("No status file.") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser(description="Hybrid CPU+GPU Markov Chain LM") |
| sub = ap.add_subparsers(dest="cmd", required=True) |
|
|
| tr = sub.add_parser("train") |
| tr.add_argument("--max_order", type=int, default=10) |
| tr.add_argument("--tokens", type=int, default=2_000_000_000) |
| tr.add_argument("--source", default=DEFAULT_SOURCES) |
| tr.add_argument("--field", default="text") |
| tr.add_argument("--save", default="markov_model") |
| tr.add_argument("--save_every", type=int, default=100_000_000) |
| tr.add_argument("--min_count", type=int, default=2) |
| tr.add_argument("--prune_every", type=int, default=300_000_000) |
| tr.add_argument("--freeze_prune", type=int, default=2) |
| tr.add_argument("--resume", action="store_true") |
|
|
| ev = sub.add_parser("eval") |
| ev.add_argument("--model", required=True) |
| ev.add_argument("--eval_source", default=EVAL_SOURCE) |
| ev.add_argument("--eval_tokens", type=int, default=1_000_000) |
| ev.add_argument("--batch_size", type=int, default=1024) |
| ev.add_argument("--accuracy", action="store_true") |
| ev.add_argument("--accuracy_tokens", type=int, default=200_000) |
| ev.add_argument("--acc_batch_size", type=int, default=128) |
| ev.add_argument("--top_k_check", type=int, default=1000) |
|
|
| gen = sub.add_parser("generate") |
| gen.add_argument("--model", required=True) |
| gen.add_argument("--prompt", required=True) |
| gen.add_argument("--max_new", type=int, default=200) |
| gen.add_argument("--temperature", type=float, default=0.8) |
| gen.add_argument("--top_k", type=int, default=50) |
| gen.add_argument("--top_p", type=float, default=0.9) |
|
|
| st = sub.add_parser("status") |
| st.add_argument("--model", default="markov_model") |
|
|
| args = ap.parse_args() |
| {"train": cmd_train, "eval": cmd_eval, "generate": cmd_generate, "status": cmd_status}[args.cmd](args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|