miike-ai commited on
Commit
fa4671c
·
verified ·
1 Parent(s): d9fe3f3

Add 128K validation results and chunked prefill usage example

Browse files
Files changed (1) hide show
  1. README.md +48 -2
README.md CHANGED
@@ -17,24 +17,70 @@ Evaluated against the uncompressed baseline on standard benchmarks:
17
 
18
  In practice, generation quality is nearly indistinguishable from the original model for typical instruction-following and conversational workloads.
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ## Usage
21
 
 
 
22
  ```python
23
  from transformers import AutoModelForCausalLM, AutoTokenizer
24
 
25
  model = AutoModelForCausalLM.from_pretrained(
26
- "LeanLlama-8B",
27
  trust_remote_code=True,
28
  dtype="auto",
29
  device_map="auto",
30
  )
31
- tokenizer = AutoTokenizer.from_pretrained("LeanLlama-8B")
32
 
33
  inputs = tokenizer("What is the capital of France?", return_tensors="pt").to(model.device)
34
  output = model.generate(**inputs, max_new_tokens=128)
35
  print(tokenizer.decode(output[0], skip_special_tokens=True))
36
  ```
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  No special configuration or post-processing is needed. The compression runs transparently inside the model's forward pass.
39
 
40
  ## Base model
 
17
 
18
  In practice, generation quality is nearly indistinguishable from the original model for typical instruction-following and conversational workloads.
19
 
20
+ ## 128K context validation
21
+
22
+ Verified on a single NVIDIA A40 (45 GB) with a needle-in-haystack retrieval task at full context length:
23
+
24
+ | Metric | Result |
25
+ |---|---|
26
+ | Input tokens | 126,239 |
27
+ | Prefill | 132.5s (953 tok/s) |
28
+ | Generation | 64 tokens in 8.9s (7.2 tok/s) |
29
+ | Peak GPU memory | 37.97 GB |
30
+ | Needle retrieved | Yes |
31
+
32
+ For long-context inference, use chunked prefill with `logits_to_keep=0` on intermediate chunks to avoid materializing the full logits tensor. See the usage example below.
33
+
34
  ## Usage
35
 
36
+ **Basic generation:**
37
+
38
  ```python
39
  from transformers import AutoModelForCausalLM, AutoTokenizer
40
 
41
  model = AutoModelForCausalLM.from_pretrained(
42
+ "miike-ai/LeanLlama-8B",
43
  trust_remote_code=True,
44
  dtype="auto",
45
  device_map="auto",
46
  )
47
+ tokenizer = AutoTokenizer.from_pretrained("miike-ai/LeanLlama-8B")
48
 
49
  inputs = tokenizer("What is the capital of France?", return_tensors="pt").to(model.device)
50
  output = model.generate(**inputs, max_new_tokens=128)
51
  print(tokenizer.decode(output[0], skip_special_tokens=True))
52
  ```
53
 
54
+ **Long-context generation (chunked prefill):**
55
+
56
+ ```python
57
+ import torch
58
+
59
+ CHUNK = 4096
60
+ input_ids = tokenizer(long_text, return_tensors="pt").input_ids.to(model.device)
61
+ seq_len = input_ids.shape[1]
62
+
63
+ with torch.no_grad():
64
+ past_kv = None
65
+ for start in range(0, seq_len, CHUNK):
66
+ end = min(start + CHUNK, seq_len)
67
+ keep = 1 if end == seq_len else 0
68
+ out = model(
69
+ input_ids=input_ids[:, start:end],
70
+ past_key_values=past_kv,
71
+ use_cache=True,
72
+ logits_to_keep=keep,
73
+ )
74
+ past_kv = out.past_key_values
75
+
76
+ # Generate from the prefilled cache
77
+ next_id = out.logits[:, -1:, :].argmax(dim=-1)
78
+ for _ in range(max_new_tokens):
79
+ out = model(input_ids=next_id, past_key_values=past_kv, use_cache=True)
80
+ past_kv = out.past_key_values
81
+ next_id = out.logits[:, -1:, :].argmax(dim=-1)
82
+ ```
83
+
84
  No special configuration or post-processing is needed. The compression runs transparently inside the model's forward pass.
85
 
86
  ## Base model