Update Readme.md
Browse files- .gitignore +0 -1
- README.md +16 -1
- custom_generate/generate.py +3 -3
.gitignore
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
.env
|
|
|
|
|
|
README.md
CHANGED
|
@@ -4,6 +4,8 @@ tags:
|
|
| 4 |
- custom_generate
|
| 5 |
---
|
| 6 |
## Overview
|
|
|
|
|
|
|
| 7 |
|
| 8 |
Most output token sampling techniques operate on the probability scores post temperature being applied. The softmax function distorts the underlying logit scores distribution making it hard to know a meaningful top-p/top-k value to set.
|
| 9 |
|
|
@@ -35,6 +37,8 @@ This implementation of Top-NSigma requires the user to pass in a new argument `n
|
|
| 35 |
|
| 36 |
We'll use this to filter out tokens whose logit scores are `n_sigma` number of standard deviations below the max logit score.
|
| 37 |
|
|
|
|
|
|
|
| 38 |
## Output Type changes
|
| 39 |
(none)
|
| 40 |
|
|
@@ -48,6 +52,17 @@ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", devic
|
|
| 48 |
|
| 49 |
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
| 50 |
# There is a print message hardcoded in the custom generation method
|
| 51 |
-
gen_out = model.generate(**inputs,
|
| 52 |
print(tokenizer.batch_decode(gen_out))
|
| 53 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
- custom_generate
|
| 5 |
---
|
| 6 |
## Overview
|
| 7 |
+
This generation sampling method is based on the paper [Top-N Sigma: A Simple and Effective Sampling Method for Language Models](https://openreview.net/pdf/1e221c8eedaf42558abc5dca4637b3378297582b.pdf).
|
| 8 |
+
|
| 9 |
|
| 10 |
Most output token sampling techniques operate on the probability scores post temperature being applied. The softmax function distorts the underlying logit scores distribution making it hard to know a meaningful top-p/top-k value to set.
|
| 11 |
|
|
|
|
| 37 |
|
| 38 |
We'll use this to filter out tokens whose logit scores are `n_sigma` number of standard deviations below the max logit score.
|
| 39 |
|
| 40 |
+
The authors recommend using `n_sigma=1.0` for most use cases, but you can experiment with values in the range **(0.0, 2√3]**.
|
| 41 |
+
|
| 42 |
## Output Type changes
|
| 43 |
(none)
|
| 44 |
|
|
|
|
| 52 |
|
| 53 |
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
| 54 |
# There is a print message hardcoded in the custom generation method
|
| 55 |
+
gen_out = model.generate(**inputs, n_sigma=1.0, custom_generate="Pramodith/topN_sigma_generation", trust_remote_code=True)
|
| 56 |
print(tokenizer.batch_decode(gen_out))
|
| 57 |
```
|
| 58 |
+
|
| 59 |
+
### Citation
|
| 60 |
+
```bibtex
|
| 61 |
+
@inproceedings{tang2025top,
|
| 62 |
+
title={Top-n𝜎: Eliminating Noise in Logit Space for Robust Token Sampling of LLM},
|
| 63 |
+
author={Tang, Chenxia and Liu, Jianchun and Xu, Hongli and Huang, Liusheng},
|
| 64 |
+
booktitle={Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
|
| 65 |
+
pages={10758--10774},
|
| 66 |
+
year={2025}
|
| 67 |
+
}
|
| 68 |
+
```
|
custom_generate/generate.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
-
def top_n_sigma_sampling(logits, temperature, n_sigma
|
| 4 |
"""
|
| 5 |
Perform topN-sigma sampling on the logits.
|
| 6 |
|
| 7 |
Args:
|
| 8 |
logits (torch.Tensor): The logits from the model of shape (batch_size, vocab_size).
|
| 9 |
temperature (float): The temperature to apply to the logits.
|
| 10 |
-
n_sigma (
|
| 11 |
|
| 12 |
Returns:
|
| 13 |
torch.Tensor: The filtered logits after applying topN-sigma sampling.
|
|
@@ -20,7 +20,7 @@ def top_n_sigma_sampling(logits, temperature, n_sigma=4):
|
|
| 20 |
return filtered_logits
|
| 21 |
|
| 22 |
@torch.inference_mode()
|
| 23 |
-
def generate(model, input_ids, generation_config=None, n_sigma=
|
| 24 |
"""
|
| 25 |
Generate text using topN-sigma sampling based on the paper:
|
| 26 |
https://openreview.net/pdf/1e221c8eedaf42558abc5dca4637b3378297582b.pdf
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
+
def top_n_sigma_sampling(logits:torch.Tensor, temperature:float, n_sigma:float) -> torch.Tensor:
|
| 4 |
"""
|
| 5 |
Perform topN-sigma sampling on the logits.
|
| 6 |
|
| 7 |
Args:
|
| 8 |
logits (torch.Tensor): The logits from the model of shape (batch_size, vocab_size).
|
| 9 |
temperature (float): The temperature to apply to the logits.
|
| 10 |
+
n_sigma (float): The number of standard deviations to use for filtering.
|
| 11 |
|
| 12 |
Returns:
|
| 13 |
torch.Tensor: The filtered logits after applying topN-sigma sampling.
|
|
|
|
| 20 |
return filtered_logits
|
| 21 |
|
| 22 |
@torch.inference_mode()
|
| 23 |
+
def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwargs):
|
| 24 |
"""
|
| 25 |
Generate text using topN-sigma sampling based on the paper:
|
| 26 |
https://openreview.net/pdf/1e221c8eedaf42558abc5dca4637b3378297582b.pdf
|