File size: 3,531 Bytes
ca7da53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
generate.py
===========
Interactive text generation with the trained MiniLM model.

Type a prompt and the model will complete it.
Type 'quit' or press Ctrl+C to exit.

Author  : André Costa
License : MIT

Usage:
    python3 generate.py
    python3 generate.py --max-tokens 100
    python3 generate.py --temperature 0.9 --top-k 50
"""

import argparse
import torch
from transformer import MiniLM, ModelConfig
from bpe_tokenizer import BPETokenizer


def load_model(checkpoint_path: str, tokenizer_path: str):
    """Load the trained model and tokenizer."""

    print("Loading tokenizer...")
    tokenizer = BPETokenizer.load(tokenizer_path)

    print("Loading model...")
    ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

    cfg_dict = ckpt["model_config"]
    cfg_dict.pop("d_head", None)
    config = ModelConfig(**cfg_dict)

    model = MiniLM(config)

    state_dict = ckpt["model_state"]
    if any(k.startswith("_orig_mod.") for k in state_dict):
        state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

    model.load_state_dict(state_dict)
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model  = model.to(device)

    print(f"Model ready — {config.n_params / 1e6:.1f}M parameters | device: {device}")
    print(f"Vocab: {config.vocab_size} tokens | Context: {config.seq_len} tokens\n")

    return model, tokenizer, device


def generate(
    model,
    tokenizer,
    device,
    prompt: str,
    max_new_tokens: int,
    temperature: float,
    top_k: int,
    top_p: float,
) -> str:
    """Generate text from a prompt."""
    input_ids = torch.tensor(
        [tokenizer.encode(prompt)],
        dtype=torch.long,
        device=device,
    )

    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )

    return tokenizer.decode(output[0].tolist())


def main():
    parser = argparse.ArgumentParser(description="MiniLM — Interactive text generation")
    parser.add_argument("--checkpoint",   type=str,   default="./checkpoints/best_model.pt")
    parser.add_argument("--tokenizer",    type=str,   default="./tokenizer")
    parser.add_argument("--max-tokens",   type=int,   default=80)
    parser.add_argument("--temperature",  type=float, default=0.8)
    parser.add_argument("--top-k",        type=int,   default=50)
    parser.add_argument("--top-p",        type=float, default=0.9)
    args = parser.parse_args()

    model, tokenizer, device = load_model(args.checkpoint, args.tokenizer)

    print("=" * 55)
    print("  MiniLM — Text Generation")
    print("  Type a prompt and press Enter.")
    print("  Type 'quit' to exit.")
    print("=" * 55)
    print()

    while True:
        try:
            prompt = input("Prompt: ").strip()
        except (KeyboardInterrupt, EOFError):
            print("\nGoodbye!")
            break

        if not prompt:
            continue

        if prompt.lower() in ("quit", "exit", "q"):
            print("Goodbye!")
            break

        result = generate(
            model, tokenizer, device,
            prompt=prompt,
            max_new_tokens=args.max_tokens,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
        )

        print(f"\n{result}\n")
        print("-" * 55)


if __name__ == "__main__":
    main()