multilabel-news-classifier / models /lightning_module_tracking.py
Solareva Taisia
chore(release): initial public snapshot
198ccb0
"""PyTorch Lightning module with enhanced experiment tracking."""
from typing import Dict, Any, Optional
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger, MLFlowLogger
from utils.experiment_tracking import WandBTracker, MLflowTracker, ExperimentTracker
import logging
logger = logging.getLogger(__name__)
class WandBCallback(Callback):
"""Enhanced WandB callback for PyTorch Lightning."""
def __init__(self, log_model: bool = True, log_artifacts: bool = True):
"""
Initialize WandB callback.
Args:
log_model: Whether to log model checkpoints
log_artifacts: Whether to log artifacts
"""
super().__init__()
self.log_model = log_model
self.log_artifacts = log_artifacts
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Log metrics at end of training epoch."""
metrics = {f"train/{k}": v for k, v in trainer.callback_metrics.items()}
if hasattr(trainer, 'logger') and isinstance(trainer.logger, WandbLogger):
trainer.logger.experiment.log(metrics)
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Log metrics at end of validation epoch."""
metrics = {f"val/{k}": v for k, v in trainer.callback_metrics.items()}
if hasattr(trainer, 'logger') and isinstance(trainer.logger, WandbLogger):
trainer.logger.experiment.log(metrics)
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Log artifacts at end of training."""
if self.log_artifacts and hasattr(trainer, 'logger'):
if isinstance(trainer.logger, WandbLogger):
# Log best model
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path:
trainer.logger.experiment.log_artifact(
trainer.checkpoint_callback.best_model_path,
name="best_model"
)
class MLflowCallback(Callback):
"""MLflow callback for PyTorch Lightning."""
def __init__(self, log_model: bool = True):
"""
Initialize MLflow callback.
Args:
log_model: Whether to log model
"""
super().__init__()
self.log_model = log_model
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Log metrics at end of training epoch."""
if hasattr(trainer, 'logger') and isinstance(trainer.logger, MLFlowLogger):
metrics = {f"train_{k}": v for k, v in trainer.callback_metrics.items()}
trainer.logger.experiment.log_metrics(metrics, step=trainer.current_epoch)
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Log metrics at end of validation epoch."""
if hasattr(trainer, 'logger') and isinstance(trainer.logger, MLFlowLogger):
metrics = {f"val_{k}": v for k, v in trainer.callback_metrics.items()}
trainer.logger.experiment.log_metrics(metrics, step=trainer.current_epoch)
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Log model at end of training."""
if self.log_model and hasattr(trainer, 'logger'):
if isinstance(trainer.logger, MLFlowLogger):
# Log model
trainer.logger.experiment.log_model(
pl_module.model,
artifact_path="model"
)
def create_tracking_loggers(
use_wandb: bool = True,
use_mlflow: bool = True,
project_name: str = "russian-news-classification",
experiment_name: Optional[str] = None,
**kwargs
) -> tuple[list, list]:
"""
Create tracking loggers and callbacks.
Args:
use_wandb: Enable WandB
use_mlflow: Enable MLflow
project_name: Project name
experiment_name: Experiment name
**kwargs: Additional arguments
Returns:
Tuple of (loggers, callbacks)
"""
loggers = []
callbacks = []
if use_wandb:
try:
wandb_logger = WandbLogger(
project=project_name,
name=experiment_name,
**kwargs.get('wandb', {})
)
loggers.append(wandb_logger)
callbacks.append(WandBCallback())
logger.info("WandB logger created")
except Exception as e:
logger.warning(f"Failed to create WandB logger: {e}")
if use_mlflow:
try:
mlflow_logger = MLFlowLogger(
experiment_name=experiment_name or project_name,
tracking_uri=kwargs.get('mlflow', {}).get('tracking_uri'),
)
loggers.append(mlflow_logger)
callbacks.append(MLflowCallback())
logger.info("MLflow logger created")
except Exception as e:
logger.warning(f"Failed to create MLflow logger: {e}")
return loggers, callbacks