MihaiPopa-1 commited on
Commit
a64d7cf
Β·
verified Β·
1 Parent(s): 826146f

Upload cinnabar_compress.py

Browse files
Files changed (1) hide show
  1. cinnabar_compress.py +51 -35
cinnabar_compress.py CHANGED
@@ -190,42 +190,64 @@ def load_model(tag: str):
190
  print(f" Loading {MODEL_NAMES[tag]} from {repo} …", flush=True)
191
  from transformers import AutoTokenizer, AutoModelForCausalLM
192
 
193
- tok = AutoTokenizer.from_pretrained(repo)
194
- model = AutoModelForCausalLM.from_pretrained(repo, torch_dtype=torch.float32)
 
 
 
 
 
 
 
 
 
195
  model.eval()
 
 
 
 
 
 
 
196
  _model_cache[tag] = (tok, model)
197
  print(f" {MODEL_NAMES[tag]} loaded.", flush=True)
198
  return tok, model
199
 
200
 
201
  # ─────────────────────────────────────────────────────────────────────────────
202
- # Probability helpers (KV-cache enabled)
203
  # ─────────────────────────────────────────────────────────────────────────────
204
 
205
  SCALE = (1 << 16) # cumulative frequency scale for arithmetic coder
206
 
207
- def _prefill(model, token_ids: list):
 
208
  """
209
- Run the model on a prompt and return (probs, past_key_values).
210
- Called once at the start to prime the KV cache.
 
211
  """
 
 
 
 
 
 
 
 
212
  inp = torch.tensor([token_ids], dtype=torch.long)
213
  with torch.no_grad():
214
  out = model(inp, use_cache=True)
215
- probs = F.softmax(out.logits[0, -1, :], dim=-1)
216
  return probs, out.past_key_values
217
 
218
 
219
- def _get_probs_cached(model, token_id: int, past_key_values):
220
- """
221
- Run ONE new token through the model, reusing past_key_values.
222
- Returns (probs, updated_past_key_values).
223
- O(1) in context length β€” this is the KV-cache speedup.
224
- """
225
  inp = torch.tensor([[token_id]], dtype=torch.long)
226
  with torch.no_grad():
227
  out = model(inp, past_key_values=past_key_values, use_cache=True)
228
- probs = F.softmax(out.logits[0, -1, :], dim=-1)
229
  return probs, out.past_key_values
230
 
231
 
@@ -280,36 +302,31 @@ def _probs_to_cum_freqs(probs: torch.Tensor):
280
  def encode(text: bytes, tag: str = "4", verbose: bool = True) -> bytes:
281
  tok, model = load_model(tag)
282
 
283
- # Tokenise input text
284
  token_ids = tok.encode(text.decode("utf-8", errors="replace"))
285
  n_tokens = len(token_ids)
286
- if verbose:
287
- print(f" Tokens: {n_tokens}")
288
-
289
- # Use EOS as the stream terminator; fall back to 0 if undefined
290
  eos = tok.eos_token_id if tok.eos_token_id is not None else 0
291
  bos = tok.bos_token_id if tok.bos_token_id is not None else 0
292
 
293
- enc = ArithmeticEncoder()
 
294
 
295
- # Prefill with BOS β†’ get P(first token | BOS)
296
- probs, past = _prefill(model, [bos])
 
 
297
 
 
298
  for step, tid in enumerate(token_ids):
299
- cum, total = _probs_to_cum_freqs(probs)
300
  lo, hi = cum[tid]
301
  enc.encode_symbol(lo, hi, total)
302
- probs, past = _get_probs_cached(model, tid, past)
303
- if verbose and (step % 50 == 0 or step == n_tokens - 1):
304
- print(f"\r Encoding token {step+1}/{n_tokens} …", end="", flush=True)
305
 
306
- # Encode EOS sentinel so the decoder knows exactly where to stop
307
- cum, total = _probs_to_cum_freqs(probs)
308
- lo, hi = cum[eos]
309
- enc.encode_symbol(lo, hi, total)
310
 
311
  if verbose:
312
- print()
313
 
314
  enc.flush()
315
  compressed = enc.get_bytes()
@@ -344,20 +361,19 @@ def decode(data: bytes, verbose: bool = True) -> bytes:
344
  dec = ArithmeticDecoder(compressed)
345
  out_ids = []
346
 
347
- # Prefill with BOS β†’ get P(first token | BOS), matching encode exactly
348
- probs, past = _prefill(model, [bos])
349
 
350
  step = 0
351
  while True:
352
  cum, total = _probs_to_cum_freqs(probs)
353
  sym = dec.decode_symbol(cum, total)
354
 
355
- # EOS sentinel = end of stream
356
  if sym == eos:
357
  break
358
 
359
  out_ids.append(sym)
360
- probs, past = _get_probs_cached(model, sym, past)
361
 
362
  step += 1
363
  if verbose and (step % 50 == 0):
 
190
  print(f" Loading {MODEL_NAMES[tag]} from {repo} …", flush=True)
191
  from transformers import AutoTokenizer, AutoModelForCausalLM
