import torch import math import random def log(t, eps = 1e-20): return torch.log(t.clamp(min = eps)) def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) return -log(-log(noise)) def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True): return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim) def top_k(logits, thres = 0.9): k = math.ceil((1 - thres) * logits.shape[-1]) val, ind = torch.topk(logits, k) probs = torch.full_like(logits, float('-inf')) probs.scatter_(-1, ind, val) return probs def generate_text(model, tokenizer, prompt: torch.Tensor, seq_len: int): prompt_seq_len = prompt.shape[-1] h_states = None logits = None text = "" for i in range(prompt_seq_len): tok = prompt[:, i:i+1] # (1, 1) logits, h_states = model.step(tok, h_states) for _ in range(seq_len): logits = top_k(logits, thres=.9) token = gumbel_sample(logits, temperature=.7, dim=-1)[0] logits, h_states = model.step(token, h_states) token = tokenizer.decode(token.item()) text += token return text def generate_name(): prefix = "mingru" random_number = random.randint(0, 0xFFFF) hex_code = f"{random_number:04x}" unique_name = f"{prefix}-{hex_code}" return unique_name