QED-75M_artifacts / scripts /generate.py
levossadtchi's picture
Add files using upload-large-folder tool
9847679 verified
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import torch
from tokenizers import Tokenizer
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT / "src"))
from sllm.checkpoint import load_checkpoint
from sllm.config import ModelConfig, load_json
from sllm.model import SLLMForCausalLM
from sllm.utils import get_device, setup_logger
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Generate text from a trained checkpoint.")
parser.add_argument("--checkpoint", required=True, help="Path to model checkpoint.")
parser.add_argument("--tokenizer-dir", required=True, help="Directory with tokenizer.json.")
parser.add_argument("--prompt", required=True, help="Prompt text.")
parser.add_argument("--max-new-tokens", type=int, default=128)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-k", type=int, default=50)
parser.add_argument("--model-config", required=False, help="Optional path to model config JSON.")
return parser
def main() -> None:
args = build_parser().parse_args()
logger, log_path = setup_logger("sllm.generate", Path("outputs/generate"), "generate")
logger.info("Generation started")
logger.info("Log file: %s", log_path)
logger.info(
"Arguments | checkpoint=%s tokenizer_dir=%s max_new_tokens=%s temperature=%s top_k=%s model_config=%s",
args.checkpoint,
args.tokenizer_dir,
args.max_new_tokens,
args.temperature,
args.top_k,
args.model_config,
)
device = get_device()
tokenizer = Tokenizer.from_file(str(Path(args.tokenizer_dir) / "tokenizer.json"))
tokenizer_meta = load_json(Path(args.tokenizer_dir) / "tokenizer_meta.json")
specials = tokenizer_meta["special_tokens"]
payload = load_checkpoint(args.checkpoint, map_location=device)
if args.model_config:
model_config = ModelConfig.from_dict(load_json(args.model_config))
else:
model_config = ModelConfig.from_dict(payload["model_config"])
model = SLLMForCausalLM(model_config).to(device)
model.load_state_dict(payload["model"])
model.eval()
prompt_ids = [int(specials["bos_token_id"])] + tokenizer.encode(
args.prompt,
add_special_tokens=False,
).ids
input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
eos_token_id=int(specials["eos_token_id"]),
)
decoded = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=False)
logger.info("Generation finished | prompt_tokens=%s output_tokens=%s", len(prompt_ids), output_ids.shape[1])
print(decoded)
if __name__ == "__main__":
main()