AGofficial commited on
Commit
a6fc25f
·
verified ·
1 Parent(s): 2a08825

Upload gptmodel4.py

Browse files
Files changed (1) hide show
  1. gptmodel4.py +296 -0
gptmodel4.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from tokenizers import Tokenizer
8
+ from tokenizers.models import BPE
9
+ from tokenizers.trainers import BpeTrainer
10
+ from tokenizers.pre_tokenizers import Whitespace
11
+ from pathlib import Path
12
+ import argparse
13
+
14
+ class LightweightGPT(nn.Module):
15
+ def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer):
16
+ super().__init__()
17
+ self.block_size = block_size
18
+ self.token_embedding = nn.Embedding(vocab_size, n_embd)
19
+ self.position_embedding = nn.Embedding(block_size, n_embd)
20
+
21
+ self.blocks = nn.ModuleList([
22
+ nn.TransformerDecoderLayer(
23
+ d_model=n_embd,
24
+ nhead=n_head,
25
+ dim_feedforward=4 * n_embd,
26
+ dropout=0.1,
27
+ activation='gelu',
28
+ batch_first=True,
29
+ norm_first=True
30
+ )
31
+ for _ in range(n_layer)
32
+ ])
33
+ self.ln_f = nn.LayerNorm(n_embd)
34
+ self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
35
+
36
+ def forward(self, idx, targets=None):
37
+ B, T = idx.shape
38
+ device = idx.device
39
+ causal_mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
40
+
41
+ token_emb = self.token_embedding(idx)
42
+ pos = torch.arange(0, T, dtype=torch.long, device=device)
43
+ pos_emb = self.position_embedding(pos)
44
+
45
+ x = token_emb + pos_emb
46
+
47
+ for block in self.blocks:
48
+ x = block(x, x, tgt_mask=causal_mask)
49
+
50
+ x = self.ln_f(x)
51
+ logits = self.lm_head(x)
52
+
53
+ loss = None
54
+ if targets is not None:
55
+ loss = F.cross_entropy(
56
+ logits.view(-1, logits.size(-1)),
57
+ targets.view(-1),
58
+ ignore_index=-1
59
+ )
60
+ return logits, loss
61
+
62
+ def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, stop_token=None):
63
+ for _ in range(max_new_tokens):
64
+ idx_cond = idx[:, -self.block_size:]
65
+ logits, _ = self(idx_cond)
66
+ logits = logits[:, -1, :]
67
+ logits = logits / temperature
68
+
69
+ if top_k is not None:
70
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
71
+ logits[logits < v[:, [-1]]] = -float('Inf')
72
+
73
+ probs = F.softmax(logits, dim=-1)
74
+ idx_next = torch.multinomial(probs, num_samples=1)
75
+
76
+ if stop_token is not None and idx_next.item() == stop_token:
77
+ break
78
+
79
+ idx = torch.cat((idx, idx_next), dim=1)
80
+
81
+ return idx
82
+
83
+ class ConversationDataset(Dataset):
84
+ def __init__(self, tokens, block_size, end_token_id):
85
+ self.end_token = end_token_id
86
+ self.block_size = block_size
87
+ self.segments = []
88
+ current_start = 0
89
+ for i, token in enumerate(tokens):
90
+ if token == end_token_id:
91
+ segment = tokens[current_start:i+1]
92
+ if len(segment) < block_size + 1:
93
+ padding = [end_token_id] * (block_size + 1 - len(segment))
94
+ segment.extend(padding)
95
+ self.segments.append(segment)
96
+ current_start = i + 1
97
+ print(f"Created {len(self.segments)} conversation segments.")
98
+
99
+ def __len__(self):
100
+ return len(self.segments)
101
+
102
+ def __getitem__(self, idx):
103
+ segment = self.segments[idx]
104
+ start_pos = torch.randint(0, max(1, len(segment) - self.block_size), (1,)).item()
105
+ chunk = segment[start_pos:start_pos + self.block_size + 1]
106
+
107
+ x = torch.tensor(chunk[:-1], dtype=torch.long)
108
+ y = torch.tensor(chunk[1:], dtype=torch.long)
109
+ return x, y
110
+
111
+ class AIBuilder:
112
+ def __init__(self, model_name: str):
113
+ self.model_name = model_name
114
+ self.output_folder = model_name.replace(" ", "_").lower()
115
+ self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
116
+ print(f"Using device: {self.device}")
117
+
118
+ self.model_config = {
119
+ "block_size": 128,
120
+ "n_embd": 128,
121
+ "n_head": 4,
122
+ "n_layer": 4,
123
+ "vocab_size": 8000,
124
+ "batch_size": 8,
125
+ "grad_accum": 4,
126
+ "max_epochs": 3,
127
+ }
128
+
129
+ def _build_tokenizer(self, training_data: str):
130
+ tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
131
+ tokenizer.pre_tokenizer = Whitespace()
132
+ trainer = BpeTrainer(
133
+ special_tokens=["[UNK]", "[PAD]", "user:", "ai:", "<|endoftext|>"],
134
+ vocab_size=self.model_config["vocab_size"]
135
+ )
136
+ tokenizer.train_from_iterator(self._get_text_iterator(training_data), trainer)
137
+ return tokenizer
138
+
139
+ def _get_text_iterator(self, text, chunk_size=1000):
140
+ for i in range(0, len(text), chunk_size):
141
+ yield text[i:i + chunk_size]
142
+
143
+ def _prepare_dataloader(self, tokenizer, text):
144
+ tokens = tokenizer.encode(text).ids
145
+ end_token_id = tokenizer.token_to_id("<|endoftext|>")
146
+ dataset = ConversationDataset(tokens, self.model_config["block_size"], end_token_id)
147
+
148
+ def collate_fn(batch):
149
+ xs, ys = zip(*batch)
150
+ return torch.stack(xs), torch.stack(ys)
151
+
152
+ return DataLoader(dataset, batch_size=self.model_config["batch_size"], shuffle=True, collate_fn=collate_fn)
153
+
154
+ def train(self, training_data: str):
155
+ os.makedirs(self.output_folder, exist_ok=True)
156
+
157
+ print("Building and saving tokenizer...")
158
+ tokenizer = self._build_tokenizer(training_data)
159
+ tokenizer.save(os.path.join(self.output_folder, "tokenizer.json"))
160
+
161
+ print("Saving configuration file...")
162
+ self._save_config(tokenizer) # MOVED HERE
163
+
164
+ print("Preparing data for training...")
165
+ dataloader = self._prepare_dataloader(tokenizer, training_data)
166
+
167
+ model = LightweightGPT(
168
+ vocab_size=tokenizer.get_vocab_size(),
169
+ block_size=self.model_config["block_size"],
170
+ n_embd=self.model_config["n_embd"],
171
+ n_head=self.model_config["n_head"],
172
+ n_layer=self.model_config["n_layer"]
173
+ ).to(self.device)
174
+
175
+ optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
176
+ model_path = os.path.join(self.output_folder, "model.pt")
177
+
178
+ print("\n--- Starting Model Training ---")
179
+ model.train()
180
+ best_loss = float('inf')
181
+
182
+ for epoch in range(self.model_config["max_epochs"]):
183
+ optimizer.zero_grad()
184
+ for batch_idx, (x, y) in enumerate(dataloader):
185
+ x, y = x.to(self.device), y.to(self.device)
186
+ _, loss = model(x, y)
187
+
188
+ loss = loss / self.model_config["grad_accum"]
189
+ loss.backward()
190
+
191
+ if (batch_idx + 1) % self.model_config["grad_accum"] == 0:
192
+ optimizer.step()
193
+ optimizer.zero_grad()
194
+
195
+ current_loss = loss.detach().item() * self.model_config["grad_accum"]
196
+
197
+ if batch_idx % 50 == 0:
198
+ print(f"Epoch {epoch+1} | Batch {batch_idx} | Loss: {current_loss:.4f}")
199
+
200
+ if current_loss < best_loss:
201
+ best_loss = current_loss
202
+ torch.save(model.state_dict(), model_path)
203
+ print(f"🎉 New best model saved with loss: {best_loss:.4f}")
204
+
205
+ print(f"✅ Training complete. Final best loss: {best_loss:.4f}")
206
+
207
+ def _save_config(self, tokenizer):
208
+ config = {
209
+ "model_name": self.model_name,
210
+ **self.model_config,
211
+ "vocab_size": tokenizer.get_vocab_size(),
212
+ "end_token_id": tokenizer.token_to_id("<|endoftext|>")
213
+ }
214
+ with open(os.path.join(self.output_folder, "config.json"), "w") as f:
215
+ json.dump(config, f, indent=2)
216
+ print(f"Configuration saved to {os.path.join(self.output_folder, 'config.json')}")
217
+
218
+ class ChatInterface:
219
+ def __init__(self, model_dir="aglm"):
220
+ self.model_dir = Path(model_dir)
221
+ self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
222
+ self.load_model()
223
+
224
+ def load_model(self):
225
+ with open(self.model_dir / "config.json", "r") as f:
226
+ self.config = json.load(f)
227
+
228
+ self.tokenizer = Tokenizer.from_file(str(self.model_dir / "tokenizer.json"))
229
+ self.end_token_id = self.config.get("end_token_id")
230
+
231
+ self.model = LightweightGPT(
232
+ vocab_size=self.config["vocab_size"],
233
+ block_size=self.config["block_size"],
234
+ n_embd=self.config["n_embd"],
235
+ n_head=self.config["n_head"],
236
+ n_layer=self.config["n_layer"]
237
+ ).to(self.device)
238
+
239
+ self.model.load_state_dict(torch.load(self.model_dir / "model.pt", map_location=self.device))
240
+ self.model.eval()
241
+ print("✅ Model loaded successfully!")
242
+
243
+ def chat(self):
244
+ print("\n===== AI Assistant Ready =====")
245
+ print("Type 'quit' or 'exit' to end the chat.\n")
246
+
247
+ while True:
248
+ user_input = input("user: ")
249
+ if user_input.lower() in ["quit", "exit"]:
250
+ break
251
+
252
+ prompt = f"user: {user_input}\nai:"
253
+ input_ids = self.tokenizer.encode(prompt).ids
254
+ input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
255
+
256
+ with torch.no_grad():
257
+ output_ids = self.model.generate(
258
+ input_tensor,
259
+ max_new_tokens=150,
260
+ temperature=0.7,
261
+ top_k=40,
262
+ stop_token=self.end_token_id
263
+ )
264
+
265
+ response_ids = output_ids[0, len(input_ids):].tolist()
266
+ response = self.tokenizer.decode(response_ids)
267
+ response = response.replace("<|endoftext|>", "").strip()
268
+
269
+ print(f"ai: {response}")
270
+
271
+ if __name__ == "__main__":
272
+ parser = argparse.ArgumentParser(description="Train or chat with an AgLM model.")
273
+ parser.add_argument('action', choices=['train', 'chat'], nargs='?', default='train', help="Choose 'train' (default) or 'chat'.")
274
+ args = parser.parse_args()
275
+
276
+ model_folder = "aglm"
277
+
278
+ if args.action == 'train':
279
+ print("--- Starting Setup for AgLM ---")
280
+ builder = AIBuilder("AgLM")
281
+ try:
282
+ with open("train.txt", "r", encoding="utf-8") as f:
283
+ data = f.read()
284
+ builder.train(data)
285
+ print("\n✅ Training finished. You can now run with the 'chat' argument.")
286
+ print(f"To chat, run: python {os.path.basename(__file__)} chat")
287
+ except FileNotFoundError:
288
+ print("\nERROR: train.txt not found. Please create train.txt with your conversational data to train the model.")
289
+
290
+ elif args.action == 'chat':
291
+ print("--- Starting Chat Interface for AgLM ---")
292
+ if os.path.exists(model_folder) and os.path.exists(os.path.join(model_folder, "model.pt")):
293
+ chat_bot = ChatInterface(model_dir=model_folder)
294
+ chat_bot.chat()
295
+ else:
296
+ print(f"\nERROR: Model directory '{model_folder}' not found. Please run training first.")