markov-5gram-500m / markov_lm.py
OpenTransformer's picture
Upload markov_lm.py with huggingface_hub
5f9cb66 verified
#!/usr/bin/env python3
"""
markov_lm.py β€” Hybrid CPU+GPU N-gram (Markov Chain) Language Model
═══════════════════════════════════════════════════════════════════
Uses same datasets, tokenizer, and infrastructure as AGILLM nC.py.
Target: beat Infini-gram's 47% next-token accuracy on Pile validation.
Architecture:
TRAIN β†’ CPU: stream tokens, accumulate counts in Python dicts (memory-bound)
FREEZE β†’ Convert dicts to GPU hash tables (sorted tensors + searchsorted)
EVAL β†’ GPU: batch context lookup, parallel KN smoothing, vectorised accuracy
INFER β†’ GPU: parallel sampling from smoothed distributions
Key design:
- Modified Kneser-Ney smoothing (gold standard for n-gram LMs)
- GPU hash tables via sorted int64 keys + torch.searchsorted (batch parallel)
- FNV-1a hash: n-gram tuple β†’ int64 for O(log N) GPU lookup
- All orders 1..max_order stored and queried simultaneously
- Prunes n-grams below min_count threshold to control memory
- Checkpoint: save/load as .pkl (CPU dicts)
Usage:
python markov_lm.py train --max_order 5 --tokens 500_000_000 --save markov_5gram
python markov_lm.py eval --model markov_5gram --eval_tokens 1_000_000
python markov_lm.py generate --model markov_5gram --prompt "The meaning of" --max_new 200
python markov_lm.py status --model markov_5gram
"""
from __future__ import annotations
import argparse, gc, math, os, pickle, sys, time, json
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from datetime import datetime, timezone, timedelta
import torch
import torch.nn.functional as F
# ───────────────────── Device ─────────────────────
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CPU = torch.device("cpu")
print(f"[markov_lm] Device: {DEV}")
if DEV.type == "cuda":
print(f"[markov_lm] GPU: {torch.cuda.get_device_name(0)}")
props = torch.cuda.get_device_properties(0)
vram = getattr(props, 'total_memory', None) or getattr(props, 'total_mem', 0)
print(f"[markov_lm] VRAM: {vram / 1e9:.1f} GB")
# ───────────────────── Tokenizer (same as nC.py) ─────────────────────
TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "gpt2")
def _load_tokenizer():
from transformers import AutoTokenizer, logging as hf_log
hf_log.set_verbosity_error()
t = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True)
if t.pad_token is None:
t.add_special_tokens({"pad_token": "<|pad|>"})
return t
tok = _load_tokenizer()
VOCAB = max(tok.get_vocab().values()) + 1
EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
print(f"[markov_lm] Vocab: {VOCAB:,} | EOS: {EOS}")
# ───────────────────── Dataset Sources (same as nC.py) ─────────────────────
DEFAULT_SOURCES = ",".join([
"OpenTransformer/goddess-crawl",
"OpenTransformer/agillm-crawl-data",
"OpenTransformer/web-crawl-2026",
"OpenTransformer/web-crawl-clean-v2",
"OpenTransformer/scraped-web-data",
"OpenTransformer/turbo-crawl",
"OpenTransformer/sft-data-clean",
"OpenTransformer/web-crawl-v1",
])
EVAL_SOURCE = "monology/pile-uncopyrighted"
# ───────────────────── UK Time ─────────────────────
def get_uk_time() -> str:
utc = datetime.now(timezone.utc)
y = utc.year
mar = datetime(y, 3, 31, 1, 0, tzinfo=timezone.utc)
while mar.weekday() != 6: mar = mar.replace(day=mar.day - 1)
oct = datetime(y, 10, 31, 1, 0, tzinfo=timezone.utc)
while oct.weekday() != 6: oct = oct.replace(day=oct.day - 1)
if mar <= utc < oct:
return (utc + timedelta(hours=1)).strftime('%Y-%m-%d %H:%M:%S BST')
return utc.strftime('%Y-%m-%d %H:%M:%S GMT')
# ───────────────────── SafeProgress (no tqdm) ─────────────────────
class SafeProgress:
def __init__(self, total, initial=0, unit="tok", interval=5_000_000):
self.total, self.n, self.unit = total, initial, unit
self.last_print, self.postfix = initial, {}
self.start_time = time.time()
self.interval = interval
def update(self, n=1):
self.n += n
if self.n - self.last_print >= self.interval:
self._print(); self.last_print = self.n
def set_postfix(self, **kw): self.postfix = kw
def _print(self):
el = time.time() - self.start_time
rate = self.n / el if el > 0 else 0
pct = 100 * self.n / self.total if self.total > 0 else 0
pf = ' '.join(f"{k}={v}" for k, v in self.postfix.items())
print(f"[{pct:.1f}%] {self.n:,}/{self.total:,} {self.unit} | {rate:,.0f} tok/s | {pf}")
sys.stdout.flush()
def close(self): self._print(); print("Done.")
# ───────────────────── Token Stream (from nC.py) ─────────────────────
def _open_stream(ds_name: str, seed: int = 42):
from datasets import load_dataset, DownloadConfig
dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
base, config = (ds_name.split(":", 1) + [None])[:2]
if config:
ds = load_dataset(base, config, split="train", streaming=True, download_config=dc)
else:
ds = load_dataset(base, split="train", streaming=True, download_config=dc)
return iter(ds.shuffle(buffer_size=1000, seed=seed))
def token_stream(ds_names: str, target: int, seed: int = 42, field: str = "text"):
sources = [s.strip() for s in ds_names.split(",") if s.strip()]
if not sources: return
src_idx, emitted, it, attempts = 0, 0, None, 0
while emitted < target:
try:
if it is None: it = _open_stream(sources[src_idx], seed)
ex = next(it)
text = ex.get(field) or ex.get("text")
if not isinstance(text, str): continue
enc = tok.encode(text)
if EOS is not None and (not enc or enc[-1] != EOS):
enc.append(EOS)
for t in enc:
yield t; emitted += 1
if emitted >= target: return
attempts = 0
except StopIteration:
it = None; src_idx = (src_idx + 1) % len(sources)
except Exception as e:
attempts += 1
s = min(60.0, 2.0 ** min(attempts, 6))
print(f"[stream-retry] {sources[src_idx]}: {type(e).__name__}, sleep {s:.1f}s")
time.sleep(s); it = None
if attempts % 5 == 0 and len(sources) > 1:
src_idx = (src_idx + 1) % len(sources)
def token_stream_chunked(ds_names: str, target: int, chunk_size: int = 16384, **kw):
buf = []
for t in token_stream(ds_names, target, **kw):
buf.append(t)
if len(buf) >= chunk_size:
yield buf; buf = []
if buf: yield buf
# ═══════════════════════════════════════════════════════════════════
# GPU HASH TABLE
# Sorted int64 keys + torch.searchsorted for O(log N) batch lookup
# ═══════════════════════════════════════════════════════════════════
FNV_OFFSET = 14695981039346656037
FNV_PRIME = 1099511628211
MASK64 = (1 << 64) - 1
INT64_MAX = (1 << 63) - 1
INT64_WRAP = 1 << 64
FNV_OFFSET_S = FNV_OFFSET - INT64_WRAP # signed repr for torch int64
def _hash_ngram_py(ngram: tuple) -> int:
"""CPU hash for building tables."""
h = FNV_OFFSET
for t in ngram:
h ^= (t & MASK64)
h = (h * FNV_PRIME) & MASK64
return h if h <= INT64_MAX else h - INT64_WRAP
def _hash_ngram_batch_gpu(contexts: torch.Tensor) -> torch.Tensor:
"""
GPU batch hash. contexts: (batch, n) int64 tensor β†’ (batch,) int64 hashes.
Same FNV-1a algorithm vectorised in torch ops.
"""
B, N = contexts.shape
h = torch.full((B,), FNV_OFFSET_S, dtype=torch.int64, device=contexts.device)
for i in range(N):
h = h ^ contexts[:, i]
h = h * FNV_PRIME # int64 wraps naturally
return h
class GPUHashTable:
"""
Immutable GPU hash table for n-gram lookups.
Built once from CPU dict, all lookups are GPU batch ops.
Storage: sorted array of (hash, count) pairs.
Lookup: searchsorted on hash array β†’ O(log N) per query, fully parallel.
Collision rate with FNV-1a on int64: ~1/2^64 per pair β€” negligible.
"""
def __init__(self):
self.hashes: Optional[torch.Tensor] = None
self.counts: Optional[torch.Tensor] = None
self.continuation_counts: Optional[torch.Tensor] = None
self.total: int = 0
self.size: int = 0
def build_from_dict(self, count_dict: Dict[tuple, Dict[int, int]], device=DEV):
"""Convert {context: {next_tok: count}} to GPU sorted hash table."""
entries = []
total = 0
for ctx, nexts in count_dict.items():
for next_tok, cnt in nexts.items():
key = ctx + (next_tok,)
h = _hash_ngram_py(key)
entries.append((h, cnt))
total += cnt
if not entries:
self.hashes = torch.empty(0, dtype=torch.int64, device=device)
self.counts = torch.empty(0, dtype=torch.int64, device=device)
self.total = 0; self.size = 0
return self
entries.sort(key=lambda x: x[0])
self.hashes = torch.tensor([e[0] for e in entries], dtype=torch.int64, device=device)
self.counts = torch.tensor([e[1] for e in entries], dtype=torch.int64, device=device)
self.total = total
self.size = len(entries)
return self
def build_context_table(self, count_dict: Dict[tuple, Dict[int, int]], device=DEV):
"""Build context-level table: hash(context) β†’ total count + unique continuations."""
entries = []
for ctx, nexts in count_dict.items():
h = _hash_ngram_py(ctx)
total = sum(nexts.values())
n_unique = len(nexts)
entries.append((h, total, n_unique))
if not entries:
self.hashes = torch.empty(0, dtype=torch.int64, device=device)
self.counts = torch.empty(0, dtype=torch.int64, device=device)
self.continuation_counts = torch.empty(0, dtype=torch.int64, device=device)
self.total = 0; self.size = 0
return self
entries.sort(key=lambda x: x[0])
self.hashes = torch.tensor([e[0] for e in entries], dtype=torch.int64, device=device)
self.counts = torch.tensor([e[1] for e in entries], dtype=torch.int64, device=device)
self.continuation_counts = torch.tensor([e[2] for e in entries], dtype=torch.int64, device=device)
self.total = sum(e[1] for e in entries)
self.size = len(entries)
return self
def batch_lookup(self, hashes: torch.Tensor) -> torch.Tensor:
"""GPU batch lookup. hashes: (B,) int64 β†’ (B,) int64 counts (0 if miss)."""
if self.size == 0:
return torch.zeros_like(hashes)
idx = torch.searchsorted(self.hashes, hashes).clamp(0, self.size - 1)
found = (self.hashes[idx] == hashes)
return torch.where(found, self.counts[idx], torch.zeros_like(hashes))
def batch_lookup_continuations(self, hashes: torch.Tensor) -> torch.Tensor:
"""Lookup unique continuation count for context hashes."""
if self.size == 0 or self.continuation_counts is None:
return torch.zeros_like(hashes)
idx = torch.searchsorted(self.hashes, hashes).clamp(0, self.size - 1)
found = (self.hashes[idx] == hashes)
return torch.where(found, self.continuation_counts[idx], torch.zeros_like(hashes))
def memory_bytes(self) -> int:
total = 0
for t in [self.hashes, self.counts, self.continuation_counts]:
if t is not None: total += t.nelement() * t.element_size()
return total
# ═══════════════════════════════════════════════════════════════════
# MARKOV CHAIN LANGUAGE MODEL
# ═══════════════════════════════════════════════════════════════════
class MarkovLM:
"""
Classical n-gram LM with Modified Kneser-Ney smoothing.
Training: CPU (dict-based counting, memory-bound).
Inference/Eval: GPU (hash table batch lookup, vectorised smoothing).
cpu_counts[order] = {context_tuple: {next_token: count}}
order 0 = unigram (context = ())
order 1 = bigram (context = (t,))
order k = (k+1)-gram (context = (t1,...,tk))
"""
def __init__(self, max_order: int = 5):
self.max_order = max_order
self.cpu_counts: List[Dict[tuple, Dict[int, int]]] = [
defaultdict(lambda: defaultdict(int)) for _ in range(max_order)
]
self.total_tokens = 0
self.tokens_trained = 0
# GPU tables (populated by freeze())
self.gpu_ngram_tables: List[Optional[GPUHashTable]] = [None] * max_order
self.gpu_context_tables: List[Optional[GPUHashTable]] = [None] * max_order
self.frozen = False
# KN discounts per order
self.discounts: List[Tuple[float, float, float]] = [(0.5, 1.0, 1.5)] * max_order
# Unigram distribution on GPU
self.gpu_unigram_probs: Optional[torch.Tensor] = None
# ─────────────── Training (CPU) ───────────────
def train_on_tokens(self, token_iter, total_tokens: int, save_path: Optional[str] = None,
save_every: int = 100_000_000, min_count_prune: int = 2,
prune_every: int = 200_000_000):
window: List[int] = []
pbar = SafeProgress(total_tokens, unit="tok", interval=2_000_000)
last_save = 0
last_prune = 0
print(f"[train] max_order={self.max_order}, target={total_tokens:,}")
print(f"[train] {get_uk_time()}")
for chunk in token_iter:
for t in chunk:
window.append(t)
self.total_tokens += 1
self.tokens_trained += 1
for order in range(self.max_order):
ctx_len = order
if len(window) > ctx_len:
ctx = tuple(window[-(ctx_len + 1):-1]) if ctx_len > 0 else ()
next_tok = window[-1]
self.cpu_counts[order][ctx][next_tok] += 1
if len(window) > self.max_order + 10:
window = window[-(self.max_order + 1):]
pbar.update(1)
if self.tokens_trained - last_prune >= prune_every and min_count_prune > 1:
self._prune_counts(min_count_prune)
last_prune = self.tokens_trained
if save_path and self.tokens_trained - last_save >= save_every:
self._save_cpu(save_path)
last_save = self.tokens_trained
pbar.close()
self._print_stats()
if save_path:
self._save_cpu(save_path)
def _prune_counts(self, min_count: int):
"""Per-order pruning: keep more at low orders, prune harder at high orders.
Orders 0-2 (uni/bi/tri): never prune (small, high value)
Orders 3-4: prune at min_count
Orders 5+: prune at min_count + (order - 4) to save memory"""
pruned = 0
for order in range(1, self.max_order):
if order <= 2:
continue # never prune unigram/bigram/trigram
threshold = min_count + max(0, order - 4)
to_delete_ctx = []
for ctx, nexts in self.cpu_counts[order].items():
low = [t for t, c in nexts.items() if c < threshold]
for t in low:
del nexts[t]; pruned += 1
if not nexts:
to_delete_ctx.append(ctx)
for ctx in to_delete_ctx:
del self.cpu_counts[order][ctx]
if pruned:
gc.collect()
print(f" [prune] Removed {pruned:,} entries (per-order thresholds)")
def _print_stats(self):
print(f"\n{'='*60}")
print(f" N-GRAM MODEL β€” {self.max_order}-gram | {self.tokens_trained:,} tokens trained")
print(f"{'='*60}")
total_entries = 0
for order in range(self.max_order):
n_ctx = len(self.cpu_counts[order])
n_ent = sum(len(v) for v in self.cpu_counts[order].values())
total_entries += n_ent
bytes_est = n_ent * (8 * (order + 2))
print(f" {order+1}-gram: {n_ctx:>12,} contexts | {n_ent:>12,} entries | ~{bytes_est/1e9:.2f} GB")
print(f" TOTAL: {total_entries:>39,} entries")
print(f"{'='*60}")
# ─────────────── Freeze: CPU β†’ GPU ───────────────
def freeze(self, device=DEV, prune_threshold: int = 1):
print(f"\n[freeze] Building GPU hash tables on {device}...")
t0 = time.time()
if prune_threshold > 1:
self._prune_counts(prune_threshold)
self._compute_kn_discounts()
total_gpu_bytes = 0
for order in range(self.max_order):
counts = self.cpu_counts[order]
if not counts: continue
gt = GPUHashTable().build_from_dict(counts, device=device)
self.gpu_ngram_tables[order] = gt
ct = GPUHashTable().build_context_table(counts, device=device)
self.gpu_context_tables[order] = ct
mem = gt.memory_bytes() + ct.memory_bytes()
total_gpu_bytes += mem
print(f" {order+1}-gram: {gt.size:,} entries, {ct.size:,} contexts ({mem/1e6:.1f} MB GPU)")
# Unigram distribution
if self.cpu_counts[0] and () in self.cpu_counts[0]:
unigram = self.cpu_counts[0][()]
probs = torch.zeros(VOCAB, dtype=torch.float32, device=device)
for tok_id, cnt in unigram.items():
if tok_id < VOCAB:
probs[tok_id] = cnt
self.gpu_unigram_probs = probs / probs.sum().clamp(min=1)
total_gpu_bytes += probs.nelement() * 4
print(f"[freeze] Done in {time.time()-t0:.1f}s. GPU: {total_gpu_bytes/1e6:.1f} MB")
self.frozen = True
def free_cpu_counts(self):
for i in range(self.max_order):
self.cpu_counts[i] = {}
gc.collect()
print("[free] CPU dicts released.")
def _compute_kn_discounts(self):
for order in range(self.max_order):
n1 = n2 = n3 = n4 = 0
for ctx, nexts in self.cpu_counts[order].items():
for cnt in nexts.values():
if cnt == 1: n1 += 1
elif cnt == 2: n2 += 1
elif cnt == 3: n3 += 1
elif cnt == 4: n4 += 1
if n1 == 0 or n2 == 0:
self.discounts[order] = (0.5, 1.0, 1.5)
continue
Y = n1 / (n1 + 2 * n2)
D1 = max(0.01, min(1 - 2 * Y * (n2 / max(n1, 1)), 0.99))
D2 = max(D1, min(2 - 3 * Y * (n3 / max(n2, 1)), 1.99))
D3 = max(D2, min(3 - 4 * Y * (n4 / max(n3, 1)), 2.99))
self.discounts[order] = (D1, D2, D3)
print(f" KN {order+1}-gram: D1={D1:.3f} D2={D2:.3f} D3+={D3:.3f}")
# ─────────────── GPU Batch Probability ───────────────
@torch.no_grad()
def batch_log_probs(self, contexts: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute log P(target | context) via interpolated Modified KN on GPU.
contexts: (B, max_order-1) int64, left-padded with -1
targets: (B,) int64
Returns: (B,) float32 log probabilities
"""
B = targets.shape[0]
device = targets.device
# Start with uniform
log_probs = torch.full((B,), math.log(1.0 / VOCAB), dtype=torch.float32, device=device)
for order in range(self.max_order):
gt = self.gpu_ngram_tables[order]
ct = self.gpu_context_tables[order]
if gt is None or ct is None or gt.size == 0:
continue
ctx_len = order
if ctx_len == 0:
# Unigram
if self.gpu_unigram_probs is not None:
safe_t = targets.clamp(0, VOCAB - 1)
uni_p = self.gpu_unigram_probs[safe_t]
valid = uni_p > 0
log_probs = torch.where(valid, torch.log(uni_p + 1e-30), log_probs)
continue
if ctx_len > contexts.shape[1]:
continue
ctx = contexts[:, -ctx_len:]
has_ctx = (ctx >= 0).all(dim=1)
if not has_ctx.any():
continue
# Hash lookups
full_ngram = torch.cat([ctx, targets.unsqueeze(1)], dim=1)
ngram_counts = gt.batch_lookup(_hash_ngram_batch_gpu(full_ngram)).float()
ctx_hashes = _hash_ngram_batch_gpu(ctx)
ctx_totals = ct.batch_lookup(ctx_hashes).float()
ctx_uniques = ct.batch_lookup_continuations(ctx_hashes).float()
# KN discount
D1, D2, D3 = self.discounts[order]
discount = torch.where(ngram_counts >= 3, D3,
torch.where(ngram_counts >= 2, D2,
torch.where(ngram_counts >= 1, D1, 0.0)))
numerator = (ngram_counts - discount).clamp(min=0)
denominator = ctx_totals.clamp(min=1)
# Interpolation weight
gamma = (D3 * ctx_uniques) / denominator
gamma = gamma.clamp(0, 1)
# Interpolate: P_high = discounted/denom + gamma * P_lower
p_lower = log_probs.exp()
p_combined = numerator / denominator + gamma * p_lower
valid = has_ctx & (ctx_totals > 0)
log_probs = torch.where(valid, torch.log(p_combined.clamp(min=1e-30)), log_probs)
return log_probs
# ─────────────── Evaluation ───────────────
@torch.no_grad()
def evaluate(self, eval_source: str, eval_tokens: int, batch_size: int = 1024):
"""Evaluate perplexity on held-out data using GPU batch processing."""
assert self.frozen, "Call freeze() first"
ctx_len = self.max_order - 1
print(f"\n[eval] Source: {eval_source}")
print(f"[eval] {eval_tokens:,} tokens, batch={batch_size}")
print(f"[eval] {get_uk_time()}")
# Collect eval tokens
print("[eval] Collecting tokens...")
all_tokens = []
for chunk in token_stream_chunked(eval_source, eval_tokens + ctx_len, chunk_size=8192):
all_tokens.extend(chunk)
if len(all_tokens) >= eval_tokens + ctx_len:
break
all_tokens = all_tokens[:eval_tokens + ctx_len]
tokens_t = torch.tensor(all_tokens, dtype=torch.int64)
n_eval = len(all_tokens) - ctx_len
print(f"[eval] {n_eval:,} evaluation positions")
total_log_prob = 0.0
total = 0
pbar = SafeProgress(n_eval, unit="tok", interval=50_000)
for start in range(0, n_eval, batch_size):
end = min(start + batch_size, n_eval)
B = end - start
# Build batched contexts and targets on GPU
# contexts[i] = tokens[start+i : start+i+ctx_len]
# targets[i] = tokens[start+i+ctx_len]
indices = torch.arange(start, end)
ctx_indices = indices.unsqueeze(1) + torch.arange(ctx_len).unsqueeze(0) # (B, ctx_len)
tgt_indices = indices + ctx_len
contexts = tokens_t[ctx_indices].to(DEV)
targets = tokens_t[tgt_indices].to(DEV)
log_probs = self.batch_log_probs(contexts, targets)
total_log_prob += log_probs.sum().item()
total += B
pbar.update(B)
pbar.close()
avg_lp = total_log_prob / total
ppl = math.exp(-avg_lp)
print(f"\n{'='*55}")
print(f" PERPLEXITY EVALUATION")
print(f"{'='*55}")
print(f" Tokens: {total:,}")
print(f" Avg log P: {avg_lp:.4f}")
print(f" Perplexity: {ppl:.2f}")
print(f" {get_uk_time()}")
print(f"{'='*55}")
return {"perplexity": ppl, "avg_log_prob": avg_lp, "tokens": total}
@torch.no_grad()
def evaluate_accuracy(self, eval_source: str, eval_tokens: int,
batch_size: int = 256, top_k_check: int = 1000):
"""
Next-token accuracy: for each position, check if the highest-probability
token (among top-K unigram candidates + target) matches ground truth.
GPU-parallel: each position checks top_k candidates simultaneously.
"""
assert self.frozen, "Call freeze() first"
ctx_len = self.max_order - 1
print(f"\n[accuracy] top_k_check={top_k_check}, {eval_tokens:,} tokens")
# Top-K unigram tokens as candidate pool
if self.gpu_unigram_probs is not None:
_, topk = self.gpu_unigram_probs.topk(min(top_k_check, VOCAB))
else:
topk = torch.arange(top_k_check, device=DEV)
# Collect tokens
all_tokens = []
for chunk in token_stream_chunked(eval_source, eval_tokens + ctx_len, chunk_size=8192):
all_tokens.extend(chunk)
if len(all_tokens) >= eval_tokens + ctx_len:
break
all_tokens = all_tokens[:eval_tokens + ctx_len]
tokens_t = torch.tensor(all_tokens, dtype=torch.int64)
n_eval = len(all_tokens) - ctx_len
correct = 0
total = 0
pbar = SafeProgress(n_eval, unit="tok", interval=20_000)
for start in range(0, n_eval, batch_size):
end = min(start + batch_size, n_eval)
B = end - start
indices = torch.arange(start, end)
ctx_indices = indices.unsqueeze(1) + torch.arange(ctx_len).unsqueeze(0)
tgt_indices = indices + ctx_len
contexts = tokens_t[ctx_indices].to(DEV)
targets = tokens_t[tgt_indices].to(DEV)
# Score the target
target_lp = self.batch_log_probs(contexts, targets)
# Score top-K candidates: expand contexts Γ— candidates
K = topk.shape[0]
ctx_exp = contexts.unsqueeze(1).expand(B, K, ctx_len).reshape(B * K, ctx_len)
cand_exp = topk.unsqueeze(0).expand(B, K).reshape(B * K)
cand_lp = self.batch_log_probs(ctx_exp, cand_exp).reshape(B, K)
best_cand_lp, _ = cand_lp.max(dim=1)
# Target wins if its log prob >= best candidate's
correct += (target_lp >= best_cand_lp).sum().item()
total += B
if total % 10000 < batch_size:
pbar.set_postfix(acc=f"{100*correct/max(total,1):.2f}%")
pbar.update(B)
pbar.close()
acc = correct / max(total, 1)
print(f"\n{'='*55}")
print(f" ACCURACY (top-{top_k_check} check)")
print(f"{'='*55}")
print(f" Correct: {correct:,} / {total:,}")
print(f" Accuracy: {100*acc:.2f}%")
print(f" Target: 47% (Infini-gram on Pile)")
print(f" {get_uk_time()}")
print(f"{'='*55}")
return {"accuracy": acc, "correct": correct, "total": total}
# ─────────────── Generation ───────────────
@torch.no_grad()
def generate(self, prompt: str, max_new: int = 200, temperature: float = 0.8,
top_k: int = 50, top_p: float = 0.9):
assert self.frozen, "Call freeze() first"
ctx_len = self.max_order - 1
ids = tok.encode(prompt)
n_cands = min(top_k * 10, VOCAB)
print(f"\n[gen] Prompt: '{prompt}' ({len(ids)} tokens)")
t0 = time.time()
# Get candidate pool
if self.gpu_unigram_probs is not None:
_, candidates = self.gpu_unigram_probs.topk(n_cands)
else:
candidates = torch.arange(n_cands, device=DEV)
for _ in range(max_new):
if len(ids) >= ctx_len:
ctx = ids[-ctx_len:]
else:
ctx = [-1] * (ctx_len - len(ids)) + ids
ctx_t = torch.tensor([ctx], dtype=torch.int64, device=DEV).expand(n_cands, ctx_len)
log_probs = self.batch_log_probs(ctx_t, candidates)
# Temperature
probs = (log_probs / max(temperature, 1e-8)).softmax(0)
# Top-k
if top_k > 0 and top_k < n_cands:
vals, idx = probs.topk(top_k)
mask = torch.zeros_like(probs)
mask.scatter_(0, idx, vals)
probs = mask
# Top-p
if top_p < 1.0:
sp, si = probs.sort(descending=True)
cum = sp.cumsum(0)
cutoff = cum > top_p
cutoff[0] = False
sp[cutoff] = 0
probs = torch.zeros_like(probs).scatter_(0, si, sp)
if probs.sum() == 0:
next_tok = candidates[0].item()
else:
probs = probs / probs.sum()
next_tok = candidates[probs.multinomial(1).item()].item()
ids.append(next_tok)
if next_tok == EOS:
break
elapsed = time.time() - t0
gen_n = len(ids) - len(tok.encode(prompt))
text = tok.decode(ids, skip_special_tokens=True)
print(f"\n{text}")
print(f"\n[{elapsed:.2f}s | {gen_n} tokens | {gen_n/max(elapsed,0.01):.1f} tok/s]")
return text
# ─────────────── Save / Load ───────────────
def _save_cpu(self, path: str):
p = Path(path).with_suffix('.cpu.pkl')
p.parent.mkdir(parents=True, exist_ok=True)
data = {
'max_order': self.max_order,
'cpu_counts': [{k: dict(v) for k, v in d.items()} for d in self.cpu_counts],
'total_tokens': self.total_tokens,
'tokens_trained': self.tokens_trained,
'discounts': self.discounts,
}
tmp = p.with_suffix('.tmp')
with open(tmp, 'wb') as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
tmp.replace(p)
print(f"[save] {p} ({p.stat().st_size/1e6:.1f} MB)")
def save(self, path: str):
self._save_cpu(path)
@classmethod
def load(cls, path: str) -> 'MarkovLM':
p = Path(path)
for suffix in ['.cpu.pkl', '.pkl']:
candidate = p.with_suffix(suffix)
if candidate.exists():
p = candidate; break
print(f"[load] {p}...")
with open(p, 'rb') as f:
data = pickle.load(f)
model = cls(max_order=data['max_order'])
model.total_tokens = data['total_tokens']
model.tokens_trained = data['tokens_trained']
model.discounts = data.get('discounts', model.discounts)
for order in range(model.max_order):
raw = data['cpu_counts'][order]
dd = defaultdict(lambda: defaultdict(int))
for ctx, nexts in raw.items():
dd[ctx] = defaultdict(int, nexts)
model.cpu_counts[order] = dd
model._print_stats()
print("[load] Freezing to GPU...")
model.freeze()
return model
def status(self):
self._print_stats()
if self.frozen:
gpu_mem = sum(
(gt.memory_bytes() if gt else 0) + (ct.memory_bytes() if ct else 0)
for gt, ct in zip(self.gpu_ngram_tables, self.gpu_context_tables)
)
print(f" GPU: {gpu_mem/1e6:.1f} MB")
print(f" Frozen: {self.frozen} | {get_uk_time()}")
# ═══════════════════════════════════════════════════════════════════
# STATUS FILE
# ═══════════════════════════════════════════════════════════════════
STATUS_FILE = "/workspace/markov_status.json"
def write_status(tokens_trained, phase, tok_per_sec=0, extra=None):
try:
data = {"tokens_trained": tokens_trained, "phase": phase,
"tok_per_sec": tok_per_sec, "updated": time.time(), "uk_time": get_uk_time()}
if extra: data.update(extra)
with open(STATUS_FILE, 'w') as f:
json.dump(data, f)
except: pass
# ═══════════════════════════════════════════════════════════════════
# CLI
# ═══════════════════════════════════════════════════════════════════
def cmd_train(args):
print(f"{'='*60}")
print(f" MARKOV CHAIN LM β€” HYBRID CPU+GPU")
print(f"{'='*60}")
print(f" Order: {args.max_order}-gram")
print(f" Tokens: {args.tokens:,}")
print(f" Source: {args.source[:80]}...")
print(f" Save: {args.save}")
print(f" Prune: min_count={args.min_count}")
print(f" Device: CPU (counting) β†’ {DEV} (inference)")
print(f" {get_uk_time()}")
print(f"{'='*60}")
if args.resume and Path(args.save).with_suffix('.cpu.pkl').exists():
print("[resume] Loading existing...")
model = MarkovLM.load(args.save)
remaining = args.tokens - model.tokens_trained
if remaining <= 0:
print(f"[resume] Already at {model.tokens_trained:,} >= {args.tokens:,}")
return model
print(f"[resume] {remaining:,} tokens remaining")
else:
model = MarkovLM(max_order=args.max_order)
stream = token_stream_chunked(args.source, args.tokens, chunk_size=16384, seed=42, field=args.field)
model.train_on_tokens(stream, total_tokens=args.tokens, save_path=args.save,
save_every=args.save_every, min_count_prune=args.min_count,
prune_every=args.prune_every)
print("\n[train] Freezing to GPU...")
model.freeze(prune_threshold=args.freeze_prune)
model.save(args.save)
return model
def cmd_eval(args):
model = MarkovLM.load(args.model)
results = model.evaluate(eval_source=args.eval_source, eval_tokens=args.eval_tokens,
batch_size=args.batch_size)
if args.accuracy:
acc = model.evaluate_accuracy(eval_source=args.eval_source,
eval_tokens=min(args.eval_tokens, args.accuracy_tokens),
batch_size=args.acc_batch_size, top_k_check=args.top_k_check)
results.update(acc)
return results
def cmd_generate(args):
model = MarkovLM.load(args.model)
model.generate(prompt=args.prompt, max_new=args.max_new,
temperature=args.temperature, top_k=args.top_k, top_p=args.top_p)
def cmd_status(args):
p = Path(args.model)
for s in ['.cpu.pkl', '.pkl']:
if p.with_suffix(s).exists():
model = MarkovLM.load(args.model)
model.status()
return
print(f"No model at {args.model}")
try:
with open(STATUS_FILE) as f:
s = json.load(f)
print(f"Phase: {s.get('phase')} | Tokens: {s.get('tokens_trained',0):,} | {s.get('uk_time','?')}")
except: print("No status file.")
def main():
ap = argparse.ArgumentParser(description="Hybrid CPU+GPU Markov Chain LM")
sub = ap.add_subparsers(dest="cmd", required=True)
tr = sub.add_parser("train")
tr.add_argument("--max_order", type=int, default=10)
tr.add_argument("--tokens", type=int, default=2_000_000_000)
tr.add_argument("--source", default=DEFAULT_SOURCES)
tr.add_argument("--field", default="text")
tr.add_argument("--save", default="markov_model")
tr.add_argument("--save_every", type=int, default=100_000_000)
tr.add_argument("--min_count", type=int, default=2)
tr.add_argument("--prune_every", type=int, default=300_000_000)
tr.add_argument("--freeze_prune", type=int, default=2)
tr.add_argument("--resume", action="store_true")
ev = sub.add_parser("eval")
ev.add_argument("--model", required=True)
ev.add_argument("--eval_source", default=EVAL_SOURCE)
ev.add_argument("--eval_tokens", type=int, default=1_000_000)
ev.add_argument("--batch_size", type=int, default=1024)
ev.add_argument("--accuracy", action="store_true")
ev.add_argument("--accuracy_tokens", type=int, default=200_000)
ev.add_argument("--acc_batch_size", type=int, default=128)
ev.add_argument("--top_k_check", type=int, default=1000)
gen = sub.add_parser("generate")
gen.add_argument("--model", required=True)
gen.add_argument("--prompt", required=True)
gen.add_argument("--max_new", type=int, default=200)
gen.add_argument("--temperature", type=float, default=0.8)
gen.add_argument("--top_k", type=int, default=50)
gen.add_argument("--top_p", type=float, default=0.9)
st = sub.add_parser("status")
st.add_argument("--model", default="markov_model")
args = ap.parse_args()
{"train": cmd_train, "eval": cmd_eval, "generate": cmd_generate, "status": cmd_status}[args.cmd](args)
if __name__ == "__main__":
main()