| import torch | |
| import torch.nn.functional as F | |
| def concat(prev, new): | |
| if prev and prev[-1].isalnum() and new and new[0].isalnum(): | |
| return prev + " " + new | |
| return prev + new | |
| class GPTInfer: | |
| def __init__(self, model, token_encoder, device): | |
| self.model = model | |
| self.token_encoder = token_encoder | |
| self.device = device | |
| self.device_type = 'cuda' if device.startswith('cuda') else 'cpu' | |
| def get_token_length(self, text): | |
| return len(self.token_encoder.encode(text, allowed_special={"<|endoftext|>"})) | |
| def apply_frequency_penalty_and_blocking( | |
| self, | |
| logits, | |
| gen_tokens, | |
| frequency_penalty=0.5, | |
| no_repeat_ngram_size=3, | |
| ): | |
| logits = logits.clone().float() | |
| if frequency_penalty and frequency_penalty > 0.0: | |
| counts = {} | |
| for t in gen_tokens[0].tolist(): | |
| counts[t] = counts.get(t, 0) + 1 | |
| if counts: | |
| vocab_size = logits.shape[-1] | |
| penalty = torch.zeros(vocab_size, dtype=logits.dtype, device=logits.device) | |
| for tok, c in counts.items(): | |
| if 0 <= tok < vocab_size: | |
| penalty[tok] = float(c) * float(frequency_penalty) | |
| logits = logits - penalty.unsqueeze(0) | |
| if no_repeat_ngram_size and no_repeat_ngram_size > 0: | |
| n = no_repeat_ngram_size | |
| cur = gen_tokens[0].tolist() | |
| if len(cur) >= n - 1: | |
| banned_next = set() | |
| for i in range(len(cur) - (n - 1)): | |
| ngram = tuple(cur[i:i + n]) | |
| prefix = tuple(ngram[:-1]) | |
| banned_next.add(ngram[-1]) | |
| last_prefix = tuple(cur[-(n - 1):]) if n > 1 else tuple() | |
| for i in range(len(cur) - (n - 1)): | |
| if tuple(cur[i:i + (n - 1)]) == last_prefix and i + (n - 1) < len(cur): | |
| banned_token = cur[i + (n - 1)] | |
| if 0 <= banned_token < logits.shape[-1]: | |
| logits[0, banned_token] = -1e9 | |
| return logits | |
| def sample_next_token( | |
| self, | |
| logits, | |
| gen_tokens, | |
| seed_rng, | |
| temperature=0.8, | |
| top_k=None, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| frequency_penalty=0.5, | |
| no_repeat_ngram_size=3, | |
| recent_tokens_window=200, | |
| ): | |
| logits = logits.clone().float() | |
| recent = gen_tokens[0, -recent_tokens_window:].tolist() | |
| if repetition_penalty is not None and repetition_penalty != 1.0: | |
| for t in set(recent): | |
| if 0 <= t < logits.shape[-1]: | |
| logits[0, t] /= float(repetition_penalty) | |
| logits = self.apply_frequency_penalty_and_blocking( | |
| logits, | |
| gen_tokens, | |
| frequency_penalty=frequency_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| ) | |
| if temperature is not None and temperature != 1.0: | |
| logits = logits / float(temperature) | |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| if top_k is not None: | |
| k = min(int(top_k), sorted_logits.shape[-1]) | |
| sorted_logits = sorted_logits[:, :k] | |
| sorted_idx = sorted_idx[:, :k] | |
| sorted_probs = sorted_probs[:, :k] | |
| if top_p is not None and 0.0 < top_p < 1.0: | |
| cum_probs = torch.cumsum(sorted_probs, dim=-1) | |
| mask = cum_probs <= top_p | |
| if not mask.any(): | |
| mask[0, 0] = True | |
| keep_count = int(mask.sum(dim=-1).item()) | |
| sorted_probs = sorted_probs[:, :keep_count] | |
| sorted_idx = sorted_idx[:, :keep_count] | |
| sorted_probs = sorted_probs / (sorted_probs.sum(dim=-1, keepdim=True) + 1e-12) | |
| next_index_in_sorted = torch.multinomial(sorted_probs, 1, generator=seed_rng) | |
| next_tok = sorted_idx.gather(-1, next_index_in_sorted) | |
| return int(next_tok.item()) | |
| def generate( | |
| self, | |
| prompt, | |
| max_new_tokens=50, | |
| seed=42, | |
| longer_story=True, | |
| temperature=0.8, | |
| top_k=None, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| frequency_penalty=0.5, | |
| no_repeat_ngram_size=3, | |
| context_window=None, | |
| stream=True, | |
| ): | |
| self.model.eval() | |
| tokens = self.token_encoder.encode(prompt) | |
| if context_window is not None and len(tokens) > context_window: | |
| tokens = tokens[-context_window:] | |
| tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(self.device) | |
| gen_tokens = tokens.clone() | |
| if seed is not None: | |
| sample_rng = torch.Generator(device=self.device).manual_seed(seed) | |
| else: | |
| sample_rng = torch.Generator(device=self.device) | |
| eos_id = self.token_encoder.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})[0] | |
| context_len = self.model.config.context_length | |
| new_tokens_generated = 0 | |
| HARD_MAX_TOTAL = context_len + max_new_tokens + 10 | |
| while new_tokens_generated < max_new_tokens and gen_tokens.shape[1] < HARD_MAX_TOTAL: | |
| if gen_tokens.shape[1] > context_len: | |
| idx_cond = gen_tokens[:, -context_len:] | |
| else: | |
| idx_cond = gen_tokens | |
| with torch.no_grad(): | |
| try: | |
| with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): | |
| logits, _ = self.model(idx_cond) | |
| except Exception: | |
| logits, _ = self.model(idx_cond) | |
| next_logits = logits[:, -1:, :].squeeze(1) | |
| if longer_story and new_tokens_generated < 5: | |
| next_logits[0, eos_id] = next_logits[0, eos_id] / 4.0 | |
| next_token_id = self.sample_next_token( | |
| logits=next_logits, | |
| gen_tokens=gen_tokens, | |
| seed_rng=sample_rng, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| frequency_penalty=frequency_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| recent_tokens_window=200, | |
| ) | |
| if next_token_id == eos_id: | |
| break | |
| next_tok_tensor = torch.tensor([[next_token_id]], dtype=torch.long).to(self.device) | |
| gen_tokens = torch.cat([gen_tokens, next_tok_tensor], dim=1) | |
| new_tokens_generated += 1 | |
| if stream: | |
| yield self.token_encoder.decode([next_token_id], errors='ignore') | |
| if not stream: | |
| yield self.token_encoder.decode(gen_tokens[0, :].tolist(), errors='ignore') | |
| def print_stream( | |
| self, | |
| prompt, | |
| max_new_tokens=200, | |
| seed=42, | |
| longer_story=True, | |
| temperature=0.8, | |
| top_k=None, | |
| top_p=0.9, | |
| repetition_penalty=1.2, | |
| frequency_penalty=0.6, | |
| no_repeat_ngram_size=3, | |
| context_window=512, | |
| ): | |
| text = prompt | |
| last_piece = "" | |
| print(prompt, end="", flush=True) | |
| for piece in self.generate( | |
| prompt, | |
| max_new_tokens=max_new_tokens, | |
| seed=seed, | |
| longer_story=longer_story, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| frequency_penalty=frequency_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| context_window=context_window, | |
| ): | |
| if piece == last_piece: | |
| continue | |
| last_piece = piece | |
| text = concat(text, piece) | |
| print(piece, end="", flush=True) | |
| return text |