Add 128K validation results and chunked prefill usage example
Browse files
README.md
CHANGED
|
@@ -17,24 +17,70 @@ Evaluated against the uncompressed baseline on standard benchmarks:
|
|
| 17 |
|
| 18 |
In practice, generation quality is nearly indistinguishable from the original model for typical instruction-following and conversational workloads.
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
## Usage
|
| 21 |
|
|
|
|
|
|
|
| 22 |
```python
|
| 23 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 24 |
|
| 25 |
model = AutoModelForCausalLM.from_pretrained(
|
| 26 |
-
"LeanLlama-8B",
|
| 27 |
trust_remote_code=True,
|
| 28 |
dtype="auto",
|
| 29 |
device_map="auto",
|
| 30 |
)
|
| 31 |
-
tokenizer = AutoTokenizer.from_pretrained("LeanLlama-8B")
|
| 32 |
|
| 33 |
inputs = tokenizer("What is the capital of France?", return_tensors="pt").to(model.device)
|
| 34 |
output = model.generate(**inputs, max_new_tokens=128)
|
| 35 |
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
| 36 |
```
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
No special configuration or post-processing is needed. The compression runs transparently inside the model's forward pass.
|
| 39 |
|
| 40 |
## Base model
|
|
|
|
| 17 |
|
| 18 |
In practice, generation quality is nearly indistinguishable from the original model for typical instruction-following and conversational workloads.
|
| 19 |
|
| 20 |
+
## 128K context validation
|
| 21 |
+
|
| 22 |
+
Verified on a single NVIDIA A40 (45 GB) with a needle-in-haystack retrieval task at full context length:
|
| 23 |
+
|
| 24 |
+
| Metric | Result |
|
| 25 |
+
|---|---|
|
| 26 |
+
| Input tokens | 126,239 |
|
| 27 |
+
| Prefill | 132.5s (953 tok/s) |
|
| 28 |
+
| Generation | 64 tokens in 8.9s (7.2 tok/s) |
|
| 29 |
+
| Peak GPU memory | 37.97 GB |
|
| 30 |
+
| Needle retrieved | Yes |
|
| 31 |
+
|
| 32 |
+
For long-context inference, use chunked prefill with `logits_to_keep=0` on intermediate chunks to avoid materializing the full logits tensor. See the usage example below.
|
| 33 |
+
|
| 34 |
## Usage
|
| 35 |
|
| 36 |
+
**Basic generation:**
|
| 37 |
+
|
| 38 |
```python
|
| 39 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 40 |
|
| 41 |
model = AutoModelForCausalLM.from_pretrained(
|
| 42 |
+
"miike-ai/LeanLlama-8B",
|
| 43 |
trust_remote_code=True,
|
| 44 |
dtype="auto",
|
| 45 |
device_map="auto",
|
| 46 |
)
|
| 47 |
+
tokenizer = AutoTokenizer.from_pretrained("miike-ai/LeanLlama-8B")
|
| 48 |
|
| 49 |
inputs = tokenizer("What is the capital of France?", return_tensors="pt").to(model.device)
|
| 50 |
output = model.generate(**inputs, max_new_tokens=128)
|
| 51 |
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
| 52 |
```
|
| 53 |
|
| 54 |
+
**Long-context generation (chunked prefill):**
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
import torch
|
| 58 |
+
|
| 59 |
+
CHUNK = 4096
|
| 60 |
+
input_ids = tokenizer(long_text, return_tensors="pt").input_ids.to(model.device)
|
| 61 |
+
seq_len = input_ids.shape[1]
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
past_kv = None
|
| 65 |
+
for start in range(0, seq_len, CHUNK):
|
| 66 |
+
end = min(start + CHUNK, seq_len)
|
| 67 |
+
keep = 1 if end == seq_len else 0
|
| 68 |
+
out = model(
|
| 69 |
+
input_ids=input_ids[:, start:end],
|
| 70 |
+
past_key_values=past_kv,
|
| 71 |
+
use_cache=True,
|
| 72 |
+
logits_to_keep=keep,
|
| 73 |
+
)
|
| 74 |
+
past_kv = out.past_key_values
|
| 75 |
+
|
| 76 |
+
# Generate from the prefilled cache
|
| 77 |
+
next_id = out.logits[:, -1:, :].argmax(dim=-1)
|
| 78 |
+
for _ in range(max_new_tokens):
|
| 79 |
+
out = model(input_ids=next_id, past_key_values=past_kv, use_cache=True)
|
| 80 |
+
past_kv = out.past_key_values
|
| 81 |
+
next_id = out.logits[:, -1:, :].argmax(dim=-1)
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
No special configuration or post-processing is needed. The compression runs transparently inside the model's forward pass.
|
| 85 |
|
| 86 |
## Base model
|