|
|
"""PyTorch Lightning module with enhanced experiment tracking.""" |
|
|
|
|
|
from typing import Dict, Any, Optional |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import pytorch_lightning as pl |
|
|
from pytorch_lightning.callbacks import Callback |
|
|
from pytorch_lightning.loggers import WandbLogger, MLFlowLogger |
|
|
|
|
|
from utils.experiment_tracking import WandBTracker, MLflowTracker, ExperimentTracker |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class WandBCallback(Callback): |
|
|
"""Enhanced WandB callback for PyTorch Lightning.""" |
|
|
|
|
|
def __init__(self, log_model: bool = True, log_artifacts: bool = True): |
|
|
""" |
|
|
Initialize WandB callback. |
|
|
|
|
|
Args: |
|
|
log_model: Whether to log model checkpoints |
|
|
log_artifacts: Whether to log artifacts |
|
|
""" |
|
|
super().__init__() |
|
|
self.log_model = log_model |
|
|
self.log_artifacts = log_artifacts |
|
|
|
|
|
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: |
|
|
"""Log metrics at end of training epoch.""" |
|
|
metrics = {f"train/{k}": v for k, v in trainer.callback_metrics.items()} |
|
|
if hasattr(trainer, 'logger') and isinstance(trainer.logger, WandbLogger): |
|
|
trainer.logger.experiment.log(metrics) |
|
|
|
|
|
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: |
|
|
"""Log metrics at end of validation epoch.""" |
|
|
metrics = {f"val/{k}": v for k, v in trainer.callback_metrics.items()} |
|
|
if hasattr(trainer, 'logger') and isinstance(trainer.logger, WandbLogger): |
|
|
trainer.logger.experiment.log(metrics) |
|
|
|
|
|
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: |
|
|
"""Log artifacts at end of training.""" |
|
|
if self.log_artifacts and hasattr(trainer, 'logger'): |
|
|
if isinstance(trainer.logger, WandbLogger): |
|
|
|
|
|
if trainer.checkpoint_callback and trainer.checkpoint_callback.best_model_path: |
|
|
trainer.logger.experiment.log_artifact( |
|
|
trainer.checkpoint_callback.best_model_path, |
|
|
name="best_model" |
|
|
) |
|
|
|
|
|
|
|
|
class MLflowCallback(Callback): |
|
|
"""MLflow callback for PyTorch Lightning.""" |
|
|
|
|
|
def __init__(self, log_model: bool = True): |
|
|
""" |
|
|
Initialize MLflow callback. |
|
|
|
|
|
Args: |
|
|
log_model: Whether to log model |
|
|
""" |
|
|
super().__init__() |
|
|
self.log_model = log_model |
|
|
|
|
|
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: |
|
|
"""Log metrics at end of training epoch.""" |
|
|
if hasattr(trainer, 'logger') and isinstance(trainer.logger, MLFlowLogger): |
|
|
metrics = {f"train_{k}": v for k, v in trainer.callback_metrics.items()} |
|
|
trainer.logger.experiment.log_metrics(metrics, step=trainer.current_epoch) |
|
|
|
|
|
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: |
|
|
"""Log metrics at end of validation epoch.""" |
|
|
if hasattr(trainer, 'logger') and isinstance(trainer.logger, MLFlowLogger): |
|
|
metrics = {f"val_{k}": v for k, v in trainer.callback_metrics.items()} |
|
|
trainer.logger.experiment.log_metrics(metrics, step=trainer.current_epoch) |
|
|
|
|
|
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: |
|
|
"""Log model at end of training.""" |
|
|
if self.log_model and hasattr(trainer, 'logger'): |
|
|
if isinstance(trainer.logger, MLFlowLogger): |
|
|
|
|
|
trainer.logger.experiment.log_model( |
|
|
pl_module.model, |
|
|
artifact_path="model" |
|
|
) |
|
|
|
|
|
|
|
|
def create_tracking_loggers( |
|
|
use_wandb: bool = True, |
|
|
use_mlflow: bool = True, |
|
|
project_name: str = "russian-news-classification", |
|
|
experiment_name: Optional[str] = None, |
|
|
**kwargs |
|
|
) -> tuple[list, list]: |
|
|
""" |
|
|
Create tracking loggers and callbacks. |
|
|
|
|
|
Args: |
|
|
use_wandb: Enable WandB |
|
|
use_mlflow: Enable MLflow |
|
|
project_name: Project name |
|
|
experiment_name: Experiment name |
|
|
**kwargs: Additional arguments |
|
|
|
|
|
Returns: |
|
|
Tuple of (loggers, callbacks) |
|
|
""" |
|
|
loggers = [] |
|
|
callbacks = [] |
|
|
|
|
|
if use_wandb: |
|
|
try: |
|
|
wandb_logger = WandbLogger( |
|
|
project=project_name, |
|
|
name=experiment_name, |
|
|
**kwargs.get('wandb', {}) |
|
|
) |
|
|
loggers.append(wandb_logger) |
|
|
callbacks.append(WandBCallback()) |
|
|
logger.info("WandB logger created") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to create WandB logger: {e}") |
|
|
|
|
|
if use_mlflow: |
|
|
try: |
|
|
mlflow_logger = MLFlowLogger( |
|
|
experiment_name=experiment_name or project_name, |
|
|
tracking_uri=kwargs.get('mlflow', {}).get('tracking_uri'), |
|
|
) |
|
|
loggers.append(mlflow_logger) |
|
|
callbacks.append(MLflowCallback()) |
|
|
logger.info("MLflow logger created") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to create MLflow logger: {e}") |
|
|
|
|
|
return loggers, callbacks |
|
|
|
|
|
|