Spaces:
Running on Zero
Running on Zero
| """ | |
| Model wrapper for SmolLM2-135M language model. | |
| Primary backend: llama.cpp (via llama-cpp-python) for fast inference | |
| by eliminating Python/PyTorch dispatch overhead (~7x faster than | |
| PyTorch + CUDA Graphs for single-token incremental decode). | |
| Fallback: PyTorch with StaticCache + CUDA Graphs on CUDA, | |
| DynamicCache on CPU. Used automatically when llama-cpp-python | |
| is not installed or no GGUF model file is found. | |
| """ | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer | |
| try: | |
| from llama_cpp import Llama | |
| _HAS_LLAMA_CPP = True | |
| except ImportError: | |
| _HAS_LLAMA_CPP = False | |
| class ModelWrapper: | |
| """Wraps SmolLM2-135M for next-token probability prediction. | |
| Primary backend: llama.cpp β ~7x faster than PyTorch for | |
| single-token incremental decode on CUDA. | |
| Fallback: PyTorch with StaticCache + CUDA Graphs (CUDA) or | |
| DynamicCache (CPU). | |
| """ | |
| MODEL_NAME = "HuggingFaceTB/SmolLM2-135M" | |
| MAX_CONTEXT = 2048 | |
| # When context exceeds MAX_CONTEXT, drop this many old tokens at once. | |
| SLIDE_CHUNK = 512 | |
| _GGUF_NAMES = [ | |
| "smollm2-135m-f32.gguf", | |
| "smollm2-135m-f16.gguf", | |
| "smollm2-135m.gguf", | |
| ] | |
| # PyTorch CUDA graph settings (fallback only) | |
| _GRAPH_WARMUP = 3 | |
| def __init__(self, model_name: str = None, gguf_path: str = None, | |
| verbose: bool = True): | |
| self.model_name = model_name or self.MODEL_NAME | |
| self.verbose = verbose | |
| self._cache_len = 0 | |
| # Try llama.cpp first, fall back to PyTorch | |
| gguf = gguf_path or self._find_gguf() | |
| if _HAS_LLAMA_CPP and gguf: | |
| self._backend = "llama.cpp" | |
| self._init_llama_cpp(gguf) | |
| else: | |
| self._backend = "pytorch" | |
| self._init_pytorch() | |
| if self.verbose: | |
| print( | |
| f"Backend: {self._backend}, device: {self.device}, " | |
| f"vocab_size: {self.vocab_size}", | |
| file=sys.stderr, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # GGUF discovery | |
| # ------------------------------------------------------------------ | |
| def _find_gguf(self) -> str | None: | |
| """Search for a GGUF model file next to this script.""" | |
| base = os.path.dirname(os.path.abspath(__file__)) | |
| for name in self._GGUF_NAMES: | |
| path = os.path.join(base, name) | |
| if os.path.isfile(path): | |
| return path | |
| return None | |
| # ------------------------------------------------------------------ | |
| # llama.cpp backend | |
| # ------------------------------------------------------------------ | |
| def _init_llama_cpp(self, gguf_path: str): | |
| if self.verbose: | |
| print(f"Loading GGUF: {gguf_path}", file=sys.stderr) | |
| self._llm = Llama( | |
| model_path=gguf_path, | |
| n_gpu_layers=-1, | |
| n_ctx=self.MAX_CONTEXT, | |
| seed=42, | |
| logits_all=True, | |
| verbose=self.verbose, | |
| ) | |
| self.vocab_size = self._llm.n_vocab() | |
| # Use HuggingFace tokenizer for encode/decode β llama.cpp's | |
| # detokenize drops content for 47 long whitespace/repeat tokens. | |
| # Token IDs are identical between both tokenizers. | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = self._llm # tests check `model is not None` | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._n_tokens = 0 | |
| # ------------------------------------------------------------------ | |
| # PyTorch backend (fallback) | |
| # ------------------------------------------------------------------ | |
| def _init_pytorch(self): | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache | |
| self.device = self._select_device() | |
| if self.verbose: | |
| print(f"Using device: {self.device}", file=sys.stderr) | |
| print(f"Loading model {self.model_name}...", file=sys.stderr) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, torch_dtype=torch.float32, | |
| ).to(self.device) | |
| self.model.eval() | |
| self.vocab_size = self.model.config.vocab_size | |
| self._use_cuda_graph = (self.device == "cuda") | |
| if self._use_cuda_graph: | |
| self._setup_cuda_graph() | |
| else: | |
| self._past_key_values = None | |
| if self.verbose: | |
| print( | |
| f"Model loaded: vocab_size={self.vocab_size}, " | |
| f"device={self.device}", | |
| file=sys.stderr, | |
| ) | |
| def _select_device() -> str: | |
| if torch.cuda.is_available(): | |
| try: | |
| t = torch.tensor([1.0], device="cuda") | |
| del t | |
| return "cuda" | |
| except Exception: | |
| pass | |
| return "cpu" | |
| def _setup_cuda_graph(self): | |
| from transformers import StaticCache | |
| self._static_cache = StaticCache( | |
| config=self.model.config, | |
| max_batch_size=1, | |
| max_cache_len=self.MAX_CONTEXT, | |
| device=self.device, | |
| dtype=torch.float32, | |
| ) | |
| self._graph = None | |
| self._graph_input = torch.zeros( | |
| 1, 1, dtype=torch.long, device=self.device, | |
| ) | |
| self._graph_position = torch.zeros( | |
| 1, dtype=torch.long, device=self.device, | |
| ) | |
| self._graph_output = None | |
| if self.verbose: | |
| print("CUDA graph acceleration enabled", file=sys.stderr) | |
| def _capture_graph(self): | |
| self._graph_position.fill_(self._cache_len) | |
| self._graph_input.fill_(0) | |
| for _ in range(self._GRAPH_WARMUP): | |
| self.model( | |
| self._graph_input, | |
| past_key_values=self._static_cache, | |
| cache_position=self._graph_position, | |
| use_cache=True, | |
| ) | |
| torch.cuda.synchronize() | |
| self._graph = torch.cuda.CUDAGraph() | |
| with torch.cuda.graph(self._graph): | |
| self._graph_output = self.model( | |
| self._graph_input, | |
| past_key_values=self._static_cache, | |
| cache_position=self._graph_position, | |
| use_cache=True, | |
| ) | |
| torch.cuda.synchronize() | |
| if self.verbose: | |
| print("CUDA graph captured", file=sys.stderr) | |
| # ------------------------------------------------------------------ | |
| # Cache management | |
| # ------------------------------------------------------------------ | |
| def reset_cache(self): | |
| """Clear the KV-cache. Call when starting a new sequence.""" | |
| self._cache_len = 0 | |
| if self._backend == "llama.cpp": | |
| self._llm.reset() | |
| self._n_tokens = 0 | |
| elif self._use_cuda_graph: | |
| self._static_cache.reset() | |
| self._graph = None | |
| else: | |
| self._past_key_values = None | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| def _slide_kv_cache(self, keep: int): | |
| """Shift llama.cpp KV cache: drop oldest tokens, shift positions. | |
| Instead of reset() + eval(all_kept_tokens), this removes the | |
| oldest SLIDE_CHUNK tokens from the KV cache and shifts the | |
| remaining positions down. The last position is left for | |
| _forward_llama_cpp to re-evaluate incrementally (1 token eval | |
| instead of re-processing all *keep* tokens from scratch). | |
| """ | |
| drop = self._n_tokens - keep | |
| if drop <= 0: | |
| return | |
| # Remove positions [0, drop) from KV cache | |
| self._llm._ctx.kv_cache_seq_rm(0, 0, drop) | |
| # Shift remaining positions [drop, n_tokens) down by -drop | |
| self._llm._ctx.kv_cache_seq_shift(0, drop, -1, -drop) | |
| # Set to keep-1 so _forward_llama_cpp sees ctx_len == _n_tokens+1, | |
| # triggering incremental eval of just the last token. Also sync | |
| # llama.cpp's internal counter so eval() places it at the right pos. | |
| self._n_tokens = keep - 1 | |
| self._cache_len = keep - 1 | |
| self._llm.n_tokens = keep - 1 | |
| # ------------------------------------------------------------------ | |
| # Probability prediction | |
| # ------------------------------------------------------------------ | |
| def get_probs(self, token_ids: list[int]) -> torch.Tensor: | |
| """Get next-token probability distribution given context. | |
| Uses KV-cache for incremental inference. On the first call (or | |
| after reset_cache), processes the full context. On subsequent | |
| calls where only one token was appended, processes just that | |
| single new token using the cached states. | |
| Args: | |
| token_ids: List of token IDs as context. Can be empty, | |
| in which case a BOS/default context is used. | |
| Returns: | |
| Tensor of shape (vocab_size,) with probabilities summing to 1. | |
| """ | |
| if len(token_ids) == 0: | |
| bos = getattr(self.tokenizer, "bos_token_id", None) | |
| token_ids = [bos if bos is not None else 0] | |
| if len(token_ids) > self.MAX_CONTEXT: | |
| keep = self.MAX_CONTEXT - self.SLIDE_CHUNK | |
| token_ids = token_ids[-keep:] | |
| if self._backend == "llama.cpp": | |
| self._slide_kv_cache(keep) | |
| else: | |
| self.reset_cache() | |
| ctx_len = len(token_ids) | |
| if self._backend == "llama.cpp": | |
| logits = self._forward_llama_cpp(token_ids, ctx_len) | |
| elif self._use_cuda_graph: | |
| logits = self._forward_cuda_graph(token_ids, ctx_len) | |
| else: | |
| logits = self._forward_eager_cpu(token_ids, ctx_len) | |
| self._cache_len = ctx_len | |
| probs = torch.softmax(logits, dim=0) | |
| return probs.cpu() | |
| # ------------------------------------------------------------------ | |
| # llama.cpp forward | |
| # ------------------------------------------------------------------ | |
| def _forward_llama_cpp( | |
| self, token_ids: list[int], ctx_len: int, | |
| ) -> torch.Tensor: | |
| """llama.cpp forward pass with incremental KV-cache. | |
| Three cases: | |
| 1. ctx_len == _n_tokens + 1 β append one token (common path) | |
| 2. 0 < ctx_len <= _n_tokens β context was trimmed from the front | |
| (sliding window); shift KV cache and eval just the last token | |
| 3. otherwise β cold start; reset and eval everything | |
| """ | |
| if self._n_tokens > 0 and ctx_len == self._n_tokens + 1: | |
| # Case 1: incremental β eval just the new token | |
| self._llm.eval([token_ids[-1]]) | |
| elif self._n_tokens > 0 and 0 < ctx_len <= self._n_tokens: | |
| # Case 2: context trimmed from front (sliding window). | |
| # The caller dropped tokens from the beginning of the context. | |
| # Shift the KV cache instead of re-processing everything. | |
| drop = self._n_tokens - (ctx_len - 1) | |
| self._llm._ctx.kv_cache_seq_rm(0, 0, drop) | |
| self._llm._ctx.kv_cache_seq_shift(0, drop, -1, -drop) | |
| self._llm.n_tokens = ctx_len - 1 | |
| self._n_tokens = ctx_len - 1 | |
| # Now eval just the last token (the one appended after the trim) | |
| self._llm.eval([token_ids[-1]]) | |
| else: | |
| # Case 3: cold start β reset and eval everything | |
| self._llm.reset() | |
| self._llm.eval(token_ids) | |
| self._n_tokens = ctx_len | |
| logits = self._llm.scores[ctx_len - 1] | |
| return torch.from_numpy(logits.copy()) | |
| # ------------------------------------------------------------------ | |
| # PyTorch CUDA Graph forward (fallback) | |
| # ------------------------------------------------------------------ | |
| def _forward_cuda_graph( | |
| self, token_ids: list[int], ctx_len: int, | |
| ) -> torch.Tensor: | |
| if self._cache_len > 0 and ctx_len == self._cache_len + 1: | |
| if self._graph is None: | |
| self._capture_graph() | |
| self._graph_input.fill_(token_ids[-1]) | |
| self._graph_position.fill_(ctx_len - 1) | |
| self._graph.replay() | |
| return self._graph_output.logits[0, -1, :] | |
| else: | |
| self._static_cache.reset() | |
| self._graph = None | |
| input_ids = torch.tensor( | |
| [token_ids], dtype=torch.long, device=self.device, | |
| ) | |
| cache_position = torch.arange( | |
| ctx_len, device=self.device, dtype=torch.long, | |
| ) | |
| outputs = self.model( | |
| input_ids, | |
| past_key_values=self._static_cache, | |
| cache_position=cache_position, | |
| use_cache=True, | |
| ) | |
| return outputs.logits[0, -1, :] | |
| # ------------------------------------------------------------------ | |
| # PyTorch CPU eager forward (fallback) | |
| # ------------------------------------------------------------------ | |
| def _forward_eager_cpu( | |
| self, token_ids: list[int], ctx_len: int, | |
| ) -> torch.Tensor: | |
| if (self._past_key_values is not None | |
| and ctx_len == self._cache_len + 1): | |
| new_ids = torch.tensor( | |
| [[token_ids[-1]]], dtype=torch.long, device=self.device, | |
| ) | |
| outputs = self.model( | |
| new_ids, | |
| past_key_values=self._past_key_values, | |
| use_cache=True, | |
| ) | |
| else: | |
| self._past_key_values = None | |
| input_ids = torch.tensor( | |
| [token_ids], dtype=torch.long, device=self.device, | |
| ) | |
| outputs = self.model( | |
| input_ids, | |
| use_cache=True, | |
| ) | |
| self._past_key_values = outputs.past_key_values | |
| return outputs.logits[0, -1, :] | |
| # ------------------------------------------------------------------ | |
| # Batch forward (stateless) | |
| # ------------------------------------------------------------------ | |
| def forward_window(self, token_ids: list[int]) -> torch.Tensor: | |
| """Run a single forward pass on a token window. | |
| Returns softmax probabilities for ALL positions in the window. | |
| Does NOT use or update the KV-cache β purely stateless. | |
| Note: on llama.cpp backend this invalidates the incremental | |
| cache. Call reset_cache() before resuming get_probs(). | |
| Args: | |
| token_ids: Up to MAX_CONTEXT token IDs. | |
| Returns: | |
| Tensor of shape (len(token_ids), vocab_size) on CPU. | |
| result[j] = P(next_token | token_ids[0], ..., token_ids[j]). | |
| """ | |
| if self._backend == "llama.cpp": | |
| self._llm.reset() | |
| self._llm.eval(token_ids) | |
| logits = self._llm.scores[:len(token_ids)].copy() | |
| self._n_tokens = 0 | |
| self._cache_len = 0 | |
| self._llm.reset() | |
| return torch.softmax(torch.from_numpy(logits), dim=-1) | |
| input_ids = torch.tensor( | |
| [token_ids], dtype=torch.long, device=self.device, | |
| ) | |
| outputs = self.model(input_ids, use_cache=False) | |
| return torch.softmax(outputs.logits[0], dim=-1).cpu() | |