File size: 1,368 Bytes
2fc11ed
 
f4a4361
2fc11ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602d956
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import math
import random



def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)

def top_k(logits, thres = 0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs



def generate_text(model, tokenizer, prompt: torch.Tensor, seq_len: int):
    prompt_seq_len = prompt.shape[-1]

    h_states = None
    logits   = None
    text     = ""

    for i in range(prompt_seq_len):
        tok = prompt[:, i:i+1] # (1, 1)
        logits, h_states = model.step(tok, h_states)

    for _ in range(seq_len):
        logits = top_k(logits, thres=.9)
        token  = gumbel_sample(logits, temperature=.7, dim=-1)[0]

        logits, h_states = model.step(token, h_states)

        token = tokenizer.decode(token.item())
        text += token

    return text



def generate_name():
    prefix = "mingru"
    random_number = random.randint(0, 0xFFFF)
    hex_code = f"{random_number:04x}"
    unique_name = f"{prefix}-{hex_code}"
    return unique_name