Spaces:
Runtime error
Runtime error
Fix OOM: drop 5-gram dict immediately after pickle load, prune 3/4-gram, conservative memory
8e22592 verified | #!/usr/bin/env python3 | |
| """ | |
| Markov Chain Language Model β Interactive Demo | |
| OpenTransformers Ltd | Part of AGILLM Research | |
| Classical n-gram LM with Modified Kneser-Ney smoothing. | |
| GPU hash tables (sorted int64 + searchsorted) for parallel inference. | |
| Runs on CPU for HF Spaces compatibility. | |
| """ | |
| import os, sys, math, time, pickle, gc, ctypes | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| # βββ Force CPU for HF Spaces βββ | |
| DEV = torch.device("cpu") | |
| # βββ Memory management βββ | |
| # HF Spaces free tier has 16GB. After Python + torch + gradio + pickle load, | |
| # roughly 10-11GB is available. We budget conservatively. | |
| TOTAL_MEM_BUDGET_GB = float(os.environ.get("MEM_BUDGET_GB", "10.0")) | |
| def _malloc_trim(): | |
| """Force glibc to return freed memory to OS.""" | |
| try: | |
| ctypes.CDLL("libc.so.6").malloc_trim(0) | |
| except Exception: | |
| pass | |
| def _force_free(): | |
| """Aggressive memory reclaim: gc + malloc_trim.""" | |
| gc.collect() | |
| gc.collect() | |
| _malloc_trim() | |
| # βββ Tokenizer βββ | |
| TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "gpt2") | |
| def _load_tokenizer(): | |
| from transformers import AutoTokenizer, logging as hf_log | |
| hf_log.set_verbosity_error() | |
| t = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True) | |
| if t.pad_token is None: | |
| t.add_special_tokens({"pad_token": "<|pad|>"}) | |
| return t | |
| tok = _load_tokenizer() | |
| VOCAB = max(tok.get_vocab().values()) + 1 | |
| EOS = tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id | |
| # βββ FNV-1a Hashing βββ | |
| FNV_OFFSET = 14695981039346656037 | |
| FNV_PRIME = 1099511628211 | |
| MASK64 = (1 << 64) - 1 | |
| INT64_MAX = (1 << 63) - 1 | |
| INT64_WRAP = 1 << 64 | |
| FNV_OFFSET_S = FNV_OFFSET - INT64_WRAP | |
| def _hash_ngram_batch_gpu(contexts: torch.Tensor) -> torch.Tensor: | |
| B, N = contexts.shape | |
| h = torch.full((B,), FNV_OFFSET_S, dtype=torch.int64, device=contexts.device) | |
| for i in range(N): | |
| h = h ^ contexts[:, i] | |
| h = h * FNV_PRIME | |
| return h | |
| class GPUHashTable: | |
| """Immutable hash table using sorted int64 keys + searchsorted.""" | |
| def __init__(self): | |
| self.hashes: Optional[torch.Tensor] = None | |
| self.counts: Optional[torch.Tensor] = None | |
| self.continuation_counts: Optional[torch.Tensor] = None | |
| self.total: int = 0 | |
| self.size: int = 0 | |
| def batch_lookup(self, hashes: torch.Tensor) -> torch.Tensor: | |
| if self.size == 0: | |
| return torch.zeros_like(hashes) | |
| idx = torch.searchsorted(self.hashes, hashes).clamp(0, self.size - 1) | |
| found = (self.hashes[idx] == hashes) | |
| return torch.where(found, self.counts[idx], torch.zeros_like(hashes)) | |
| def batch_lookup_continuations(self, hashes: torch.Tensor) -> torch.Tensor: | |
| if self.size == 0 or self.continuation_counts is None: | |
| return torch.zeros_like(hashes) | |
| idx = torch.searchsorted(self.hashes, hashes).clamp(0, self.size - 1) | |
| found = (self.hashes[idx] == hashes) | |
| return torch.where(found, self.continuation_counts[idx], torch.zeros_like(hashes)) | |
| def memory_bytes(self) -> int: | |
| total = 0 | |
| for t in [self.hashes, self.counts, self.continuation_counts]: | |
| if t is not None: total += t.nelement() * t.element_size() | |
| return total | |
| def _hash_ngram_py(ngram: tuple) -> int: | |
| h = FNV_OFFSET | |
| for t in ngram: | |
| h ^= (t & MASK64) | |
| h = (h * FNV_PRIME) & MASK64 | |
| return h if h <= INT64_MAX else h - INT64_WRAP | |
| # βββ Memory-aware freezing βββ | |
| BYTES_PER_NGRAM = 16 # int64 hash + int64 count | |
| BYTES_PER_CTX = 24 # int64 hash + int64 count + int64 continuation | |
| SORT_OVERHEAD = 3 # argsort needs ~3x the tensor size transiently | |
| def _count_entries_at_threshold(d, threshold): | |
| """Count how many n-gram entries have count >= threshold.""" | |
| n = 0 | |
| for nexts in d.values(): | |
| for cnt in nexts.values(): | |
| if cnt >= threshold: | |
| n += 1 | |
| return n | |
| def _estimate_freeze_peak(n_entries, n_contexts): | |
| """Estimate peak memory (bytes) during freeze for an order. | |
| Peak = raw tensors being built + argsort workspace + context tensors. | |
| """ | |
| ngram_bytes = n_entries * BYTES_PER_NGRAM * SORT_OVERHEAD | |
| ctx_bytes = n_contexts * BYTES_PER_CTX * SORT_OVERHEAD | |
| return ngram_bytes + ctx_bytes | |
| def _estimate_python_dict_bytes(d): | |
| """Rough estimate of Python dict memory for one order's raw data. | |
| Each outer dict entry: ~232 bytes (key tuple + value dict overhead). | |
| Each inner entry: ~128 bytes (int key + int value + dict entry). | |
| """ | |
| n_contexts = len(d) | |
| n_entries = sum(len(nexts) for nexts in d.values()) | |
| return n_contexts * 232 + n_entries * 128 | |
| def _pick_threshold_for_budget(d, budget_bytes): | |
| """Find lowest prune threshold that fits freeze peak in budget.""" | |
| n_contexts = len(d) | |
| for thresh in [1, 2, 3, 5, 8, 10, 15, 25, 50]: | |
| n = _count_entries_at_threshold(d, thresh) | |
| peak = _estimate_freeze_peak(n, n_contexts) | |
| if peak <= budget_bytes: | |
| return thresh, n | |
| # Nothing fits β return highest with its count | |
| n = _count_entries_at_threshold(d, 50) | |
| return 50, n | |
| def _freeze_order(order, d, device, prune_threshold=1): | |
| """Freeze a single n-gram order into GPU hash tables. | |
| Writes directly into pre-allocated tensors. Catches OOM to prevent | |
| corrupted tables from poisoning higher-order lookups. | |
| """ | |
| if not d: | |
| return GPUHashTable(), GPUHashTable() | |
| threshold = max(prune_threshold, 1) | |
| # Count entries | |
| n_ngrams = 0 | |
| for nexts in d.values(): | |
| for cnt in nexts.values(): | |
| if cnt >= threshold: | |
| n_ngrams += 1 | |
| n_contexts = len(d) | |
| gt = GPUHashTable() | |
| ct = GPUHashTable() | |
| try: | |
| # Build ngram table | |
| if n_ngrams > 0: | |
| hashes = torch.empty(n_ngrams, dtype=torch.int64) | |
| counts = torch.empty(n_ngrams, dtype=torch.int64) | |
| idx = 0 | |
| for ctx, nexts in d.items(): | |
| for next_tok, cnt in nexts.items(): | |
| if cnt >= threshold: | |
| hashes[idx] = _hash_ngram_py(ctx + (next_tok,)) | |
| counts[idx] = cnt | |
| idx += 1 | |
| sort_idx = hashes.argsort() | |
| gt.hashes = hashes[sort_idx].to(device) | |
| gt.counts = counts[sort_idx].to(device) | |
| gt.total = counts.sum().item() | |
| gt.size = n_ngrams | |
| del hashes, counts, sort_idx | |
| _force_free() | |
| # Build context table | |
| if n_contexts > 0: | |
| c_hashes = torch.empty(n_contexts, dtype=torch.int64) | |
| c_counts = torch.empty(n_contexts, dtype=torch.int64) | |
| c_uniques = torch.empty(n_contexts, dtype=torch.int64) | |
| idx = 0 | |
| for ctx, nexts in d.items(): | |
| c_hashes[idx] = _hash_ngram_py(ctx) | |
| c_counts[idx] = sum(nexts.values()) | |
| c_uniques[idx] = len(nexts) | |
| idx += 1 | |
| sort_idx = c_hashes.argsort() | |
| ct.hashes = c_hashes[sort_idx].to(device) | |
| ct.counts = c_counts[sort_idx].to(device) | |
| ct.continuation_counts = c_uniques[sort_idx].to(device) | |
| ct.total = c_counts.sum().item() | |
| ct.size = n_contexts | |
| del c_hashes, c_counts, c_uniques, sort_idx | |
| _force_free() | |
| except (MemoryError, RuntimeError) as e: | |
| print(f" [OOM] Failed to freeze order {order+1}: {e}") | |
| print(f" [OOM] Skipping this order to prevent corruption") | |
| gt = GPUHashTable() | |
| ct = GPUHashTable() | |
| _force_free() | |
| return gt, ct | |
| class MarkovLM: | |
| """N-gram LM with Modified Kneser-Ney smoothing.""" | |
| def __init__(self, max_order: int = 5): | |
| self.max_order = max_order | |
| self.cpu_counts: List[Dict] = [] | |
| self.total_tokens = 0 | |
| self.tokens_trained = 0 | |
| self.gpu_ngram_tables: List[Optional[GPUHashTable]] = [None] * max_order | |
| self.gpu_context_tables: List[Optional[GPUHashTable]] = [None] * max_order | |
| self.frozen = False | |
| self.discounts: List[Tuple[float, float, float]] = [(0.5, 1.0, 1.5)] * max_order | |
| self.gpu_unigram_probs: Optional[torch.Tensor] = None | |
| def _estimate_discounts(self): | |
| for order in range(self.max_order): | |
| gt = self.gpu_ngram_tables[order] | |
| if gt is None or gt.size == 0: | |
| continue | |
| counts = gt.counts | |
| n1 = (counts == 1).sum().item() | |
| n2 = (counts == 2).sum().item() | |
| n3 = (counts == 3).sum().item() | |
| n4 = (counts == 4).sum().item() | |
| if n1 == 0 or n2 == 0: | |
| continue | |
| Y = n1 / (n1 + 2 * n2) | |
| D1 = max(0.01, min(1 - 2 * Y * (n2 / max(n1, 1)), 0.99)) | |
| D2 = max(D1, min(2 - 3 * Y * (n3 / max(n2, 1)), 1.99)) | |
| D3 = max(D2, min(3 - 4 * Y * (n4 / max(n3, 1)), 2.99)) | |
| self.discounts[order] = (D1, D2, D3) | |
| def batch_log_probs(self, contexts: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | |
| B = targets.shape[0] | |
| device = targets.device | |
| log_probs = torch.full((B,), math.log(1.0 / VOCAB), dtype=torch.float32, device=device) | |
| for order in range(self.max_order): | |
| gt = self.gpu_ngram_tables[order] | |
| ct = self.gpu_context_tables[order] | |
| if gt is None or ct is None or gt.size == 0: | |
| continue | |
| ctx_len = order | |
| if ctx_len == 0: | |
| if self.gpu_unigram_probs is not None: | |
| safe_t = targets.clamp(0, VOCAB - 1) | |
| uni_p = self.gpu_unigram_probs[safe_t] | |
| valid = uni_p > 0 | |
| log_probs = torch.where(valid, torch.log(uni_p + 1e-30), log_probs) | |
| continue | |
| if ctx_len > contexts.shape[1]: | |
| continue | |
| ctx = contexts[:, -ctx_len:] | |
| has_ctx = (ctx >= 0).all(dim=1) | |
| if not has_ctx.any(): | |
| continue | |
| full_ngram = torch.cat([ctx, targets.unsqueeze(1)], dim=1) | |
| ngram_counts = gt.batch_lookup(_hash_ngram_batch_gpu(full_ngram)).float() | |
| ctx_hashes = _hash_ngram_batch_gpu(ctx) | |
| ctx_totals = ct.batch_lookup(ctx_hashes).float() | |
| ctx_uniques = ct.batch_lookup_continuations(ctx_hashes).float() | |
| D1, D2, D3 = self.discounts[order] | |
| discount = torch.where(ngram_counts >= 3, D3, | |
| torch.where(ngram_counts >= 2, D2, | |
| torch.where(ngram_counts >= 1, D1, 0.0))) | |
| numerator = (ngram_counts - discount).clamp(min=0) | |
| denominator = ctx_totals.clamp(min=1) | |
| gamma = (D3 * ctx_uniques) / denominator | |
| gamma = gamma.clamp(0, 1) | |
| p_lower = log_probs.exp() | |
| p_combined = numerator / denominator + gamma * p_lower | |
| valid = has_ctx & (ctx_totals > 0) | |
| log_probs = torch.where(valid, torch.log(p_combined.clamp(min=1e-30)), log_probs) | |
| return log_probs | |
| def generate(self, prompt: str, max_new: int = 200, temperature: float = 0.8, | |
| top_k: int = 50, top_p: float = 0.9): | |
| assert self.frozen, "Call freeze() first" | |
| ctx_len = self.max_order - 1 | |
| ids = tok.encode(prompt) | |
| n_cands = min(top_k * 10, VOCAB) | |
| if self.gpu_unigram_probs is not None: | |
| _, candidates = self.gpu_unigram_probs.topk(n_cands) | |
| else: | |
| candidates = torch.arange(n_cands, device=DEV) | |
| for _ in range(max_new): | |
| if len(ids) >= ctx_len: | |
| ctx = ids[-ctx_len:] | |
| else: | |
| ctx = [-1] * (ctx_len - len(ids)) + ids | |
| ctx_t = torch.tensor([ctx], dtype=torch.int64, device=DEV).expand(n_cands, ctx_len) | |
| log_probs = self.batch_log_probs(ctx_t, candidates) | |
| probs = (log_probs / max(temperature, 1e-8)).softmax(0) | |
| if top_k > 0 and top_k < n_cands: | |
| vals, idx = probs.topk(top_k) | |
| mask = torch.zeros_like(probs) | |
| mask.scatter_(0, idx, vals) | |
| probs = mask | |
| if top_p < 1.0: | |
| sp, si = probs.sort(descending=True) | |
| cum = sp.cumsum(0) | |
| cutoff = cum > top_p | |
| cutoff[0] = False | |
| sp[cutoff] = 0 | |
| probs = torch.zeros_like(probs).scatter_(0, si, sp) | |
| if probs.sum() == 0: | |
| next_tok = candidates[0].item() | |
| else: | |
| probs = probs / probs.sum() | |
| next_tok = candidates[probs.multinomial(1).item()].item() | |
| ids.append(next_tok) | |
| if next_tok == EOS: | |
| break | |
| return tok.decode(ids, skip_special_tokens=True) | |
| def load(cls, path: str, mem_budget_gb: float = TOTAL_MEM_BUDGET_GB) -> 'MarkovLM': | |
| """Load with data-driven adaptive memory management. | |
| Strategy: | |
| 1. Load pickle into Python dicts | |
| 2. Estimate total Python dict memory for ALL orders | |
| 3. For each order (lowβhigh), compute how much budget remains | |
| after accounting for still-loaded raw dicts of future orders | |
| 4. Pick the lowest prune threshold whose freeze peak fits | |
| 5. Freeze, free raw dict, malloc_trim to reclaim OS memory | |
| 6. On OOM during freeze: skip order (prevents corruption) | |
| Lower orders get priority β they're smaller AND more impactful | |
| for quality (every token hits unigram/bigram; few hit 5-gram). | |
| """ | |
| p = Path(path) | |
| for suffix in ['.cpu.pkl', '.pkl']: | |
| candidate = p.with_suffix(suffix) | |
| if candidate.exists(): | |
| p = candidate; break | |
| print(f"[load] {p} ({p.stat().st_size / 1e6:.1f} MB)...") | |
| with open(p, 'rb') as f: | |
| data = pickle.load(f) | |
| max_order = data['max_order'] | |
| model = cls(max_order=max_order) | |
| model.total_tokens = data['total_tokens'] | |
| model.tokens_trained = data['tokens_trained'] | |
| model.discounts = data.get('discounts', model.discounts) | |
| raw_counts = data['cpu_counts'] | |
| del data | |
| _force_free() | |
| # On HF Spaces (16GB), the pickle's Python dicts can use 6-8GB. | |
| # We must free each order's raw dict BEFORE freezing the next. | |
| # Higher orders (4-gram, 5-gram) have massive dicts that may not | |
| # fit in memory alongside lower orders' tensors. Strategy: | |
| # - Always load orders 0-2 (unigram, bigram, trigram) β essential | |
| # - Try order 3 (4-gram) with pruning | |
| # - Skip order 4 (5-gram) if memory is tight | |
| # This is conservative but prevents OOM kills. | |
| # First: drop higher orders' raw dicts immediately to free memory | |
| # before we even start freezing. This is the key insight β the | |
| # pickle loads ALL orders into memory, but we can delete the ones | |
| # we'll process later (or skip entirely) to free GB of RAM. | |
| MAX_ORDERS_TO_LOAD = int(os.environ.get("MAX_ORDERS", "4")) # default: up to 4-gram | |
| for order in range(MAX_ORDERS_TO_LOAD, max_order): | |
| print(f" [mem] Dropping {order+1}-gram raw dict to save memory") | |
| raw_counts[order] = None | |
| _force_free() | |
| total_entries = 0 | |
| print(f"[load] {max_order}-gram | {model.tokens_trained:,} tokens") | |
| print(f"[load] Loading up to {MAX_ORDERS_TO_LOAD}-gram (of {max_order})") | |
| print("[load] Freezing (one order at a time)...") | |
| t0 = time.time() | |
| tensor_bytes_used = 0 | |
| for order in range(min(max_order, MAX_ORDERS_TO_LOAD)): | |
| raw = raw_counts[order] | |
| raw_counts[order] = None # unhook from list | |
| _force_free() | |
| if not raw: | |
| continue | |
| # Prune higher orders more aggressively | |
| if order <= 1: | |
| prune_t = 1 # unigrams/bigrams: keep all | |
| elif order == 2: | |
| prune_t = 2 # trigrams: keep count >= 2 | |
| else: | |
| prune_t = 3 # 4-gram+: keep count >= 3 | |
| n_entries = sum(len(nexts) for nexts in raw.values()) | |
| n_ctx = len(raw) | |
| print(f" [scan] {order+1}-gram: {n_ctx:,} contexts, {n_entries:,} raw entries") | |
| gt, ct = _freeze_order(order, raw, device=DEV, prune_threshold=prune_t) | |
| model.gpu_ngram_tables[order] = gt | |
| model.gpu_context_tables[order] = ct | |
| this_tensor_bytes = gt.memory_bytes() + ct.memory_bytes() | |
| tensor_bytes_used += this_tensor_bytes | |
| del raw | |
| _force_free() | |
| n_ent = gt.size | |
| print(f" {order+1}-gram: {n_ent:,} entries | {this_tensor_bytes/1e6:.1f} MB (prune>={prune_t})") | |
| total_entries += n_ent | |
| del raw_counts | |
| _force_free() | |
| print(f"[load] Total: {total_entries:,} entries | {tensor_bytes_used/1e6:.1f} MB tensors") | |
| # Estimate KN discounts | |
| model._estimate_discounts() | |
| # Unigram probs | |
| gt0 = model.gpu_ngram_tables[0] | |
| if gt0 and gt0.size > 0: | |
| probs = torch.zeros(VOCAB, dtype=torch.float32, device=DEV) | |
| total = gt0.total | |
| if total > 0 and gt0.counts is not None: | |
| all_toks = torch.arange(VOCAB, dtype=torch.int64, device=DEV).unsqueeze(1) | |
| all_hashes = _hash_ngram_batch_gpu(all_toks) | |
| counts = gt0.batch_lookup(all_hashes).float() | |
| probs = counts / counts.sum().clamp(min=1) | |
| model.gpu_unigram_probs = probs | |
| model.frozen = True | |
| elapsed = time.time() - t0 | |
| print(f"[load] Done in {elapsed:.1f}s") | |
| # Sanity check: generate a few tokens to verify output quality | |
| try: | |
| test = model.generate("The meaning of life is", max_new=20, temperature=0.7, top_k=50, top_p=0.9) | |
| has_spaces = " " in test[20:] # check generated part has spaces | |
| print(f"[sanity] Test generation: {repr(test[:80])}...") | |
| if has_spaces: | |
| print("[sanity] PASSED β output contains spaces and looks reasonable") | |
| else: | |
| print("[sanity] WARNING β output may be degraded (no spaces in generated text)") | |
| except Exception as e: | |
| print(f"[sanity] Could not run test generation: {e}") | |
| return model | |
| # βββ Load model at startup βββ | |
| print("Downloading model from HuggingFace...") | |
| model_path = hf_hub_download( | |
| repo_id="OpenTransformer/markov-5gram-500m", | |
| filename="markov_5gram.cpu.pkl", | |
| cache_dir="/tmp/markov_cache" | |
| ) | |
| print(f"Loading model from {model_path}...") | |
| MODEL = MarkovLM.load(model_path) | |
| print("Model ready!") | |
| # βββ Gradio Interface βββ | |
| def generate_text(prompt, max_tokens, temperature, top_k, top_p): | |
| if not prompt.strip(): | |
| return "Please enter a prompt." | |
| t0 = time.time() | |
| result = MODEL.generate( | |
| prompt=prompt, | |
| max_new=int(max_tokens), | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| ) | |
| elapsed = time.time() - t0 | |
| gen_tokens = len(tok.encode(result)) - len(tok.encode(prompt)) | |
| stats = f"\n\n---\n*Generated {gen_tokens} tokens in {elapsed:.2f}s ({gen_tokens/max(elapsed,0.01):.0f} tok/s)*" | |
| return result + stats | |
| def get_model_info(): | |
| total_entries = sum( | |
| gt.size for gt in MODEL.gpu_ngram_tables if gt | |
| ) | |
| mem = sum( | |
| (gt.memory_bytes() if gt else 0) + (ct.memory_bytes() if ct else 0) | |
| for gt, ct in zip(MODEL.gpu_ngram_tables, MODEL.gpu_context_tables) | |
| ) / 1e6 | |
| info = f"""## Model Statistics | |
| - **Architecture**: {MODEL.max_order}-gram with Modified Kneser-Ney smoothing | |
| - **Tokens trained**: {MODEL.tokens_trained:,} | |
| - **Total n-gram entries**: {total_entries:,} | |
| - **Memory usage**: {mem:.1f} MB | |
| - **Tokenizer**: GPT-2 ({VOCAB:,} vocab) | |
| - **Inference**: CPU (searchsorted batch lookup) | |
| ### How it works | |
| This is a classical n-gram language model β no neural network, no parameters to learn via gradient descent. | |
| Instead, it counts how often sequences of tokens appear in the training data and uses those counts | |
| to predict the next token. Kneser-Ney smoothing interpolates between different context lengths | |
| (unigram through {MODEL.max_order}-gram) so that even unseen contexts get reasonable predictions. | |
| The n-gram counts are stored in sorted hash tables and looked up via binary search (`torch.searchsorted`), | |
| making inference parallel and efficient even on CPU. | |
| ### Per-order breakdown""" | |
| for order in range(MODEL.max_order): | |
| gt = MODEL.gpu_ngram_tables[order] | |
| if gt and gt.size > 0: | |
| D1, D2, D3 = MODEL.discounts[order] | |
| info += f"\n- **{order+1}-gram**: {gt.size:,} entries (D1={D1:.3f}, D2={D2:.3f}, D3+={D3:.3f})" | |
| info += f"\n\n*Trained by [OpenTransformers Ltd](https://huggingface.co/OpenTransformer). Part of AGILLM research.*" | |
| return info | |
| with gr.Blocks( | |
| title="Markov Chain LM β OpenTransformers", | |
| theme=gr.themes.Base(primary_hue="blue", neutral_hue="slate"), | |
| ) as demo: | |
| gr.Markdown("""# Markov Chain Language Model | |
| ### Classical N-gram LM with Modified Kneser-Ney Smoothing | |
| *No neural network β pure statistical language modelling. [OpenTransformers Ltd](https://huggingface.co/OpenTransformer)* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter text to continue...", | |
| lines=3, | |
| value="The meaning of life is" | |
| ) | |
| output = gr.Markdown(label="Generated Text") | |
| generate_btn = gr.Button("Generate", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| max_tokens = gr.Slider(10, 500, value=200, step=10, label="Max tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature") | |
| top_k = gr.Slider(1, 200, value=50, step=1, label="Top-K") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P") | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[prompt, max_tokens, temperature, top_k, top_p], | |
| outputs=output, | |
| ) | |
| with gr.Accordion("Model Information", open=False): | |
| gr.Markdown(get_model_info()) | |
| gr.Examples( | |
| examples=[ | |
| ["The meaning of life is"], | |
| ["In the beginning, there was"], | |
| ["The president of the United States"], | |
| ["Machine learning is a field of"], | |
| ["Once upon a time in a land far away"], | |
| ["The quick brown fox"], | |
| ], | |
| inputs=prompt, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |