|
|
""" |
|
|
Training script for NLP models. |
|
|
|
|
|
This module contains the main training loop and model training functions. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader |
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def train_model(model, train_loader, val_loader, epochs=10, lr=0.001): |
|
|
""" |
|
|
Train the NLP model. |
|
|
|
|
|
Args: |
|
|
model: The neural network model |
|
|
train_loader: Training data loader |
|
|
val_loader: Validation data loader |
|
|
epochs: Number of training epochs |
|
|
lr: Learning rate |
|
|
""" |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
optimizer = optim.Adam(model.parameters(), lr=lr) |
|
|
|
|
|
for epoch in range(epochs): |
|
|
model.train() |
|
|
train_loss = 0.0 |
|
|
|
|
|
for batch_idx, (data, target) in enumerate(train_loader): |
|
|
optimizer.zero_grad() |
|
|
output = model(data) |
|
|
loss = criterion(output, target) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
train_loss += loss.item() |
|
|
|
|
|
if batch_idx % 100 == 0: |
|
|
logger.info(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}') |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_loss = 0.0 |
|
|
correct = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for data, target in val_loader: |
|
|
output = model(data) |
|
|
val_loss += criterion(output, target).item() |
|
|
pred = output.argmax(dim=1, keepdim=True) |
|
|
correct += pred.eq(target.view_as(pred)).sum().item() |
|
|
|
|
|
val_accuracy = correct / len(val_loader.dataset) |
|
|
logger.info(f'Epoch {epoch}: Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}') |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main training function.""" |
|
|
parser = argparse.ArgumentParser(description='Train NLP Model') |
|
|
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs') |
|
|
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') |
|
|
parser.add_argument('--batch_size', type=int, default=32, help='Batch size') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Starting training...") |
|
|
|
|
|
logger.info("Training completed!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |