Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from tqdm import tqdm | |
| from utils.config import config | |
| from utils.data_loader import get_data_loaders | |
| from models.encoder import Encoder | |
| from models.decoder import Decoder | |
| from models.seq2seq import Seq2Seq | |
| def init_weights(m): | |
| for name, param in m.named_parameters(): | |
| if 'weight' in name: | |
| nn.init.normal_(param.data, mean=0, std=0.01) | |
| else: | |
| nn.init.constant_(param.data, 0) | |
| def train(): | |
| train_loader, val_loader, eng_vocab, hin_vocab = get_data_loaders() | |
| print(f"Final English vocab size: {len(eng_vocab)}") | |
| print(f"Final Hindi vocab size: {len(hin_vocab)}") | |
| # Model initialization | |
| enc = Encoder( | |
| len(eng_vocab), | |
| config.embedding_dim, | |
| config.hidden_size, | |
| config.num_layers, | |
| config.dropout | |
| ).to(config.device) | |
| dec = Decoder( | |
| len(hin_vocab), | |
| config.embedding_dim, | |
| config.hidden_size, | |
| config.num_layers, | |
| config.dropout | |
| ).to(config.device) | |
| model = Seq2Seq(enc, dec, config.device).to(config.device) | |
| model.apply(init_weights) | |
| # Optimizer and loss | |
| optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) | |
| criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore padding | |
| # Training loop | |
| for epoch in range(config.epochs): | |
| model.train() | |
| epoch_loss = 0 | |
| for src, trg in tqdm(train_loader, desc=f"Epoch {epoch+1}"): | |
| src, trg = src.to(config.device), trg.to(config.device) | |
| optimizer.zero_grad() | |
| output = model(src, trg, config.teacher_forcing_ratio) | |
| output_dim = output.shape[-1] | |
| output = output[:, 1:].reshape(-1, output_dim) | |
| trg = trg[:, 1:].reshape(-1) | |
| loss = criterion(output, trg) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1) | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| avg_loss = epoch_loss / len(train_loader) | |
| print(f"Epoch: {epoch+1}, Loss: {avg_loss:.4f}") | |
| # Save model | |
| torch.save(model.state_dict(), f"seq2seq_epoch_{epoch+1}.pth") | |
| if __name__ == "__main__": | |
| train() |