Spaces:
Sleeping
Sleeping
| # src/train.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| import json | |
| from model import TransformerModel | |
| from utils import load_vocab, tokenize | |
| from tqdm import tqdm | |
| import os | |
| import subprocess | |
| class TextDataset(Dataset): | |
| def __init__(self, data_path, vocab, seq_length=50): | |
| with open(data_path, 'r', encoding='utf-8') as f: | |
| self.data = json.load(f) | |
| self.vocab = vocab | |
| self.seq_length = seq_length | |
| def __len__(self): | |
| return len(self.data) | |
| def numericalize(self, tokens): | |
| return [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens] | |
| def __getitem__(self, idx): | |
| tokens = self.data[idx] | |
| numericalized = self.numericalize(tokens) | |
| if len(numericalized) < self.seq_length + 1: | |
| numericalized += [self.vocab['<PAD>']] * (self.seq_length + 1 - len(numericalized)) | |
| else: | |
| numericalized = numericalized[:self.seq_length + 1] | |
| input_seq = torch.tensor(numericalized[:-1], dtype=torch.long) | |
| target_seq = torch.tensor(numericalized[1:], dtype=torch.long) | |
| return input_seq, target_seq | |
| def collate_fn(batch): | |
| inputs, targets = zip(*batch) | |
| inputs = torch.stack(inputs) | |
| targets = torch.stack(targets) | |
| return inputs, targets | |
| def get_dataloader(data_path, vocab, batch_size=64, seq_length=50): | |
| dataset = TextDataset(data_path, vocab, seq_length) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) | |
| return dataloader | |
| def train_model(config): | |
| # Check if vocab.json exists | |
| if not os.path.exists(config['vocab_path']): | |
| print("vocab.json not found. Running data_processing.py...") | |
| subprocess.run(['python', 'src/data_processing.py'], check=True) | |
| # Load vocabulary | |
| vocab = load_vocab(config['vocab_path']) | |
| vocab_size = len(vocab) | |
| # Initialize model | |
| model = TransformerModel( | |
| vocab_size=vocab_size, | |
| embed_size=config['embed_size'], | |
| num_heads=config['num_heads'], | |
| hidden_dim=config['hidden_dim'], | |
| num_layers=config['num_layers'], | |
| dropout=config['dropout'] | |
| ) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = model.to(device) | |
| # Loss and optimizer | |
| criterion = nn.CrossEntropyLoss(ignore_index=vocab['<PAD>']) | |
| optimizer = optim.Adam(model.parameters(), lr=config['learning_rate']) | |
| # DataLoader | |
| dataloader = get_dataloader( | |
| data_path=config['data_path'], | |
| vocab=vocab, | |
| batch_size=config['batch_size'], | |
| seq_length=config['seq_length'] | |
| ) | |
| # Training loop | |
| model.train() | |
| for epoch in range(1, config['epochs'] + 1): | |
| epoch_loss = 0 | |
| progress = tqdm(dataloader, desc=f"Epoch {epoch}/{config['epochs']}") | |
| for inputs, targets in progress: | |
| inputs = inputs.to(device) | |
| targets = targets.to(device) | |
| optimizer.zero_grad() | |
| src_mask = model.generate_square_subsequent_mask(inputs.size(1)).to(device) | |
| outputs = model(inputs, src_mask) | |
| loss = criterion(outputs.view(-1, vocab_size), targets.view(-1)) | |
| loss.backward() | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| progress.set_postfix(loss=loss.item()) | |
| avg_loss = epoch_loss / len(dataloader) | |
| print(f"Epoch {epoch} completed. Average Loss: {avg_loss:.4f}") | |
| # Save model after each epoch | |
| os.makedirs('models', exist_ok=True) | |
| torch.save(model.state_dict(), f"models/3ed0k4_model_epoch{epoch}.pth") | |
| print(f"Model saved at models/3ed0k4_model_epoch{epoch}.pth") | |
| if __name__ == "__main__": | |
| config = { | |
| 'vocab_path': 'vocab.json', | |
| 'data_path': 'data/processed/tokenized_data.json', | |
| 'embed_size': 256, | |
| 'num_heads': 8, | |
| 'hidden_dim': 512, | |
| 'num_layers': 4, | |
| 'dropout': 0.1, | |
| 'learning_rate': 0.001, | |
| 'batch_size': 64, | |
| 'seq_length': 50, | |
| 'epochs': 10 | |
| } | |
| train_model(config) | |