BICORP commited on
Commit
55c1883
·
verified ·
1 Parent(s): fd2c745

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -58
train.py DELETED
@@ -1,58 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- from torch.utils.data import Dataset, DataLoader
5
- from transformers import BertTokenizer
6
-
7
- class SmallGemmaModel(nn.Module):
8
- def __init__(self, vocab_size, embedding_dim=256, num_heads=4, num_layers=4):
9
- super(SmallGemmaModel, self).__init__()
10
- self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
11
- self.transformer_layers = nn.ModuleList([
12
- nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
13
- ])
14
- self.output_layer = nn.Linear(embedding_dim, vocab_size)
15
-
16
- def forward(self, input_ids):
17
- text_embeddings = self.token_embeddings(input_ids)
18
- for layer in self.transformer_layers:
19
- text_embeddings = layer(text_embeddings)
20
- return self.output_layer(text_embeddings)
21
-
22
- class KnowledgeDataset(Dataset):
23
- def __init__(self, file_path, tokenizer, max_length=128): # Reduced max_length
24
- self.tokenizer = tokenizer
25
- self.max_length = max_length
26
- with open(file_path, 'r') as f:
27
- self.data = f.read().splitlines()
28
-
29
- def __len__(self):
30
- return len(self.data)
31
-
32
- def __getitem__(self, idx):
33
- text = self.data[idx]
34
- encoding = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
35
- input_ids = encoding['input_ids'].squeeze()
36
- return input_ids[:-1], input_ids[1:]
37
-
38
- def train_model(model, dataset, epochs=5, batch_size=8, learning_rate=1e-4): # Reduced batch size
39
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
40
- optimizer = optim.Adam(model.parameters(), lr=learning_rate)
41
- loss_fn = nn.CrossEntropyLoss()
42
-
43
- model.train()
44
- for epoch in range(epochs):
45
- for input_ids, target_ids in dataloader:
46
- optimizer.zero_grad()
47
- outputs = model(input_ids)
48
- loss = loss_fn(outputs.view(-1, outputs.size(-1)), target_ids.view(-1))
49
- loss.backward()
50
- optimizer.step()
51
- print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
52
-
53
- if __name__ == "__main__":
54
- vocab_size = 262208 // 4
55
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
56
- model = SmallGemmaModel(vocab_size=vocab_size)
57
- dataset = KnowledgeDataset('default.txt', tokenizer)
58
- train_model(model, dataset)