Test-25 / train.py
BICORP's picture
Upload directory
e5fbea9 verified
raw
history blame
2.41 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
class SmallGemmaModel(nn.Module):
def __init__(self, vocab_size, embedding_dim=256, num_heads=4, num_layers=4):
super(SmallGemmaModel, self).__init__()
self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.transformer_layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
])
self.output_layer = nn.Linear(embedding_dim, vocab_size)
def forward(self, input_ids):
text_embeddings = self.token_embeddings(input_ids)
for layer in self.transformer_layers:
text_embeddings = layer(text_embeddings)
return self.output_layer(text_embeddings)
class KnowledgeDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=128): # Reduced max_length
self.tokenizer = tokenizer
self.max_length = max_length
with open(file_path, 'r') as f:
self.data = f.read().splitlines()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data[idx]
encoding = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
input_ids = encoding['input_ids'].squeeze()
return input_ids[:-1], input_ids[1:]
def train_model(model, dataset, epochs=5, batch_size=8, learning_rate=1e-4): # Reduced batch size
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
for input_ids, target_ids in dataloader:
optimizer.zero_grad()
outputs = model(input_ids)
loss = loss_fn(outputs.view(-1, outputs.size(-1)), target_ids.view(-1))
loss.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
if __name__ == "__main__":
vocab_size = 262208 // 4
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = SmallGemmaModel(vocab_size=vocab_size)
dataset = KnowledgeDataset('default.txt', tokenizer)
train_model(model, dataset)