OpenGPT / generate.py
VolodymyrPugachov's picture
Upload 17 files
6810eb1 verified
import argparse
import os
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from model.gpt_model import GPTModel
from data import utils
def main():
parser = argparse.ArgumentParser(description="Generate text using a trained OpenGPT model.")
parser.add_argument("--model", type=str, required=True, help="Path to the model checkpoint (.pt file).")
parser.add_argument("--config", type=str, required=True, help="Path to the model config file (YAML/JSON).")
parser.add_argument("--tokenizer", type=str, required=True, help="Path to the tokenizer directory or tokenizer.json file.")
parser.add_argument("--prompt", type=str, required=True, help="Input prompt text to start generation.")
parser.add_argument("--max_length", type=int, default=50, help="Maximum number of tokens to generate.")
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature (higher = more random).")
parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling (0 for no top-k filtering).")
parser.add_argument("--greedy", action="store_true", help="Use greedy decoding instead of sampling.")
args = parser.parse_args()
# Load model configuration
config = utils.load_config(args.config)
model_conf = config.get("model", {})
vocab_size = model_conf["vocab_size"]
max_pos = model_conf.get("max_position_embeddings", 512)
hidden_dim = model_conf.get("embedding_dim", 768)
n_layers = model_conf.get("n_layers", 12)
n_heads = model_conf.get("n_heads", 12)
dropout = model_conf.get("dropout", 0.0)
# Initialize model and load weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPTModel(vocab_size=vocab_size, max_position_embeddings=max_pos,
n_layers=n_layers, n_heads=n_heads, hidden_dim=hidden_dim,
dropout=dropout).to(device)
utils.load_checkpoint(model, optimizer=None, filepath=args.model, device=device)
model.eval()
# Load tokenizer
tk_path = args.tokenizer
if os.path.isdir(tk_path):
tk_path = os.path.join(tk_path, "tokenizer.json")
tokenizer = Tokenizer.from_file(tk_path)
# Encode prompt
input_ids = tokenizer.encode(args.prompt).ids
# Truncate prompt if it exceeds model's context length
if len(input_ids) > max_pos:
input_ids = input_ids[-max_pos:]
generated_ids = input_ids[:] # start with prompt tokens
# Generate tokens iteratively
for _ in range(args.max_length):
# Prepare input tensor for current context
inp = torch.tensor([generated_ids], dtype=torch.long, device=device)
with torch.no_grad():
outputs = model(inp)
logits = outputs[0, -1, :] # logits for the last token in sequence
if args.greedy:
next_token_id = int(torch.argmax(logits))
else:
# Apply temperature scaling
if args.temperature != 1.0:
logits = logits / args.temperature
if args.top_k and args.top_k > 0:
# Top-k sampling: select top_k tokens and sample among them
top_values, top_indices = torch.topk(logits, k=args.top_k)
probabilities = F.softmax(top_values, dim=-1)
next_token_index = int(torch.multinomial(probabilities, num_samples=1))
next_token_id = int(top_indices[next_token_index])
else:
probabilities = F.softmax(logits, dim=-1)
next_token_id = int(torch.multinomial(probabilities, num_samples=1))
generated_ids.append(next_token_id)
# If an EOS token was defined and is produced, you could break out (not applicable by default)
# Decode token IDs to string
output_text = tokenizer.decode(generated_ids)
print(output_text)
if __name__ == "__main__":
main()