| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| from tokenizers import Tokenizer |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| sys.path.append(str(ROOT / "src")) |
|
|
| from sllm.checkpoint import load_checkpoint |
| from sllm.config import ModelConfig, load_json |
| from sllm.model import SLLMForCausalLM |
| from sllm.utils import get_device, setup_logger |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Generate text from a trained checkpoint.") |
| parser.add_argument("--checkpoint", required=True, help="Path to model checkpoint.") |
| parser.add_argument("--tokenizer-dir", required=True, help="Directory with tokenizer.json.") |
| parser.add_argument("--prompt", required=True, help="Prompt text.") |
| parser.add_argument("--max-new-tokens", type=int, default=128) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top-k", type=int, default=50) |
| parser.add_argument("--model-config", required=False, help="Optional path to model config JSON.") |
| return parser |
|
|
|
|
| def main() -> None: |
| args = build_parser().parse_args() |
| logger, log_path = setup_logger("sllm.generate", Path("outputs/generate"), "generate") |
| logger.info("Generation started") |
| logger.info("Log file: %s", log_path) |
| logger.info( |
| "Arguments | checkpoint=%s tokenizer_dir=%s max_new_tokens=%s temperature=%s top_k=%s model_config=%s", |
| args.checkpoint, |
| args.tokenizer_dir, |
| args.max_new_tokens, |
| args.temperature, |
| args.top_k, |
| args.model_config, |
| ) |
| device = get_device() |
| tokenizer = Tokenizer.from_file(str(Path(args.tokenizer_dir) / "tokenizer.json")) |
| tokenizer_meta = load_json(Path(args.tokenizer_dir) / "tokenizer_meta.json") |
| specials = tokenizer_meta["special_tokens"] |
|
|
| payload = load_checkpoint(args.checkpoint, map_location=device) |
| if args.model_config: |
| model_config = ModelConfig.from_dict(load_json(args.model_config)) |
| else: |
| model_config = ModelConfig.from_dict(payload["model_config"]) |
|
|
| model = SLLMForCausalLM(model_config).to(device) |
| model.load_state_dict(payload["model"]) |
| model.eval() |
|
|
| prompt_ids = [int(specials["bos_token_id"])] + tokenizer.encode( |
| args.prompt, |
| add_special_tokens=False, |
| ).ids |
| input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device) |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| input_ids=input_ids, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| eos_token_id=int(specials["eos_token_id"]), |
| ) |
|
|
| decoded = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=False) |
| logger.info("Generation finished | prompt_tokens=%s output_tokens=%s", len(prompt_ids), output_ids.shape[1]) |
| print(decoded) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|