192
 
193
+ tok = AutoTokenizer.from_pretrained(repo)
194
+
195
+ # bfloat16 halves memory bandwidth on modern CPUs; fall back to float32
196
+ try:
197
+ model = AutoModelForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16)
198
+ # Verify bfloat16 actually works with a tiny test forward pass
199
+ test = torch.zeros(1, 1, dtype=torch.long)
200
+ model(test)
201
+ except Exception:
202
+ model = AutoModelForCausalLM.from_pretrained(repo, torch_dtype=torch.float32)
203
+
204
  model.eval()
205
+
206
+ # torch.compile speeds up the repeated single-token decode loop
207
+ try:
208
+ model = torch.compile(model, mode="reduce-overhead")
209
+ except Exception:
210
+ pass
211
+
212
  _model_cache[tag] = (tok, model)
213
  print(f" {MODEL_NAMES[tag]} loaded.", flush=True)
214
  return tok, model
215
 
216
 
217
  # ─────────────────────────────────────────────────────────────────────────────
218
+ # Probability helpers
219
  # ─────────────────────────────────────────────────────────────────────────────
220
 
221
  SCALE = (1 << 16) # cumulative frequency scale for arithmetic coder
222
 
223
+
224
+ def _all_probs_batched(model, bos: int, token_ids: list) -> list:
225
  """
226
+ ENCODE fast-path: one forward pass over [BOS, t0, t1, ..., t_n-1].
227
+ Returns n+1 float32 probability tensors (one per position).
228
+ This is O(n) instead of the naive O(n^2) token-by-token approach.
229
  """
230
+ inp = torch.tensor([[bos] + token_ids], dtype=torch.long)
231
+ with torch.no_grad():
232
+ logits = model(inp).logits[0] # [n+1, vocab]
233
+ return [F.softmax(logits[i].float(), dim=-1) for i in range(logits.shape[0])]
234
+
235
+
236
+ def _prefill_cached(model, token_ids: list):
237
+ """Prime the KV cache with a prompt. Returns (last_probs, past_key_values)."""
238
  inp = torch.tensor([token_ids], dtype=torch.long)
239
  with torch.no_grad():
240
  out = model(inp, use_cache=True)
241
+ probs = F.softmax(out.logits[0, -1, :].float(), dim=-1)
242
  return probs, out.past_key_values
243
 
244
 
245
+ def _step_cached(model, token_id: int, past_key_values):
246
+ """One autoregressive decode step with KV cache. O(1) per step."""
 
 
 
 
247
  inp = torch.tensor([[token_id]], dtype=torch.long)
248
  with torch.no_grad():
249
  out = model(inp, past_key_values=past_key_values, use_cache=True)
250
+ probs = F.softmax(out.logits[0, -1, :].float(), dim=-1)
251
  return probs, out.past_key_values
252
 
253
 
 
302
  def encode(text: bytes, tag: str = "4", verbose: bool = True) -> bytes:
303
  tok, model = load_model(tag)
304
 
 
305
  token_ids = tok.encode(text.decode("utf-8", errors="replace"))
306
  n_tokens = len(token_ids)
 
 
 
 
307
  eos = tok.eos_token_id if tok.eos_token_id is not None else 0
308
  bos = tok.bos_token_id if tok.bos_token_id is not None else 0
309
 
310
+ if verbose:
311
+ print(f" Tokens: {n_tokens} β€” running single batched forward pass…", flush=True)
312
 
313
+ # ONE forward pass gives all n+1 probability distributions at once.
314
+ # probs_list[i] = P(next | BOS, t0..t_{i-1})
315
+ # probs_list[n] = P(next | BOS, t0..t_{n-1}) ← used to encode EOS
316
+ probs_list = _all_probs_batched(model, bos, token_ids)
317
 
318
+ enc = ArithmeticEncoder()
319
  for step, tid in enumerate(token_ids):
320
+ cum, total = _probs_to_cum_freqs(probs_list[step])
321
  lo, hi = cum[tid]
322
  enc.encode_symbol(lo, hi, total)
 
 
 
323
 
324
+ # EOS sentinel
325
+ cum, total = _probs_to_cum_freqs(probs_list[n_tokens])
326
+ enc.encode_symbol(*cum[eos], total)
 
327
 
328
  if verbose:
329
+ print(f" Done.", flush=True)
330
 
331
  enc.flush()
332
  compressed = enc.get_bytes()
 
361
  dec = ArithmeticDecoder(compressed)
362
  out_ids = []
363
 
364
+ # Prime KV cache with BOS β†’ get P(first token | BOS)
365
+ probs, past = _prefill_cached(model, [bos])
366
 
367
  step = 0
368
  while True:
369
  cum, total = _probs_to_cum_freqs(probs)
370
  sym = dec.decode_symbol(cum, total)
371
 
 
372
  if sym == eos:
373
  break
374
 
375
  out_ids.append(sym)
376
+ probs, past = _step_cached(model, sym, past)
377
 
378
  step += 1
379
  if verbose and (step % 50 == 0):