import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from transformers import BertTokenizer class SmallGemmaModel(nn.Module): def __init__(self, vocab_size, embedding_dim=256, num_heads=4, num_layers=4): super(SmallGemmaModel, self).__init__() self.token_embeddings = nn.Embedding(vocab_size, embedding_dim) self.transformer_layers = nn.ModuleList([ nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers) ]) self.output_layer = nn.Linear(embedding_dim, vocab_size) def forward(self, input_ids): text_embeddings = self.token_embeddings(input_ids) for layer in self.transformer_layers: text_embeddings = layer(text_embeddings) return self.output_layer(text_embeddings) class KnowledgeDataset(Dataset): def __init__(self, file_path, tokenizer, max_length=128): # Reduced max_length self.tokenizer = tokenizer self.max_length = max_length with open(file_path, 'r') as f: self.data = f.read().splitlines() def __len__(self): return len(self.data) def __getitem__(self, idx): text = self.data[idx] encoding = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) input_ids = encoding['input_ids'].squeeze() return input_ids[:-1], input_ids[1:] def train_model(model, dataset, epochs=5, batch_size=8, learning_rate=1e-4): # Reduced batch size dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) optimizer = optim.Adam(model.parameters(), lr=learning_rate) loss_fn = nn.CrossEntropyLoss() model.train() for epoch in range(epochs): for input_ids, target_ids in dataloader: optimizer.zero_grad() outputs = model(input_ids) loss = loss_fn(outputs.view(-1, outputs.size(-1)), target_ids.view(-1)) loss.backward() optimizer.step() print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}") if __name__ == "__main__": vocab_size = 262208 // 4 tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = SmallGemmaModel(vocab_size=vocab_size) dataset = KnowledgeDataset('default.txt', tokenizer) train_model(model, dataset)