|
|
"""PyTorch Lightning module for training.""" |
|
|
|
|
|
from typing import Dict, Any, Optional |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import pytorch_lightning as pl |
|
|
from torch.optim import Adam |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class NewsClassificationModule(pl.LightningModule): |
|
|
""" |
|
|
PyTorch Lightning module for news classification training. |
|
|
|
|
|
Handles both title-only and title+snippet models. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
learning_rate: float = 1e-3, |
|
|
criterion: Optional[nn.Module] = None, |
|
|
): |
|
|
""" |
|
|
Initialize training module. |
|
|
|
|
|
Args: |
|
|
model: The neural network model to train |
|
|
learning_rate: Learning rate for optimizer |
|
|
criterion: Loss function. If None, uses CrossEntropyLoss |
|
|
|
|
|
Example: |
|
|
>>> model = SimpleClassifier(vocab_size=10000, embedding_dim=300, output_dim=1000) |
|
|
>>> lightning_module = NewsClassificationModule(model, learning_rate=1e-3) |
|
|
""" |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.learning_rate = learning_rate |
|
|
self.criterion = criterion or nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
if hasattr(model, 'use_snippet'): |
|
|
self.use_snippet = model.use_snippet |
|
|
else: |
|
|
|
|
|
sig = inspect.signature(model.forward) |
|
|
self.use_snippet = 'snippet' in sig.parameters |
|
|
|
|
|
logger.info( |
|
|
f"Initialized NewsClassificationModule: " |
|
|
f"lr={learning_rate}, use_snippet={self.use_snippet}" |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
title: torch.Tensor, |
|
|
snippet: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass. |
|
|
|
|
|
Args: |
|
|
title: Title token indices |
|
|
snippet: Optional snippet token indices |
|
|
|
|
|
Returns: |
|
|
Model logits |
|
|
""" |
|
|
if self.use_snippet and snippet is not None: |
|
|
return self.model(title, snippet) |
|
|
else: |
|
|
return self.model(title) |
|
|
|
|
|
def configure_optimizers(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Configure optimizer. |
|
|
|
|
|
Returns: |
|
|
Dictionary with optimizer configuration |
|
|
""" |
|
|
optimizer = Adam(self.parameters(), lr=self.learning_rate) |
|
|
return {"optimizer": optimizer} |
|
|
|
|
|
def training_step( |
|
|
self, |
|
|
train_batch: tuple, |
|
|
batch_idx: int |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Training step. |
|
|
|
|
|
Args: |
|
|
train_batch: Batch of training data |
|
|
batch_idx: Batch index |
|
|
|
|
|
Returns: |
|
|
Loss value |
|
|
""" |
|
|
if self.use_snippet: |
|
|
title, snippet, target = train_batch |
|
|
logits = self.forward(title, snippet) |
|
|
else: |
|
|
title, target = train_batch |
|
|
logits = self.forward(title) |
|
|
|
|
|
loss = self.criterion(logits, target) |
|
|
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True) |
|
|
return loss |
|
|
|
|
|
def validation_step( |
|
|
self, |
|
|
val_batch: tuple, |
|
|
batch_idx: int |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Validation step. |
|
|
|
|
|
Args: |
|
|
val_batch: Batch of validation data |
|
|
batch_idx: Batch index |
|
|
|
|
|
Returns: |
|
|
Loss value |
|
|
""" |
|
|
if self.use_snippet: |
|
|
title, snippet, target = val_batch |
|
|
logits = self.forward(title, snippet) |
|
|
else: |
|
|
title, target = val_batch |
|
|
logits = self.forward(title) |
|
|
|
|
|
loss = self.criterion(logits, target) |
|
|
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) |
|
|
return loss |
|
|
|
|
|
|