"""Generate a few text samples from a saved checkpoint.""" from __future__ import annotations import argparse import json import sys from pathlib import Path import torch REPO_ROOT = Path(__file__).resolve().parents[2] SRC_ROOT = REPO_ROOT / "src" if str(SRC_ROOT) not in sys.path: sys.path.insert(0, str(SRC_ROOT)) from taoTrain.checkpointing.checkpoint import CheckpointManager from taoTrain.config import ModelConfig from taoTrain.inference.inferencer import Inferencer from taoTrain.models import get_model def clear_kernel_caches(model) -> None: for module in model.modules(): clear = getattr(module, "clear_kernel_cache", None) if callable(clear): clear() def generate_once( model, tokenizer, prompt: str, *, device: torch.device, max_new_tokens: int, temperature: float, top_p: float, dtype: torch.dtype, ) -> str: input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) generated = [] eos_token_id = getattr(tokenizer, "eos_token_id", None) model.eval() device_type = "cuda" if device.type == "cuda" else "cpu" autocast_enabled = device.type == "cuda" and dtype in {torch.float16, torch.bfloat16} with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=dtype, enabled=autocast_enabled): for _ in range(max_new_tokens): clear_kernel_caches(model) outputs = model(input_ids=input_ids, attention_mask=torch.ones_like(input_ids), labels=None) logits = outputs["logits"][:, -1, :] / max(temperature, 1e-6) if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_probs = torch.softmax(sorted_logits, dim=-1) cumulative = torch.cumsum(sorted_probs, dim=-1) remove = cumulative > top_p remove[..., 1:] = remove[..., :-1].clone() remove[..., 0] = False indices_to_remove = sorted_indices[remove] logits[0, indices_to_remove] = float("-inf") probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) token_id = int(next_token.item()) if eos_token_id is not None and token_id == eos_token_id: break generated.append(token_id) input_ids = torch.cat([input_ids, next_token], dim=-1) clear_kernel_caches(model) return tokenizer.decode(generated, skip_special_tokens=True) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True) parser.add_argument("--tokenizer-path", required=True) parser.add_argument("--output", required=True) parser.add_argument("--prompt", action="append", default=[]) parser.add_argument("--max-new-tokens", type=int, default=80) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top-p", type=float, default=0.9) parser.add_argument("--device", default="cuda") parser.add_argument("--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16") args = parser.parse_args() prompts = args.prompt or [ "The purpose of artificial intelligence is", "In a small village,", "Hello, who are you?", ] device = torch.device(args.device if args.device == "cpu" or torch.cuda.is_available() else "cpu") dtype = { "float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16, }[args.dtype] tokenizer = Inferencer._load_tokenizer(args.tokenizer_path) checkpoint_path = Path(args.checkpoint) checkpoint = CheckpointManager(checkpoint_path.parent).load(checkpoint_path, device=device) model_config = ModelConfig(**checkpoint.get("config", {}).get("model", {})) model = get_model(model_config, device=device) model.load_state_dict(checkpoint["model_state"], strict=False) samples = [] for prompt in prompts: text = generate_once( model, tokenizer, prompt, device=device, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_p=args.top_p, dtype=dtype, ) samples.append({"prompt": prompt, "completion": text}) result = { "checkpoint": args.checkpoint, "tokenizer_path": args.tokenizer_path, "device": str(device), "dtype": str(dtype), "max_new_tokens": args.max_new_tokens, "temperature": args.temperature, "top_p": args.top_p, "samples": samples, } output = Path(args.output) output.parent.mkdir(parents=True, exist_ok=True) output.write_text(json.dumps(result, indent=2, ensure_ascii=False), encoding="utf-8") print(json.dumps(result, indent=2, ensure_ascii=False)) if __name__ == "__main__": main()