| import argparse |
| import os |
| import torch |
| import torch.nn.functional as F |
| from tokenizers import Tokenizer |
| from model.gpt_model import GPTModel |
| from data import utils |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate text using a trained OpenGPT model.") |
| parser.add_argument("--model", type=str, required=True, help="Path to the model checkpoint (.pt file).") |
| parser.add_argument("--config", type=str, required=True, help="Path to the model config file (YAML/JSON).") |
| parser.add_argument("--tokenizer", type=str, required=True, help="Path to the tokenizer directory or tokenizer.json file.") |
| parser.add_argument("--prompt", type=str, required=True, help="Input prompt text to start generation.") |
| parser.add_argument("--max_length", type=int, default=50, help="Maximum number of tokens to generate.") |
| parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature (higher = more random).") |
| parser.add_argument("--top_k", type=int, default=0, help="Top-k sampling (0 for no top-k filtering).") |
| parser.add_argument("--greedy", action="store_true", help="Use greedy decoding instead of sampling.") |
| args = parser.parse_args() |
|
|
| |
| config = utils.load_config(args.config) |
| model_conf = config.get("model", {}) |
| vocab_size = model_conf["vocab_size"] |
| max_pos = model_conf.get("max_position_embeddings", 512) |
| hidden_dim = model_conf.get("embedding_dim", 768) |
| n_layers = model_conf.get("n_layers", 12) |
| n_heads = model_conf.get("n_heads", 12) |
| dropout = model_conf.get("dropout", 0.0) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = GPTModel(vocab_size=vocab_size, max_position_embeddings=max_pos, |
| n_layers=n_layers, n_heads=n_heads, hidden_dim=hidden_dim, |
| dropout=dropout).to(device) |
| utils.load_checkpoint(model, optimizer=None, filepath=args.model, device=device) |
| model.eval() |
|
|
| |
| tk_path = args.tokenizer |
| if os.path.isdir(tk_path): |
| tk_path = os.path.join(tk_path, "tokenizer.json") |
| tokenizer = Tokenizer.from_file(tk_path) |
|
|
| |
| input_ids = tokenizer.encode(args.prompt).ids |
| |
| if len(input_ids) > max_pos: |
| input_ids = input_ids[-max_pos:] |
| generated_ids = input_ids[:] |
|
|
| |
| for _ in range(args.max_length): |
| |
| inp = torch.tensor([generated_ids], dtype=torch.long, device=device) |
| with torch.no_grad(): |
| outputs = model(inp) |
| logits = outputs[0, -1, :] |
| if args.greedy: |
| next_token_id = int(torch.argmax(logits)) |
| else: |
| |
| if args.temperature != 1.0: |
| logits = logits / args.temperature |
| if args.top_k and args.top_k > 0: |
| |
| top_values, top_indices = torch.topk(logits, k=args.top_k) |
| probabilities = F.softmax(top_values, dim=-1) |
| next_token_index = int(torch.multinomial(probabilities, num_samples=1)) |
| next_token_id = int(top_indices[next_token_index]) |
| else: |
| probabilities = F.softmax(logits, dim=-1) |
| next_token_id = int(torch.multinomial(probabilities, num_samples=1)) |
| generated_ids.append(next_token_id) |
| |
|
|
| |
| output_text = tokenizer.decode(generated_ids) |
| print(output_text) |
|
|
| if __name__ == "__main__": |
| main() |
|
|