| """ |
| Text generation (inference) script with temperature + top-p / top-k sampling. |
| |
| Usage: |
| python eval/generate.py \ |
| --checkpoint checkpoints/checkpoint-0100000 \ |
| --prompt "Once upon a time" \ |
| --max_new_tokens 200 \ |
| --temperature 0.8 \ |
| --top_p 0.9 \ |
| --top_k 50 \ |
| --device cuda:0 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
| from typing import Generator |
|
|
| import torch |
| import torch.nn.functional as F |
| from model.transformer import LLM |
| from tokenizers import Tokenizer |
|
|
|
|
| |
| |
| |
|
|
| def top_p_filtering( |
| logits: torch.Tensor, |
| top_p: float = 0.9, |
| top_k: int = 0, |
| filter_value: float = float("-inf"), |
| ) -> torch.Tensor: |
| """ |
| Apply top-k and / or top-p (nucleus) filtering to a logits tensor. |
| |
| Args: |
| logits: 1-D or 2-D tensor of raw (un-normalised) logits. |
| Shape: [vocab_size] or [batch, vocab_size]. |
| top_k: Keep only the top-k tokens (0 = disabled). |
| top_p: Keep the smallest set of tokens whose cumulative |
| probability is >= top_p (1.0 = disabled). |
| filter_value: Value assigned to filtered positions (−inf by default). |
| |
| Returns: |
| Filtered logits with the same shape as input. |
| """ |
| |
| if logits.dim() == 1: |
| logits = logits.unsqueeze(0) |
| squeeze_output = True |
| else: |
| squeeze_output = False |
|
|
| |
| if top_k > 0: |
| k = min(top_k, logits.size(-1)) |
| |
| kth_values = torch.topk(logits, k, dim=-1).values[:, -1, None] |
| logits = logits.masked_fill(logits < kth_values, filter_value) |
|
|
| |
| if 0.0 < top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
| |
| |
| |
| sorted_indices_to_remove = cumulative_probs - F.softmax( |
| sorted_logits, dim=-1 |
| ) >= top_p |
| sorted_logits = sorted_logits.masked_fill( |
| sorted_indices_to_remove, filter_value |
| ) |
| |
| logits = torch.zeros_like(logits).scatter_( |
| -1, sorted_indices, sorted_logits |
| ) |
|
|
| if squeeze_output: |
| logits = logits.squeeze(0) |
|
|
| return logits |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def generate( |
| model: torch.nn.Module, |
| tokenizer: Tokenizer, |
| prompt: str, |
| max_new_tokens: int = 200, |
| temperature: float = 0.8, |
| top_p: float = 0.9, |
| top_k: int = 50, |
| device: str = "cuda:0", |
| ) -> Generator[str, None, None]: |
| """ |
| Auto-regressive token generation with streaming output. |
| |
| Yields decoded string fragments (one token at a time) so callers can |
| stream output to stdout without waiting for the full sequence. |
| |
| Args: |
| model: A causal LM whose forward pass returns logits |
| (last dim = vocab_size). |
| tokenizer: Matching tokenizer; must expose encode / decode. |
| prompt: Text prompt to condition on. |
| max_new_tokens: Maximum number of new tokens to generate. |
| temperature: Softmax temperature (1.0 = neutral, <1 = sharper). |
| top_p: Nucleus sampling probability threshold. |
| top_k: Top-K token candidates (0 = disabled). |
| device: Torch device string. |
| |
| Yields: |
| Decoded string for each newly generated token. |
| """ |
| model.eval() |
|
|
| |
| input_ids = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long, device=device) |
| eos_token_id: int | None = tokenizer.token_to_id("</s>") |
|
|
| |
| generated_ids = input_ids |
|
|
| for _ in range(max_new_tokens): |
| |
| logits_all, _ = model(generated_ids) |
| logits: torch.Tensor = logits_all[:, -1, :] |
|
|
| |
| if temperature != 1.0: |
| logits = logits / max(temperature, 1e-8) |
|
|
| |
| logits = top_p_filtering(logits, top_p=top_p, top_k=top_k) |
|
|
| |
| probs = F.softmax(logits, dim=-1) |
| next_token_id = torch.multinomial(probs, num_samples=1) |
|
|
| generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
|
|
| |
| token_str: str = tokenizer.decode([next_token_id.item()]) |
| yield token_str |
|
|
| |
| if eos_token_id is not None and next_token_id.item() == eos_token_id: |
| break |
|
|
|
|
| |
| |
| |
|
|
| def load_model_and_tokenizer( |
| checkpoint_dir: str, device: str |
| ) -> tuple[torch.nn.Module, Tokenizer]: |
| """ |
| Load a model and tokenizer from a checkpoint directory. |
| |
| Expects: |
| - <checkpoint_dir>/model.pt — model weights |
| - <checkpoint_dir>/config.yaml — LMConfig |
| - <checkpoint_dir>/tokenizer.json — HuggingFace tokenizers format |
| """ |
| ckpt_path = Path(checkpoint_dir) |
| if not ckpt_path.exists(): |
| raise FileNotFoundError(f"Checkpoint directory not found: {ckpt_path}") |
|
|
| print(f"Loading model from: {ckpt_path}") |
| model = LLM.from_pretrained(str(ckpt_path)).to(device=device, dtype=torch.float16) |
| model.eval() |
|
|
| tokenizer_path = ckpt_path / "tokenizer.json" |
| if not tokenizer_path.exists(): |
| |
| tokenizer_path = Path("tokenizer/korean_sp/tokenizer.json") |
| print(f"Loading tokenizer from: {tokenizer_path}") |
| tokenizer = Tokenizer.from_file(str(tokenizer_path)) |
|
|
| return model, tokenizer |
|
|
|
|
| |
| |
| |
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Generate text from a trained LLM checkpoint." |
| ) |
| parser.add_argument( |
| "--checkpoint", |
| required=True, |
| help="Path to the checkpoint directory.", |
| ) |
| parser.add_argument( |
| "--prompt", |
| required=True, |
| help="Input prompt text.", |
| ) |
| parser.add_argument( |
| "--max_new_tokens", |
| type=int, |
| default=200, |
| help="Maximum number of new tokens to generate (default: 200).", |
| ) |
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=0.8, |
| help="Sampling temperature (default: 0.8).", |
| ) |
| parser.add_argument( |
| "--top_p", |
| type=float, |
| default=0.9, |
| help="Top-p nucleus sampling threshold (default: 0.9).", |
| ) |
| parser.add_argument( |
| "--top_k", |
| type=int, |
| default=50, |
| help="Top-k token candidates; 0 disables top-k (default: 50).", |
| ) |
| parser.add_argument( |
| "--device", |
| default="cuda:0", |
| help="Torch device to run inference on (default: cuda:0).", |
| ) |
| return parser.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| model, tokenizer = load_model_and_tokenizer(args.checkpoint, args.device) |
|
|
| num_params = sum(p.numel() for p in model.parameters()) |
| print(f"Model parameters: {num_params / 1e6:.1f}M") |
| print(f"\nPrompt: {args.prompt!r}") |
| print("-" * 60) |
| print(args.prompt, end="", flush=True) |
|
|
| generated_tokens = 0 |
| for token_str in generate( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=args.prompt, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| top_k=args.top_k, |
| device=args.device, |
| ): |
| print(token_str, end="", flush=True) |
| generated_tokens += 1 |
|
|
| print() |
| print("-" * 60) |
| print(f"Generated {generated_tokens} token(s).") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|