File size: 2,838 Bytes
7d03cad
 
95d9c07
7d03cad
 
 
 
 
 
95d9c07
7d03cad
 
 
 
 
 
 
 
2c23dfb
7d03cad
 
 
95d9c07
7d03cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c23dfb
7d03cad
cf9e688
7d03cad
 
d031870
cf9e688
 
 
7d03cad
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch

def top_n_sigma_sampling(logits:torch.Tensor, temperature:float, n_sigma:float) -> torch.Tensor:
    """
    Perform topN-sigma sampling on the logits.
    
    Args:
        logits (torch.Tensor): The logits from the model of shape (batch_size, vocab_size).
        temperature (float): The temperature to apply to the logits.
        n_sigma (float): The number of standard deviations to use for filtering.
    
    Returns:
        torch.Tensor: The filtered logits after applying topN-sigma sampling.
    """
    logits = logits / temperature  # Apply temperature scaling
    max_logit_score = torch.max(logits, dim=-1, keepdim=True).values
    std = torch.std(logits, dim=-1, keepdim=True)
    threshold = max_logit_score - n_sigma * std
    filtered_logits = torch.where(logits >= threshold, logits, torch.full_like(logits, float('-inf')))
    return filtered_logits

@torch.inference_mode()
def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwargs):
    """
    Generate text using topN-sigma sampling based on the paper: 
    https://openreview.net/pdf/1e221c8eedaf42558abc5dca4637b3378297582b.pdf
    
    
    @inproceedings{tang2025top,
        title={Top-n𝜎: Eliminating Noise in Logit Space for Robust Token Sampling of LLM},
        author={Tang, Chenxia and Liu, Jianchun and Xu, Hongli and Huang, Liusheng},
        booktitle={Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
        pages={10758--10774},
        year={2025}
    }

    Args:
        model: The model to use for generation.
        input_ids (torch.Tensor): The input tensor of shape (batch_size, sequence_length).
        generation_config (optional): Configuration for generation, such as max_length, pad_token_id,
                                      and max_new_tokens.
        n_sigma (float): The number of standard deviations to use for topN-sigma sampling.
        **kwargs: Additional keyword arguments.
    """
    generation_config = generation_config or model.generation_config  # default to the model generation config
    cur_length = input_ids.shape[1]
    if generation_config.max_new_tokens:
        max_length = cur_length + generation_config.max_new_tokens
    else:
        max_length = generation_config.max_length

    while cur_length < max_length:
        logits = model(input_ids).logits
        logits = logits[:, -1, :]
        # Filter logits using topN-sigma sampling
        filtered_logits = top_n_sigma_sampling(logits, generation_config.temperature, n_sigma=n_sigma)
        # sample from the filtered logits
        next_tokens = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
        input_ids = torch.cat((input_ids, next_tokens), dim=-1)
        cur_length += 1

    return input_ids