File size: 3,985 Bytes
198ccb0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """PyTorch Lightning module for training."""
from typing import Dict, Any, Optional
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.optim import Adam
import logging
logger = logging.getLogger(__name__)
class NewsClassificationModule(pl.LightningModule):
"""
PyTorch Lightning module for news classification training.
Handles both title-only and title+snippet models.
"""
def __init__(
self,
model: nn.Module,
learning_rate: float = 1e-3,
criterion: Optional[nn.Module] = None,
):
"""
Initialize training module.
Args:
model: The neural network model to train
learning_rate: Learning rate for optimizer
criterion: Loss function. If None, uses CrossEntropyLoss
Example:
>>> model = SimpleClassifier(vocab_size=10000, embedding_dim=300, output_dim=1000)
>>> lightning_module = NewsClassificationModule(model, learning_rate=1e-3)
"""
super().__init__()
self.model = model
self.learning_rate = learning_rate
self.criterion = criterion or nn.CrossEntropyLoss()
# Detect if model uses snippets
# Check if model has use_snippet attribute or if forward() accepts snippet parameter
import inspect
if hasattr(model, 'use_snippet'):
self.use_snippet = model.use_snippet
else:
# Check forward signature for snippet parameter
sig = inspect.signature(model.forward)
self.use_snippet = 'snippet' in sig.parameters
logger.info(
f"Initialized NewsClassificationModule: "
f"lr={learning_rate}, use_snippet={self.use_snippet}"
)
def forward(
self,
title: torch.Tensor,
snippet: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass.
Args:
title: Title token indices
snippet: Optional snippet token indices
Returns:
Model logits
"""
if self.use_snippet and snippet is not None:
return self.model(title, snippet)
else:
return self.model(title)
def configure_optimizers(self) -> Dict[str, Any]:
"""
Configure optimizer.
Returns:
Dictionary with optimizer configuration
"""
optimizer = Adam(self.parameters(), lr=self.learning_rate)
return {"optimizer": optimizer}
def training_step(
self,
train_batch: tuple,
batch_idx: int
) -> torch.Tensor:
"""
Training step.
Args:
train_batch: Batch of training data
batch_idx: Batch index
Returns:
Loss value
"""
if self.use_snippet:
title, snippet, target = train_batch
logits = self.forward(title, snippet)
else:
title, target = train_batch
logits = self.forward(title)
loss = self.criterion(logits, target)
self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
return loss
def validation_step(
self,
val_batch: tuple,
batch_idx: int
) -> torch.Tensor:
"""
Validation step.
Args:
val_batch: Batch of validation data
batch_idx: Batch index
Returns:
Loss value
"""
if self.use_snippet:
title, snippet, target = val_batch
logits = self.forward(title, snippet)
else:
title, target = val_batch
logits = self.forward(title)
loss = self.criterion(logits, target)
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
|