Pramodith commited on
Commit
95d9c07
·
1 Parent(s): 6445bb3

Update Readme.md

Browse files
Files changed (3) hide show
  1. .gitignore +0 -1
  2. README.md +16 -1
  3. custom_generate/generate.py +3 -3
.gitignore DELETED
@@ -1 +0,0 @@
1
- .env
 
 
README.md CHANGED
@@ -4,6 +4,8 @@ tags:
4
  - custom_generate
5
  ---
6
  ## Overview
 
 
7
 
8
  Most output token sampling techniques operate on the probability scores post temperature being applied. The softmax function distorts the underlying logit scores distribution making it hard to know a meaningful top-p/top-k value to set.
9
 
@@ -35,6 +37,8 @@ This implementation of Top-NSigma requires the user to pass in a new argument `n
35
 
36
  We'll use this to filter out tokens whose logit scores are `n_sigma` number of standard deviations below the max logit score.
37
 
 
 
38
  ## Output Type changes
39
  (none)
40
 
@@ -48,6 +52,17 @@ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", devic
48
 
49
  inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
50
  # There is a print message hardcoded in the custom generation method
51
- gen_out = model.generate(**inputs, left_padding=5, custom_generate="Pramodith/topN_sigma_generation", trust_remote_code=True)
52
  print(tokenizer.batch_decode(gen_out))
53
  ```
 
 
 
 
 
 
 
 
 
 
 
 
4
  - custom_generate
5
  ---
6
  ## Overview
7
+ This generation sampling method is based on the paper [Top-N Sigma: A Simple and Effective Sampling Method for Language Models](https://openreview.net/pdf/1e221c8eedaf42558abc5dca4637b3378297582b.pdf).
8
+
9
 
10
  Most output token sampling techniques operate on the probability scores post temperature being applied. The softmax function distorts the underlying logit scores distribution making it hard to know a meaningful top-p/top-k value to set.
11
 
 
37
 
38
  We'll use this to filter out tokens whose logit scores are `n_sigma` number of standard deviations below the max logit score.
39
 
40
+ The authors recommend using `n_sigma=1.0` for most use cases, but you can experiment with values in the range **(0.0, 2√3]**.
41
+
42
  ## Output Type changes
43
  (none)
44
 
 
52
 
53
  inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
54
  # There is a print message hardcoded in the custom generation method
55
+ gen_out = model.generate(**inputs, n_sigma=1.0, custom_generate="Pramodith/topN_sigma_generation", trust_remote_code=True)
56
  print(tokenizer.batch_decode(gen_out))
57
  ```
58
+
59
+ ### Citation
60
+ ```bibtex
61
+ @inproceedings{tang2025top,
62
+ title={Top-n𝜎: Eliminating Noise in Logit Space for Robust Token Sampling of LLM},
63
+ author={Tang, Chenxia and Liu, Jianchun and Xu, Hongli and Huang, Liusheng},
64
+ booktitle={Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
65
+ pages={10758--10774},
66
+ year={2025}
67
+ }
68
+ ```
custom_generate/generate.py CHANGED
@@ -1,13 +1,13 @@
1
  import torch
2
 
3
- def top_n_sigma_sampling(logits, temperature, n_sigma=4):
4
  """
5
  Perform topN-sigma sampling on the logits.
6
 
7
  Args:
8
  logits (torch.Tensor): The logits from the model of shape (batch_size, vocab_size).
9
  temperature (float): The temperature to apply to the logits.
10
- n_sigma (int): The number of standard deviations to use for filtering.
11
 
12
  Returns:
13
  torch.Tensor: The filtered logits after applying topN-sigma sampling.
@@ -20,7 +20,7 @@ def top_n_sigma_sampling(logits, temperature, n_sigma=4):
20
  return filtered_logits
21
 
22
  @torch.inference_mode()
23
- def generate(model, input_ids, generation_config=None, n_sigma=4, **kwargs):
24
  """
25
  Generate text using topN-sigma sampling based on the paper:
26
  https://openreview.net/pdf/1e221c8eedaf42558abc5dca4637b3378297582b.pdf
 
1
  import torch
2
 
3
+ def top_n_sigma_sampling(logits:torch.Tensor, temperature:float, n_sigma:float) -> torch.Tensor:
4
  """
5
  Perform topN-sigma sampling on the logits.
6
 
7
  Args:
8
  logits (torch.Tensor): The logits from the model of shape (batch_size, vocab_size).
9
  temperature (float): The temperature to apply to the logits.
10
+ n_sigma (float): The number of standard deviations to use for filtering.
11
 
12
  Returns:
13
  torch.Tensor: The filtered logits after applying topN-sigma sampling.
 
20
  return filtered_logits
21
 
22
  @torch.inference_mode()
23
+ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwargs):
24
  """
25
  Generate text using topN-sigma sampling based on the paper:
26
  https://openreview.net/pdf/1e221c8eedaf42558abc5dca4637b3378297582b.pdf