markov-lm-demo / app.py
OpenTransformer's picture
Fix OOM: drop 5-gram dict immediately after pickle load, prune 3/4-gram, conservative memory
8e22592 verified
#!/usr/bin/env python3
"""
Markov Chain Language Model β€” Interactive Demo
OpenTransformers Ltd | Part of AGILLM Research
Classical n-gram LM with Modified Kneser-Ney smoothing.
GPU hash tables (sorted int64 + searchsorted) for parallel inference.
Runs on CPU for HF Spaces compatibility.
"""
import os, sys, math, time, pickle, gc, ctypes
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import gradio as gr
from huggingface_hub import hf_hub_download
# ─── Force CPU for HF Spaces ───
DEV = torch.device("cpu")
# ─── Memory management ───
# HF Spaces free tier has 16GB. After Python + torch + gradio + pickle load,
# roughly 10-11GB is available. We budget conservatively.
TOTAL_MEM_BUDGET_GB = float(os.environ.get("MEM_BUDGET_GB", "10.0"))
def _malloc_trim():
"""Force glibc to return freed memory to OS."""
try:
ctypes.CDLL("libc.so.6").malloc_trim(0)
except Exception:
pass
def _force_free():
"""Aggressive memory reclaim: gc + malloc_trim."""
gc.collect()
gc.collect()
_malloc_trim()
# ─── Tokenizer ───
TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "gpt2")
def _load_tokenizer():
from transformers import AutoTokenizer, logging as hf_log
hf_log.set_verbosity_error()
t = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True)
if t.pad_token is None:
t.add_special_tokens({"pad_token": "<|pad|>"})
return t
tok = _load_tokenizer()
VOCAB = max(tok.get_vocab().values()) + 1
EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
# ─── FNV-1a Hashing ───
FNV_OFFSET = 14695981039346656037
FNV_PRIME = 1099511628211
MASK64 = (1 << 64) - 1
INT64_MAX = (1 << 63) - 1
INT64_WRAP = 1 << 64
FNV_OFFSET_S = FNV_OFFSET - INT64_WRAP
def _hash_ngram_batch_gpu(contexts: torch.Tensor) -> torch.Tensor:
B, N = contexts.shape
h = torch.full((B,), FNV_OFFSET_S, dtype=torch.int64, device=contexts.device)
for i in range(N):
h = h ^ contexts[:, i]
h = h * FNV_PRIME
return h
class GPUHashTable:
"""Immutable hash table using sorted int64 keys + searchsorted."""
def __init__(self):
self.hashes: Optional[torch.Tensor] = None
self.counts: Optional[torch.Tensor] = None
self.continuation_counts: Optional[torch.Tensor] = None
self.total: int = 0
self.size: int = 0
def batch_lookup(self, hashes: torch.Tensor) -> torch.Tensor:
if self.size == 0:
return torch.zeros_like(hashes)
idx = torch.searchsorted(self.hashes, hashes).clamp(0, self.size - 1)
found = (self.hashes[idx] == hashes)
return torch.where(found, self.counts[idx], torch.zeros_like(hashes))
def batch_lookup_continuations(self, hashes: torch.Tensor) -> torch.Tensor:
if self.size == 0 or self.continuation_counts is None:
return torch.zeros_like(hashes)
idx = torch.searchsorted(self.hashes, hashes).clamp(0, self.size - 1)
found = (self.hashes[idx] == hashes)
return torch.where(found, self.continuation_counts[idx], torch.zeros_like(hashes))
def memory_bytes(self) -> int:
total = 0
for t in [self.hashes, self.counts, self.continuation_counts]:
if t is not None: total += t.nelement() * t.element_size()
return total
def _hash_ngram_py(ngram: tuple) -> int:
h = FNV_OFFSET
for t in ngram:
h ^= (t & MASK64)
h = (h * FNV_PRIME) & MASK64
return h if h <= INT64_MAX else h - INT64_WRAP
# ─── Memory-aware freezing ───
BYTES_PER_NGRAM = 16 # int64 hash + int64 count
BYTES_PER_CTX = 24 # int64 hash + int64 count + int64 continuation
SORT_OVERHEAD = 3 # argsort needs ~3x the tensor size transiently
def _count_entries_at_threshold(d, threshold):
"""Count how many n-gram entries have count >= threshold."""
n = 0
for nexts in d.values():
for cnt in nexts.values():
if cnt >= threshold:
n += 1
return n
def _estimate_freeze_peak(n_entries, n_contexts):
"""Estimate peak memory (bytes) during freeze for an order.
Peak = raw tensors being built + argsort workspace + context tensors.
"""
ngram_bytes = n_entries * BYTES_PER_NGRAM * SORT_OVERHEAD
ctx_bytes = n_contexts * BYTES_PER_CTX * SORT_OVERHEAD
return ngram_bytes + ctx_bytes
def _estimate_python_dict_bytes(d):
"""Rough estimate of Python dict memory for one order's raw data.
Each outer dict entry: ~232 bytes (key tuple + value dict overhead).
Each inner entry: ~128 bytes (int key + int value + dict entry).
"""
n_contexts = len(d)
n_entries = sum(len(nexts) for nexts in d.values())
return n_contexts * 232 + n_entries * 128
def _pick_threshold_for_budget(d, budget_bytes):
"""Find lowest prune threshold that fits freeze peak in budget."""
n_contexts = len(d)
for thresh in [1, 2, 3, 5, 8, 10, 15, 25, 50]:
n = _count_entries_at_threshold(d, thresh)
peak = _estimate_freeze_peak(n, n_contexts)
if peak <= budget_bytes:
return thresh, n
# Nothing fits β€” return highest with its count
n = _count_entries_at_threshold(d, 50)
return 50, n
def _freeze_order(order, d, device, prune_threshold=1):
"""Freeze a single n-gram order into GPU hash tables.
Writes directly into pre-allocated tensors. Catches OOM to prevent
corrupted tables from poisoning higher-order lookups.
"""
if not d:
return GPUHashTable(), GPUHashTable()
threshold = max(prune_threshold, 1)
# Count entries
n_ngrams = 0
for nexts in d.values():
for cnt in nexts.values():
if cnt >= threshold:
n_ngrams += 1
n_contexts = len(d)
gt = GPUHashTable()
ct = GPUHashTable()
try:
# Build ngram table
if n_ngrams > 0:
hashes = torch.empty(n_ngrams, dtype=torch.int64)
counts = torch.empty(n_ngrams, dtype=torch.int64)
idx = 0
for ctx, nexts in d.items():
for next_tok, cnt in nexts.items():
if cnt >= threshold:
hashes[idx] = _hash_ngram_py(ctx + (next_tok,))
counts[idx] = cnt
idx += 1
sort_idx = hashes.argsort()
gt.hashes = hashes[sort_idx].to(device)
gt.counts = counts[sort_idx].to(device)
gt.total = counts.sum().item()
gt.size = n_ngrams
del hashes, counts, sort_idx
_force_free()
# Build context table
if n_contexts > 0:
c_hashes = torch.empty(n_contexts, dtype=torch.int64)
c_counts = torch.empty(n_contexts, dtype=torch.int64)
c_uniques = torch.empty(n_contexts, dtype=torch.int64)
idx = 0
for ctx, nexts in d.items():
c_hashes[idx] = _hash_ngram_py(ctx)
c_counts[idx] = sum(nexts.values())
c_uniques[idx] = len(nexts)
idx += 1
sort_idx = c_hashes.argsort()
ct.hashes = c_hashes[sort_idx].to(device)
ct.counts = c_counts[sort_idx].to(device)
ct.continuation_counts = c_uniques[sort_idx].to(device)
ct.total = c_counts.sum().item()
ct.size = n_contexts
del c_hashes, c_counts, c_uniques, sort_idx
_force_free()
except (MemoryError, RuntimeError) as e:
print(f" [OOM] Failed to freeze order {order+1}: {e}")
print(f" [OOM] Skipping this order to prevent corruption")
gt = GPUHashTable()
ct = GPUHashTable()
_force_free()
return gt, ct
class MarkovLM:
"""N-gram LM with Modified Kneser-Ney smoothing."""
def __init__(self, max_order: int = 5):
self.max_order = max_order
self.cpu_counts: List[Dict] = []
self.total_tokens = 0
self.tokens_trained = 0
self.gpu_ngram_tables: List[Optional[GPUHashTable]] = [None] * max_order
self.gpu_context_tables: List[Optional[GPUHashTable]] = [None] * max_order
self.frozen = False
self.discounts: List[Tuple[float, float, float]] = [(0.5, 1.0, 1.5)] * max_order
self.gpu_unigram_probs: Optional[torch.Tensor] = None
def _estimate_discounts(self):
for order in range(self.max_order):
gt = self.gpu_ngram_tables[order]
if gt is None or gt.size == 0:
continue
counts = gt.counts
n1 = (counts == 1).sum().item()
n2 = (counts == 2).sum().item()
n3 = (counts == 3).sum().item()
n4 = (counts == 4).sum().item()
if n1 == 0 or n2 == 0:
continue
Y = n1 / (n1 + 2 * n2)
D1 = max(0.01, min(1 - 2 * Y * (n2 / max(n1, 1)), 0.99))
D2 = max(D1, min(2 - 3 * Y * (n3 / max(n2, 1)), 1.99))
D3 = max(D2, min(3 - 4 * Y * (n4 / max(n3, 1)), 2.99))
self.discounts[order] = (D1, D2, D3)
@torch.no_grad()
def batch_log_probs(self, contexts: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
B = targets.shape[0]
device = targets.device
log_probs = torch.full((B,), math.log(1.0 / VOCAB), dtype=torch.float32, device=device)
for order in range(self.max_order):
gt = self.gpu_ngram_tables[order]
ct = self.gpu_context_tables[order]
if gt is None or ct is None or gt.size == 0:
continue
ctx_len = order
if ctx_len == 0:
if self.gpu_unigram_probs is not None:
safe_t = targets.clamp(0, VOCAB - 1)
uni_p = self.gpu_unigram_probs[safe_t]
valid = uni_p > 0
log_probs = torch.where(valid, torch.log(uni_p + 1e-30), log_probs)
continue
if ctx_len > contexts.shape[1]:
continue
ctx = contexts[:, -ctx_len:]
has_ctx = (ctx >= 0).all(dim=1)
if not has_ctx.any():
continue
full_ngram = torch.cat([ctx, targets.unsqueeze(1)], dim=1)
ngram_counts = gt.batch_lookup(_hash_ngram_batch_gpu(full_ngram)).float()
ctx_hashes = _hash_ngram_batch_gpu(ctx)
ctx_totals = ct.batch_lookup(ctx_hashes).float()
ctx_uniques = ct.batch_lookup_continuations(ctx_hashes).float()
D1, D2, D3 = self.discounts[order]
discount = torch.where(ngram_counts >= 3, D3,
torch.where(ngram_counts >= 2, D2,
torch.where(ngram_counts >= 1, D1, 0.0)))
numerator = (ngram_counts - discount).clamp(min=0)
denominator = ctx_totals.clamp(min=1)
gamma = (D3 * ctx_uniques) / denominator
gamma = gamma.clamp(0, 1)
p_lower = log_probs.exp()
p_combined = numerator / denominator + gamma * p_lower
valid = has_ctx & (ctx_totals > 0)
log_probs = torch.where(valid, torch.log(p_combined.clamp(min=1e-30)), log_probs)
return log_probs
@torch.no_grad()
def generate(self, prompt: str, max_new: int = 200, temperature: float = 0.8,
top_k: int = 50, top_p: float = 0.9):
assert self.frozen, "Call freeze() first"
ctx_len = self.max_order - 1
ids = tok.encode(prompt)
n_cands = min(top_k * 10, VOCAB)
if self.gpu_unigram_probs is not None:
_, candidates = self.gpu_unigram_probs.topk(n_cands)
else:
candidates = torch.arange(n_cands, device=DEV)
for _ in range(max_new):
if len(ids) >= ctx_len:
ctx = ids[-ctx_len:]
else:
ctx = [-1] * (ctx_len - len(ids)) + ids
ctx_t = torch.tensor([ctx], dtype=torch.int64, device=DEV).expand(n_cands, ctx_len)
log_probs = self.batch_log_probs(ctx_t, candidates)
probs = (log_probs / max(temperature, 1e-8)).softmax(0)
if top_k > 0 and top_k < n_cands:
vals, idx = probs.topk(top_k)
mask = torch.zeros_like(probs)
mask.scatter_(0, idx, vals)
probs = mask
if top_p < 1.0:
sp, si = probs.sort(descending=True)
cum = sp.cumsum(0)
cutoff = cum > top_p
cutoff[0] = False
sp[cutoff] = 0
probs = torch.zeros_like(probs).scatter_(0, si, sp)
if probs.sum() == 0:
next_tok = candidates[0].item()
else:
probs = probs / probs.sum()
next_tok = candidates[probs.multinomial(1).item()].item()
ids.append(next_tok)
if next_tok == EOS:
break
return tok.decode(ids, skip_special_tokens=True)
@classmethod
def load(cls, path: str, mem_budget_gb: float = TOTAL_MEM_BUDGET_GB) -> 'MarkovLM':
"""Load with data-driven adaptive memory management.
Strategy:
1. Load pickle into Python dicts
2. Estimate total Python dict memory for ALL orders
3. For each order (low→high), compute how much budget remains
after accounting for still-loaded raw dicts of future orders
4. Pick the lowest prune threshold whose freeze peak fits
5. Freeze, free raw dict, malloc_trim to reclaim OS memory
6. On OOM during freeze: skip order (prevents corruption)
Lower orders get priority β€” they're smaller AND more impactful
for quality (every token hits unigram/bigram; few hit 5-gram).
"""
p = Path(path)
for suffix in ['.cpu.pkl', '.pkl']:
candidate = p.with_suffix(suffix)
if candidate.exists():
p = candidate; break
print(f"[load] {p} ({p.stat().st_size / 1e6:.1f} MB)...")
with open(p, 'rb') as f:
data = pickle.load(f)
max_order = data['max_order']
model = cls(max_order=max_order)
model.total_tokens = data['total_tokens']
model.tokens_trained = data['tokens_trained']
model.discounts = data.get('discounts', model.discounts)
raw_counts = data['cpu_counts']
del data
_force_free()
# On HF Spaces (16GB), the pickle's Python dicts can use 6-8GB.
# We must free each order's raw dict BEFORE freezing the next.
# Higher orders (4-gram, 5-gram) have massive dicts that may not
# fit in memory alongside lower orders' tensors. Strategy:
# - Always load orders 0-2 (unigram, bigram, trigram) β€” essential
# - Try order 3 (4-gram) with pruning
# - Skip order 4 (5-gram) if memory is tight
# This is conservative but prevents OOM kills.
# First: drop higher orders' raw dicts immediately to free memory
# before we even start freezing. This is the key insight β€” the
# pickle loads ALL orders into memory, but we can delete the ones
# we'll process later (or skip entirely) to free GB of RAM.
MAX_ORDERS_TO_LOAD = int(os.environ.get("MAX_ORDERS", "4")) # default: up to 4-gram
for order in range(MAX_ORDERS_TO_LOAD, max_order):
print(f" [mem] Dropping {order+1}-gram raw dict to save memory")
raw_counts[order] = None
_force_free()
total_entries = 0
print(f"[load] {max_order}-gram | {model.tokens_trained:,} tokens")
print(f"[load] Loading up to {MAX_ORDERS_TO_LOAD}-gram (of {max_order})")
print("[load] Freezing (one order at a time)...")
t0 = time.time()
tensor_bytes_used = 0
for order in range(min(max_order, MAX_ORDERS_TO_LOAD)):
raw = raw_counts[order]
raw_counts[order] = None # unhook from list
_force_free()
if not raw:
continue
# Prune higher orders more aggressively
if order <= 1:
prune_t = 1 # unigrams/bigrams: keep all
elif order == 2:
prune_t = 2 # trigrams: keep count >= 2
else:
prune_t = 3 # 4-gram+: keep count >= 3
n_entries = sum(len(nexts) for nexts in raw.values())
n_ctx = len(raw)
print(f" [scan] {order+1}-gram: {n_ctx:,} contexts, {n_entries:,} raw entries")
gt, ct = _freeze_order(order, raw, device=DEV, prune_threshold=prune_t)
model.gpu_ngram_tables[order] = gt
model.gpu_context_tables[order] = ct
this_tensor_bytes = gt.memory_bytes() + ct.memory_bytes()
tensor_bytes_used += this_tensor_bytes
del raw
_force_free()
n_ent = gt.size
print(f" {order+1}-gram: {n_ent:,} entries | {this_tensor_bytes/1e6:.1f} MB (prune>={prune_t})")
total_entries += n_ent
del raw_counts
_force_free()
print(f"[load] Total: {total_entries:,} entries | {tensor_bytes_used/1e6:.1f} MB tensors")
# Estimate KN discounts
model._estimate_discounts()
# Unigram probs
gt0 = model.gpu_ngram_tables[0]
if gt0 and gt0.size > 0:
probs = torch.zeros(VOCAB, dtype=torch.float32, device=DEV)
total = gt0.total
if total > 0 and gt0.counts is not None:
all_toks = torch.arange(VOCAB, dtype=torch.int64, device=DEV).unsqueeze(1)
all_hashes = _hash_ngram_batch_gpu(all_toks)
counts = gt0.batch_lookup(all_hashes).float()
probs = counts / counts.sum().clamp(min=1)
model.gpu_unigram_probs = probs
model.frozen = True
elapsed = time.time() - t0
print(f"[load] Done in {elapsed:.1f}s")
# Sanity check: generate a few tokens to verify output quality
try:
test = model.generate("The meaning of life is", max_new=20, temperature=0.7, top_k=50, top_p=0.9)
has_spaces = " " in test[20:] # check generated part has spaces
print(f"[sanity] Test generation: {repr(test[:80])}...")
if has_spaces:
print("[sanity] PASSED β€” output contains spaces and looks reasonable")
else:
print("[sanity] WARNING β€” output may be degraded (no spaces in generated text)")
except Exception as e:
print(f"[sanity] Could not run test generation: {e}")
return model
# ─── Load model at startup ───
print("Downloading model from HuggingFace...")
model_path = hf_hub_download(
repo_id="OpenTransformer/markov-5gram-500m",
filename="markov_5gram.cpu.pkl",
cache_dir="/tmp/markov_cache"
)
print(f"Loading model from {model_path}...")
MODEL = MarkovLM.load(model_path)
print("Model ready!")
# ─── Gradio Interface ───
def generate_text(prompt, max_tokens, temperature, top_k, top_p):
if not prompt.strip():
return "Please enter a prompt."
t0 = time.time()
result = MODEL.generate(
prompt=prompt,
max_new=int(max_tokens),
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
)
elapsed = time.time() - t0
gen_tokens = len(tok.encode(result)) - len(tok.encode(prompt))
stats = f"\n\n---\n*Generated {gen_tokens} tokens in {elapsed:.2f}s ({gen_tokens/max(elapsed,0.01):.0f} tok/s)*"
return result + stats
def get_model_info():
total_entries = sum(
gt.size for gt in MODEL.gpu_ngram_tables if gt
)
mem = sum(
(gt.memory_bytes() if gt else 0) + (ct.memory_bytes() if ct else 0)
for gt, ct in zip(MODEL.gpu_ngram_tables, MODEL.gpu_context_tables)
) / 1e6
info = f"""## Model Statistics
- **Architecture**: {MODEL.max_order}-gram with Modified Kneser-Ney smoothing
- **Tokens trained**: {MODEL.tokens_trained:,}
- **Total n-gram entries**: {total_entries:,}
- **Memory usage**: {mem:.1f} MB
- **Tokenizer**: GPT-2 ({VOCAB:,} vocab)
- **Inference**: CPU (searchsorted batch lookup)
### How it works
This is a classical n-gram language model β€” no neural network, no parameters to learn via gradient descent.
Instead, it counts how often sequences of tokens appear in the training data and uses those counts
to predict the next token. Kneser-Ney smoothing interpolates between different context lengths
(unigram through {MODEL.max_order}-gram) so that even unseen contexts get reasonable predictions.
The n-gram counts are stored in sorted hash tables and looked up via binary search (`torch.searchsorted`),
making inference parallel and efficient even on CPU.
### Per-order breakdown"""
for order in range(MODEL.max_order):
gt = MODEL.gpu_ngram_tables[order]
if gt and gt.size > 0:
D1, D2, D3 = MODEL.discounts[order]
info += f"\n- **{order+1}-gram**: {gt.size:,} entries (D1={D1:.3f}, D2={D2:.3f}, D3+={D3:.3f})"
info += f"\n\n*Trained by [OpenTransformers Ltd](https://huggingface.co/OpenTransformer). Part of AGILLM research.*"
return info
with gr.Blocks(
title="Markov Chain LM β€” OpenTransformers",
theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate"),
) as demo:
gr.Markdown("""# Markov Chain Language Model
### Classical N-gram LM with Modified Kneser-Ney Smoothing
*No neural network β€” pure statistical language modelling. [OpenTransformers Ltd](https://huggingface.co/OpenTransformer)*
""")
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter text to continue...",
lines=3,
value="The meaning of life is"
)
output = gr.Markdown(label="Generated Text")
generate_btn = gr.Button("Generate", variant="primary", size="lg")
with gr.Column(scale=1):
max_tokens = gr.Slider(10, 500, value=200, step=10, label="Max tokens")
temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature")
top_k = gr.Slider(1, 200, value=50, step=1, label="Top-K")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
generate_btn.click(
fn=generate_text,
inputs=[prompt, max_tokens, temperature, top_k, top_p],
outputs=output,
)
with gr.Accordion("Model Information", open=False):
gr.Markdown(get_model_info())
gr.Examples(
examples=[
["The meaning of life is"],
["In the beginning, there was"],
["The president of the United States"],
["Machine learning is a field of"],
["Once upon a time in a land far away"],
["The quick brown fox"],
],
inputs=prompt,
)
if __name__ == "__main__":
demo.launch()