Austin207 commited on
Commit
e24d6f1
·
verified ·
1 Parent(s): 6f341e4

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. inference.py +108 -0
  2. miniGPT.py +30 -0
  3. multiheadattention.py +34 -0
  4. transformer.py +24 -0
  5. wordlevel.json +0 -0
inference.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ from tokenizers import Tokenizer
4
+ from miniGPT import MiniGPT
5
+
6
+ # --- 1. Load tokenizer and model ---
7
+ tokenizer = Tokenizer.from_file("wordlevel.json")
8
+ vocab_size = tokenizer.get_vocab_size()
9
+
10
+ # Set model parameters to match your trained model
11
+ model = MiniGPT(
12
+ vocab_size=vocab_size,
13
+ embed_dim=128,
14
+ num_heads=4,
15
+ ff_dim=512,
16
+ num_layers=4,
17
+ max_seq_len=128
18
+ )
19
+ checkpoint_path = "model_checkpoint_step20000.pt"
20
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
21
+ model.load_state_dict(checkpoint['model_state_dict'])
22
+ model.eval()
23
+
24
+ # --- 2. Show model parameter count ---
25
+ num_params = sum(p.numel() for p in model.parameters())
26
+ print(f"Model parameters: {num_params:,}")
27
+
28
+ # --- 3. Sampling helpers ---
29
+
30
+ def top_k_logits(logits, k):
31
+ """Keep only top-k tokens with highest probability."""
32
+ values, _ = torch.topk(logits, k)
33
+ min_values = values[:, -1].unsqueeze(1)
34
+ logits[logits < min_values] = -float('Inf')
35
+ return logits
36
+
37
+ def top_p_logits(logits, p=0.9):
38
+ """Keep the smallest set of tokens with cumulative probability >= p."""
39
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
40
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
41
+
42
+ sorted_indices_to_remove = cumulative_probs > p
43
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
44
+ sorted_indices_to_remove[..., 0] = 0
45
+
46
+ for batch in range(logits.size(0)):
47
+ remove_ids = sorted_indices[batch][sorted_indices_to_remove[batch]]
48
+ logits[batch, remove_ids] = -float('Inf')
49
+
50
+ return logits
51
+
52
+ # --- 4. Streaming generation function ---
53
+ def generate_stream(
54
+ model, tokenizer, prompt,
55
+ max_new_tokens=50,
56
+ temperature=1.0,
57
+ top_k=None,
58
+ top_p=None,
59
+ repetition_penalty=2.0
60
+ ):
61
+ idx = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long)
62
+ generated = []
63
+ start_time = time.time()
64
+
65
+ with torch.no_grad():
66
+ for _ in range(max_new_tokens):
67
+ if idx.shape[1] >= model.max_seq_len:
68
+ break
69
+
70
+ logits = model(idx)
71
+ logits = logits[:, -1, :] / temperature
72
+
73
+ # Apply repetition penalty
74
+ for token_id in set(generated):
75
+ logits[0, token_id] /= repetition_penalty
76
+
77
+ # Apply Top-K and/or Top-P filtering
78
+ if top_k is not None:
79
+ logits = top_k_logits(logits, top_k)
80
+ if top_p is not None:
81
+ logits = top_p_logits(logits, top_p)
82
+
83
+ probs = torch.softmax(logits, dim=-1)
84
+ next_id = torch.multinomial(probs, num_samples=1)
85
+ idx = torch.cat([idx, next_id], dim=1)
86
+ generated.append(next_id.item())
87
+ print(tokenizer.decode([next_id.item()]), end=' ', flush=True)
88
+
89
+ elapsed = time.time() - start_time
90
+ tps = len(generated) / elapsed if elapsed > 0 else 0
91
+ print(f"\n[Generated {len(generated)} tokens in {elapsed:.2f} seconds | {tps:.2f} tokens/sec]")
92
+ return idx
93
+
94
+ # --- 5. Main input loop ---
95
+ while True:
96
+ prompt = input("\nEnter your prompt (or type 'exit' to quit): ")
97
+ if prompt.lower() == 'exit':
98
+ break
99
+
100
+ print("\nStreaming output:")
101
+ generate_stream(
102
+ model, tokenizer, prompt,
103
+ max_new_tokens=90,
104
+ temperature=2.0,
105
+ top_k=100,
106
+ top_p=0.9,
107
+ repetition_penalty=1.8
108
+ )
miniGPT.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformer import TransformerBlock
5
+
6
+ class MiniGPT(nn.Module):
7
+ def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, max_seq_len):
8
+ super().__init__()
9
+ self.max_seq_len = max_seq_len
10
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
11
+ self.pos_embedding = nn.Embedding(max_seq_len, embed_dim)
12
+ self.blocks = nn.Sequential(
13
+ *[TransformerBlock(embed_dim, num_heads, ff_dim) for _ in range(num_layers)]
14
+ )
15
+ self.ln_f = nn.LayerNorm(embed_dim)
16
+ self.head = nn.Linear(embed_dim, vocab_size, bias=False)
17
+
18
+ self.head.weight = self.token_embedding.weight
19
+
20
+ def forward(self, idx, mask=None):
21
+ B, T = idx.shape
22
+ tok_emb = self.token_embedding(idx)
23
+ pos = torch.arange(T,device=idx.device).unsqueeze(0)
24
+ pos_emb = self.pos_embedding(pos)
25
+ x = tok_emb + pos_emb
26
+ x = self.blocks(x, mask=mask) if mask is not None else self.blocks(x)
27
+ x = self.ln_f(x)
28
+ logits = self.head(x)
29
+ return logits
30
+
multiheadattention.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class MultiHeadAttention(nn.Module):
6
+ def __init__(self, embed_dim, num_heads):
7
+ super().__init__()
8
+ assert embed_dim % num_heads == 0, "Embedding dim must be divisible by num heads"
9
+ self.embed_dim = embed_dim
10
+ self.num_heads = num_heads
11
+ self.head_dim = embed_dim // num_heads
12
+
13
+ self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
14
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
15
+
16
+ def forward(self, x, mask=None):
17
+ B, T, C = x.shape
18
+ qkv = self.qkv_proj(x)
19
+ qkv = qkv.reshape(B, T, self.num_heads, 3 * self.head_dim)
20
+ qkv = qkv.permute(0, 2, 1, 3)
21
+ q, k, v = qkv.chunk(3, dim=-1)
22
+
23
+ attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
24
+
25
+ if mask is not None:
26
+ attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
27
+
28
+ attn_weights = F.softmax(attn_scores, dim=-1)
29
+ attn_output = attn_weights @ v
30
+
31
+
32
+ attn_output = attn_output.transpose(1, 2).reshape(B, T, C)
33
+ ouptut = self.out_proj(attn_output)
34
+ return ouptut
transformer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from multiheadattention import MultiHeadAttention
5
+
6
+ class TransformerBlock(nn.Module):
7
+ def __init__(self, embed_dim, num_heads, ff_dim):
8
+ super().__init__()
9
+ self.attn = MultiHeadAttention(embed_dim, num_heads)
10
+ self.ln1 = nn.LayerNorm(embed_dim)
11
+ self.ff = nn.Sequential(
12
+ nn.Linear(embed_dim, ff_dim),
13
+ nn.GELU(),
14
+ nn.Linear(ff_dim, embed_dim)
15
+ )
16
+
17
+ self.ln2 = nn.LayerNorm(embed_dim)
18
+
19
+ def forward(self, x, mask=None):
20
+ x = x + self.attn(self.ln1(x), mask = mask)
21
+ x = x + self.ff(self.ln2(x))
22
+ return x
23
+
24
+
wordlevel.json ADDED
The diff for this file is too large to render. See raw diff