Kleinpuki2 commited on
Commit
a69eabc
·
verified ·
1 Parent(s): f2f562e

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +209 -0
model.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import json
5
+ import re
6
+
7
+ class BPETokenizer:
8
+ def __init__(self, model_type="gpt2"):
9
+ import tiktoken
10
+ self.enc = tiktoken.get_encoding(model_type)
11
+ self.vocab_size = self.enc.n_vocab
12
+
13
+ def encode(self, text: str):
14
+ return self.enc.encode(text, allowed_special={'<|endoftext|>'})
15
+
16
+ def decode(self, ids):
17
+ return self.enc.decode(ids)
18
+
19
+ def save(self, path: str):
20
+ with open(path, "w", encoding="utf-8") as f:
21
+ json.dump({"type": "bpe", "model": "gpt2"}, f)
22
+
23
+ def load(self, path: str):
24
+ pass
25
+
26
+ class WordTokenizer:
27
+ def __init__(self):
28
+ self.word2idx = {"<PAD>": 0, "<UNK>": 1}
29
+ self.idx2word = {0: "<PAD>", 1: "<UNK>"}
30
+ self.vocab_size = 2
31
+
32
+ def build(self, text: str, max_vocab: int = 10000):
33
+ tokens = re.findall(r"\w+|[^\w\s]|\n", text.lower())
34
+ from collections import Counter
35
+ counts = Counter(tokens)
36
+ most_common = counts.most_common(max_vocab - 2)
37
+ for word, _ in most_common:
38
+ idx = len(self.word2idx)
39
+ self.word2idx[word] = idx
40
+ self.idx2word[idx] = word
41
+ self.vocab_size = len(self.word2idx)
42
+
43
+ def encode(self, text: str):
44
+ tokens = re.findall(r"\w+|[^\w\s]|\n", text.lower())
45
+ return [self.word2idx.get(t, 1) for t in tokens]
46
+
47
+ def decode(self, ids):
48
+ words = [self.idx2word.get(i, "<UNK>") for i in ids]
49
+ result = ""
50
+ for w in words:
51
+ if w in ".,!?;:)]}\"'" or result == "": result += w
52
+ elif w == "\n": result += "\n"
53
+ else: result += " " + w
54
+ return result
55
+
56
+ def save(self, path: str):
57
+ with open(path, "w", encoding="utf-8") as f:
58
+ json.dump({
59
+ "word2idx": self.word2idx,
60
+ "idx2word": {str(k): v for k, v in self.idx2word.items()}
61
+ }, f, ensure_ascii=False)
62
+
63
+ def load(self, path: str):
64
+ with open(path, "r", encoding="utf-8") as f:
65
+ data = json.load(f)
66
+ self.word2idx = data["word2idx"]
67
+ self.idx2word = {int(k): v for k, v in data["idx2word"].items()}
68
+ self.vocab_size = len(self.word2idx)
69
+
70
+ class MiniTransformer(nn.Module):
71
+ def __init__(self, vocab_size, emb_dim=128, n_layers=4, n_heads=4, ctx_len=64, dropout=0.1):
72
+ super().__init__()
73
+ self.ctx_len = ctx_len
74
+ self.n_heads = n_heads
75
+ self.emb_dim = emb_dim
76
+ self.n_layers = n_layers
77
+ self.token_embedding_table = nn.Embedding(vocab_size, emb_dim)
78
+ self.position_embedding_table = nn.Embedding(ctx_len, emb_dim)
79
+ self.drop = nn.Dropout(dropout)
80
+ self.blocks = nn.ModuleList([
81
+ nn.TransformerEncoderLayer(
82
+ d_model=emb_dim,
83
+ nhead=n_heads,
84
+ dim_feedforward=emb_dim * 4,
85
+ dropout=dropout,
86
+ batch_first=True,
87
+ norm_first=True,
88
+ activation='gelu'
89
+ ) for _ in range(n_layers)
90
+ ])
91
+ self.ln_f = nn.LayerNorm(emb_dim)
92
+ self.lm_head = nn.Linear(emb_dim, vocab_size, bias=False)
93
+ self.apply(self._init_weights)
94
+
95
+ def _init_weights(self, module):
96
+ if isinstance(module, nn.Linear):
97
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
98
+ if module.bias is not None:
99
+ nn.init.zeros_(module.bias)
100
+ elif isinstance(module, nn.Embedding):
101
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
102
+
103
+ def forward(self, idx, targets=None, use_checkpointing=False):
104
+ device = idx.device
105
+ B, T = idx.shape
106
+ tok_emb = self.token_embedding_table(idx)
107
+ pos_emb = self.position_embedding_table(torch.arange(T, device=device))
108
+ x = self.drop(tok_emb + pos_emb)
109
+ mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
110
+ for block in self.blocks:
111
+ if use_checkpointing and self.training:
112
+ from torch.utils.checkpoint import checkpoint
113
+ def custom_forward(x_in, m_in):
114
+ return block(x_in, src_mask=m_in, is_causal=True)
115
+ x = checkpoint(custom_forward, x, mask, use_reentrant=False)
116
+ else:
117
+ x = block(x, src_mask=mask, is_causal=True)
118
+ x = self.ln_f(x)
119
+ logits = self.lm_head(x)
120
+ loss = None
121
+ if targets is not None:
122
+ B, T, C = logits.shape
123
+ loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
124
+ return logits, loss
125
+
126
+ def generate(self, idx, max_new_tokens, temperature=0.8, top_k=40, repetition_penalty=1.0):
127
+ device = next(self.parameters()).device
128
+ if isinstance(idx, list):
129
+ idx = torch.tensor([idx], dtype=torch.long)
130
+ idx = idx.to(device)
131
+ self.eval()
132
+ with torch.no_grad():
133
+ for _ in range(max_new_tokens):
134
+ idx_cond = idx[:, -self.ctx_len:]
135
+ logits, _ = self(idx_cond)
136
+ logits = logits[:, -1, :] / temperature
137
+ if repetition_penalty != 1.0:
138
+ for i in range(idx.shape[1]):
139
+ token_id = idx[0, i].item()
140
+ if logits[0, token_id] > 0:
141
+ logits[0, token_id] /= repetition_penalty
142
+ else:
143
+ logits[0, token_id] *= repetition_penalty
144
+ if top_k > 0:
145
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
+ logits[logits < v[:, [-1]]] = float('-inf')
147
+ probs = F.softmax(logits, dim=-1)
148
+ idx_next = torch.multinomial(probs, num_samples=1)
149
+ idx = torch.cat((idx, idx_next), dim=1)
150
+ return idx
151
+
152
+ def generate_stream(self, idx, max_new_tokens, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.2):
153
+ device = next(self.parameters()).device
154
+ if isinstance(idx, list):
155
+ idx = torch.tensor([idx], dtype=torch.long)
156
+ idx = idx.to(device)
157
+ self.eval()
158
+ with torch.no_grad():
159
+ for _ in range(max_new_tokens):
160
+ idx_cond = idx[:, -self.ctx_len:]
161
+ logits, _ = self(idx_cond)
162
+ logits = logits[:, -1, :] / temperature
163
+ if repetition_penalty != 1.0:
164
+ for i in range(idx.shape[1]):
165
+ token_id = idx[0, i].item()
166
+ if logits[0, token_id] > 0:
167
+ logits[0, token_id] /= repetition_penalty
168
+ else:
169
+ logits[0, token_id] *= repetition_penalty
170
+ if top_k > 0:
171
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
172
+ logits[logits < v[:, [-1]]] = float('-inf')
173
+ if top_p > 0.0 and top_p < 1.0:
174
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
175
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
176
+ sorted_indices_to_remove = cumulative_probs > top_p
177
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
178
+ sorted_indices_to_remove[..., 0] = 0
179
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
180
+ logits[indices_to_remove] = float('-inf')
181
+ probs = F.softmax(logits, dim=-1)
182
+ idx_next = torch.multinomial(probs, num_samples=1)
183
+ yield idx_next.item(), torch.max(probs).item()
184
+ idx = torch.cat((idx, idx_next), dim=1)
185
+
186
+ def save(self, path: str):
187
+ torch.save({
188
+ 'model_state': self.state_dict(),
189
+ 'config': {
190
+ 'vocab_size': self.token_embedding_table.num_embeddings,
191
+ 'emb_dim': self.emb_dim,
192
+ 'n_layers': self.n_layers,
193
+ 'n_heads': self.n_heads,
194
+ 'ctx_len': self.ctx_len,
195
+ }
196
+ }, path)
197
+ print(f"Modell gespeichert: {path}")
198
+
199
+ @classmethod
200
+ def load(cls, path: str, device='cpu'):
201
+ if not torch.cuda.is_available():
202
+ device = 'cpu'
203
+ ckpt = torch.load(path, map_location=device, weights_only=False)
204
+ cfg = ckpt['config']
205
+ m = cls(cfg['vocab_size'], cfg['emb_dim'], cfg['n_layers'], cfg['n_heads'], cfg['ctx_len'])
206
+ m.load_state_dict(ckpt['model_state'])
207
+ m.to(device)
208
+ print(f"Modell geladen: {path}")
209
+ return m