File size: 3,527 Bytes
e24d6f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import time
from tokenizers import Tokenizer
from miniGPT import MiniGPT

# --- 1. Load tokenizer and model ---
tokenizer = Tokenizer.from_file("wordlevel.json")
vocab_size = tokenizer.get_vocab_size()

# Set model parameters to match your trained model
model = MiniGPT(
    vocab_size=vocab_size,
    embed_dim=128,
    num_heads=4,
    ff_dim=512,
    num_layers=4,
    max_seq_len=128
)
checkpoint_path = "model_checkpoint_step20000.pt"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# --- 2. Show model parameter count ---
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

# --- 3. Sampling helpers ---

def top_k_logits(logits, k):
    """Keep only top-k tokens with highest probability."""
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1].unsqueeze(1)
    logits[logits < min_values] = -float('Inf')
    return logits

def top_p_logits(logits, p=0.9):
    """Keep the smallest set of tokens with cumulative probability >= p."""
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    for batch in range(logits.size(0)):
        remove_ids = sorted_indices[batch][sorted_indices_to_remove[batch]]
        logits[batch, remove_ids] = -float('Inf')
    
    return logits

# --- 4. Streaming generation function ---
def generate_stream(

    model, tokenizer, prompt, 

    max_new_tokens=50, 

    temperature=1.0, 

    top_k=None, 

    top_p=None,

    repetition_penalty=2.0

):
    idx = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long)
    generated = []
    start_time = time.time()

    with torch.no_grad():
        for _ in range(max_new_tokens):
            if idx.shape[1] >= model.max_seq_len:
                break

            logits = model(idx)
            logits = logits[:, -1, :] / temperature

            # Apply repetition penalty
            for token_id in set(generated):
                logits[0, token_id] /= repetition_penalty

            # Apply Top-K and/or Top-P filtering
            if top_k is not None:
                logits = top_k_logits(logits, top_k)
            if top_p is not None:
                logits = top_p_logits(logits, top_p)

            probs = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, next_id], dim=1)
            generated.append(next_id.item())
            print(tokenizer.decode([next_id.item()]), end=' ', flush=True)

    elapsed = time.time() - start_time
    tps = len(generated) / elapsed if elapsed > 0 else 0
    print(f"\n[Generated {len(generated)} tokens in {elapsed:.2f} seconds | {tps:.2f} tokens/sec]")
    return idx

# --- 5. Main input loop ---
while True:
    prompt = input("\nEnter your prompt (or type 'exit' to quit): ")
    if prompt.lower() == 'exit':
        break

    print("\nStreaming output:")
    generate_stream(
        model, tokenizer, prompt, 
        max_new_tokens=90,
        temperature=2.0,
        top_k=100,
        top_p=0.9,
        repetition_penalty=1.8
    )