Spaces:
Running
Running
| """Standalone streaming Lilylet generator (int8 ONNX + two-level KV cache). | |
| Torch-free adaptation of deep-starry's ORTGeneratorKV | |
| (tests/bench_lilylet_int8_ort.py): the patch-level decoder runs incrementally | |
| through `patch_kv_int8.onnx` (O(1) per step) and the token decoder inside each | |
| patch through `token_kv_int8.onnx`, both via onnxruntime. The token embedding | |
| table and model geometry are loaded from vendored assets (`wte.npy`, | |
| `geometry.json`) instead of a torch model, and sampling is reimplemented in | |
| numpy — so the only runtime deps are onnxruntime + numpy. | |
| `generate_stream(...)` is a Python generator: it yields `(raw, pretty, done)` | |
| after every patch, where `raw` is the accumulated decoded text (for the run | |
| log) and `pretty = postprocess(raw)` (for the editor, segmented by measure). | |
| """ | |
| import os | |
| import json | |
| import numpy as np | |
| import onnxruntime as ort | |
| from .postprocess import postprocess | |
| def sample_next (logits, rng, temperature=1.0, top_k=0, top_p=1.0): | |
| '''Sample one token id from a logits vector (numpy) with temperature/top-k/top-p.''' | |
| logits = logits.astype(np.float64) | |
| if temperature != 1.0: | |
| logits = logits / max(temperature, 1e-6) | |
| if top_k and top_k > 0: | |
| k = min(top_k, logits.shape[-1]) | |
| kth = np.sort(logits)[-k] | |
| logits = np.where(logits < kth, -np.inf, logits) | |
| if top_p and top_p < 1.0: | |
| order = np.argsort(logits)[::-1] | |
| sorted_logits = logits[order] | |
| probs = _softmax(sorted_logits) | |
| cdf = np.cumsum(probs) | |
| remove = cdf > top_p | |
| remove[1:] = remove[:-1].copy() | |
| remove[0] = False | |
| sorted_logits = np.where(remove, -np.inf, sorted_logits) | |
| logits = np.full_like(logits, -np.inf) | |
| logits[order] = sorted_logits | |
| probs = _softmax(logits) | |
| return int(rng.choice(len(probs), p=probs)) | |
| def _softmax (x): | |
| x = x - np.max(x) | |
| e = np.exp(x) | |
| return e / e.sum() | |
| class StreamingLilyletGenerator: | |
| '''Loads the int8 KV ONNX sessions + vendored assets and streams generation.''' | |
| def __init__ (self, model_dir, asset_dir, threads=None): | |
| from .tokenizer import LilyletTokenizer | |
| geo = json.load(open(os.path.join(model_dir, 'geometry.json'))) | |
| self.patch_size = geo['patch_size'] | |
| self.pad_id = geo['pad_id'] | |
| self.bos_id = geo['bos_id'] | |
| self.eos_id = geo['eos_id'] | |
| # patch-level KV geometry | |
| self.n_layers = geo['patch']['n_layers'] | |
| self.n_kv = geo['patch']['n_kv_heads'] | |
| self.head_dim = geo['patch']['head_dim'] | |
| # token-level KV geometry | |
| self.t_layers = geo['token']['n_layers'] | |
| self.t_kv = geo['token']['n_kv_heads'] | |
| self.t_head_dim = geo['token']['head_dim'] | |
| self.tokenizer = LilyletTokenizer(os.path.join(asset_dir, 'lilylet-tokenizer.json')) | |
| self.wte = np.load(os.path.join(model_dir, 'wte.npy')) # [vocab, hidden] — model weight, lives with the onnx | |
| so = ort.SessionOptions() | |
| if threads: | |
| so.intra_op_num_threads = threads | |
| self.patch_kv_sess = ort.InferenceSession( | |
| os.path.join(model_dir, 'patch_kv_int8.onnx'), so, providers=['CPUExecutionProvider']) | |
| self.token_kv_sess = ort.InferenceSession( | |
| os.path.join(model_dir, 'token_kv_int8.onnx'), so, providers=['CPUExecutionProvider']) | |
| self.patch_out_names = [o.name for o in self.patch_kv_sess.get_outputs()] | |
| self.token_out_names = [o.name for o in self.token_kv_sess.get_outputs()] | |
| # ---- text helpers (mirror LilyletPatchyGenerator.patch_to_text) ---- | |
| def patch_to_text (self, patch): | |
| out = [] | |
| for tid in patch: | |
| tid = int(tid) | |
| if tid == self.eos_id: | |
| break | |
| if tid in (self.pad_id, self.bos_id): | |
| continue | |
| out.append(self.tokenizer.text_by_id.get(tid, '')) | |
| return ''.join(out) | |
| def patches_to_text (self, patches): | |
| return ''.join(self.patch_to_text(p) for p in patches) | |
| # ---- KV plumbing ---- | |
| def _empty_patch_past (self): | |
| return [np.zeros((1, self.n_kv, 0, self.head_dim), dtype=np.float32) for _ in range(2 * self.n_layers)] | |
| def _empty_token_past (self): | |
| return [np.zeros((1, self.t_kv, 0, self.t_head_dim), dtype=np.float32) for _ in range(2 * self.t_layers)] | |
| def _patch_kv_step (self, patch_rows, past): | |
| '''Feed L new patches (list of patch_size-length id rows) + past KV. | |
| Returns (last_hidden [hidden], new_past list).''' | |
| feed = {'patches': np.asarray([patch_rows], dtype=np.int64)} | |
| for i in range(self.n_layers): | |
| feed[f'past_k_{i}'] = past[2 * i] | |
| feed[f'past_v_{i}'] = past[2 * i + 1] | |
| out = dict(zip(self.patch_out_names, self.patch_kv_sess.run(None, feed))) | |
| new_past = [] | |
| for i in range(self.n_layers): | |
| new_past.append(out[f'new_k_{i}']) | |
| new_past.append(out[f'new_v_{i}']) | |
| return out['hidden'][0, -1], new_past | |
| def _token_kv_step (self, emb_np, past): | |
| '''Feed L new token embeddings [1,L,hidden] + past KV. Returns (logits[-1], new_past).''' | |
| feed = {'inputs_embeds': emb_np.astype(np.float32)} | |
| for i in range(self.t_layers): | |
| feed[f'past_k_{i}'] = past[2 * i] | |
| feed[f'past_v_{i}'] = past[2 * i + 1] | |
| out = dict(zip(self.token_out_names, self.token_kv_sess.run(None, feed))) | |
| new_past = [] | |
| for i in range(self.t_layers): | |
| new_past.append(out[f'new_k_{i}']) | |
| new_past.append(out[f'new_v_{i}']) | |
| return out['logits'][0, -1], new_past | |
| def _generate_patch (self, last_hidden, rng, prefix_ids=None, temperature=1.0, top_k=0, top_p=1.0): | |
| '''Token-level decode for one patch through the token-KV session: feed the | |
| patch state, then bos+prefix embeddings, then sample until the patch is full.''' | |
| generated = list(prefix_ids or []) | |
| tokens = [self.bos_id] + list(prefix_ids or []) | |
| past = self._empty_token_past() | |
| enc = last_hidden.reshape(1, 1, -1).astype(np.float32) | |
| logits, past = self._token_kv_step(enc, past) | |
| for i in range(1, len(tokens)): | |
| emb = self.wte[tokens[i]].reshape(1, 1, -1) | |
| logits, past = self._token_kv_step(emb, past) | |
| while len(generated) < self.patch_size: | |
| nxt = sample_next(logits, rng, temperature=temperature, top_k=top_k, top_p=top_p) | |
| generated.append(nxt) | |
| if len(generated) < self.patch_size: | |
| emb = self.wte[nxt].reshape(1, 1, -1) | |
| logits, past = self._token_kv_step(emb, past) | |
| return generated | |
| def generate_stream (self, prompt_text='', max_patches=256, temperature=1.0, top_k=0, | |
| top_p=0.9, measures=None, seed=0): | |
| '''Autoregressive generation, yielding after every patch. | |
| Yields (raw, pretty, done): | |
| raw -- accumulated decoded patch text (with `[r:x/y]` markers), for the log | |
| pretty -- postprocess(raw), measure-segmented, for the editor | |
| done -- True on the final yield (EOS patch or max_patches reached) | |
| ''' | |
| rng = np.random.default_rng(seed) | |
| bos_patch = [self.bos_id] * (self.patch_size - 1) + [self.eos_id] | |
| patches = [bos_patch] | |
| if prompt_text: | |
| for line in prompt_text.splitlines(): | |
| ids = self.tokenizer.encode(line + '\n') | |
| for i in range(0, len(ids), self.patch_size): | |
| chunk = ids[i:i + self.patch_size] | |
| patches.append(chunk + [self.pad_id] * (self.patch_size - len(chunk))) | |
| out_text = self.patches_to_text(patches[1:]) | |
| prime_ids = self.tokenizer.encode(f'[r:0/{measures}]') if measures is not None else None | |
| primed = False | |
| # prefill: run all seed patches through the patch-KV decoder in one call | |
| past = self._empty_patch_past() | |
| last, past = self._patch_kv_step(patches, past) | |
| yield out_text, postprocess(out_text), False | |
| for _ in range(max_patches): | |
| patch_ids = self._generate_patch(last, rng, temperature=temperature, top_k=top_k, top_p=top_p) | |
| # first time the model emits a stream patch, re-sample with the forced | |
| # `[r:0/<measures>]` prefix so the body starts at the requested measure count. | |
| if prime_ids is not None and not primed and self.patch_to_text(patch_ids).startswith('[r:'): | |
| primed = True | |
| patch_ids = self._generate_patch(last, rng, prefix_ids=prime_ids, | |
| temperature=temperature, top_k=top_k, top_p=top_p) | |
| # EOS patch -> done | |
| if patch_ids[0] == self.bos_id and patch_ids[1] == self.eos_id: | |
| break | |
| out_text += self.patch_to_text(patch_ids) | |
| # mask tokens after the first EOS inside the patch to PAD before caching | |
| clean = list(patch_ids) | |
| seen_eos = False | |
| for j in range(len(clean)): | |
| if seen_eos: | |
| clean[j] = self.pad_id | |
| if clean[j] == self.eos_id: | |
| seen_eos = True | |
| # advance the patch-level cache by the one new patch -> next hidden state | |
| last, past = self._patch_kv_step([clean], past) | |
| yield out_text, postprocess(out_text), False | |
| yield out_text, postprocess(out_text), True | |