BICORP commited on
Commit
e5fbea9
·
verified ·
1 Parent(s): 5854bd8

Upload directory

Browse files
Files changed (1) hide show
  1. train.py +58 -0
train.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from transformers import BertTokenizer
6
+
7
+ class SmallGemmaModel(nn.Module):
8
+ def __init__(self, vocab_size, embedding_dim=256, num_heads=4, num_layers=4):
9
+ super(SmallGemmaModel, self).__init__()
10
+ self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
11
+ self.transformer_layers = nn.ModuleList([
12
+ nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
13
+ ])
14
+ self.output_layer = nn.Linear(embedding_dim, vocab_size)
15
+
16
+ def forward(self, input_ids):
17
+ text_embeddings = self.token_embeddings(input_ids)
18
+ for layer in self.transformer_layers:
19
+ text_embeddings = layer(text_embeddings)
20
+ return self.output_layer(text_embeddings)
21
+
22
+ class KnowledgeDataset(Dataset):
23
+ def __init__(self, file_path, tokenizer, max_length=128): # Reduced max_length
24
+ self.tokenizer = tokenizer
25
+ self.max_length = max_length
26
+ with open(file_path, 'r') as f:
27
+ self.data = f.read().splitlines()
28
+
29
+ def __len__(self):
30
+ return len(self.data)
31
+
32
+ def __getitem__(self, idx):
33
+ text = self.data[idx]
34
+ encoding = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
35
+ input_ids = encoding['input_ids'].squeeze()
36
+ return input_ids[:-1], input_ids[1:]
37
+
38
+ def train_model(model, dataset, epochs=5, batch_size=8, learning_rate=1e-4): # Reduced batch size
39
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
40
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
41
+ loss_fn = nn.CrossEntropyLoss()
42
+
43
+ model.train()
44
+ for epoch in range(epochs):
45
+ for input_ids, target_ids in dataloader:
46
+ optimizer.zero_grad()
47
+ outputs = model(input_ids)
48
+ loss = loss_fn(outputs.view(-1, outputs.size(-1)), target_ids.view(-1))
49
+ loss.backward()
50
+ optimizer.step()
51
+ print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
52
+
53
+ if __name__ == "__main__":
54
+ vocab_size = 262208 // 4
55
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
56
+ model = SmallGemmaModel(vocab_size=vocab_size)
57
+ dataset = KnowledgeDataset('default.txt', tokenizer)
58
+ train_model(model, dataset)