| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader |
| | from transformers import BertTokenizer |
| |
|
| | class SmallGemmaModel(nn.Module): |
| | def __init__(self, vocab_size, embedding_dim=256, num_heads=4, num_layers=4): |
| | super(SmallGemmaModel, self).__init__() |
| | self.token_embeddings = nn.Embedding(vocab_size, embedding_dim) |
| | self.transformer_layers = nn.ModuleList([ |
| | nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers) |
| | ]) |
| | self.output_layer = nn.Linear(embedding_dim, vocab_size) |
| |
|
| | def forward(self, input_ids): |
| | text_embeddings = self.token_embeddings(input_ids) |
| | for layer in self.transformer_layers: |
| | text_embeddings = layer(text_embeddings) |
| | return self.output_layer(text_embeddings) |
| |
|
| | class KnowledgeDataset(Dataset): |
| | def __init__(self, file_path, tokenizer, max_length=128): |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| | with open(file_path, 'r') as f: |
| | self.data = f.read().splitlines() |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | text = self.data[idx] |
| | encoding = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length) |
| | input_ids = encoding['input_ids'].squeeze() |
| | return input_ids[:-1], input_ids[1:] |
| |
|
| | def train_model(model, dataset, epochs=5, batch_size=8, learning_rate=1e-4): |
| | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
| | optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
| | loss_fn = nn.CrossEntropyLoss() |
| |
|
| | model.train() |
| | for epoch in range(epochs): |
| | for input_ids, target_ids in dataloader: |
| | optimizer.zero_grad() |
| | outputs = model(input_ids) |
| | loss = loss_fn(outputs.view(-1, outputs.size(-1)), target_ids.view(-1)) |
| | loss.backward() |
| | optimizer.step() |
| | print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}") |
| |
|
| | if __name__ == "__main__": |
| | vocab_size = 262208 // 4 |
| | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| | model = SmallGemmaModel(vocab_size=vocab_size) |
| | dataset = KnowledgeDataset('default.txt', tokenizer) |
| | train_model(model, dataset) |