mnist-digit-classifier / scripts /mlflow_setup.py
faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
MLflow Setup and Configuration
Utilities for MLflow experiment tracking with MLOps best practices:
- Automatic experiment naming and organization
- Parameter and metric logging
- Model registry integration
- Artifact tracking
"""
import mlflow
from pathlib import Path
from typing import Optional, Dict, Any
# MLflow configuration
MLFLOW_TRACKING_URI = "file:./mlruns"
DEFAULT_EXPERIMENT_NAME = "mnist-digit-classification"
def setup_mlflow(
experiment_name: str = DEFAULT_EXPERIMENT_NAME,
tracking_uri: Optional[str] = None
) -> str:
"""
Setup MLflow tracking with best practices.
Args:
experiment_name: Name of the experiment
tracking_uri: MLflow tracking URI (default: local ./mlruns)
Returns:
experiment_id: MLflow experiment ID
"""
# Set tracking URI
if tracking_uri is None:
tracking_uri = MLFLOW_TRACKING_URI
mlflow.set_tracking_uri(tracking_uri)
# Create or get experiment
try:
experiment = mlflow.get_experiment_by_name(experiment_name)
if experiment is None:
experiment_id = mlflow.create_experiment(
experiment_name,
tags={
"project": "mnist-classification",
"framework": "pytorch",
"model_type": "cnn"
}
)
else:
experiment_id = experiment.experiment_id
except Exception as e:
print(f"Warning: Could not create experiment: {e}")
experiment_id = "0" # Default experiment
mlflow.set_experiment(experiment_name)
print(f"MLflow tracking URI: {tracking_uri}")
print(f"Experiment: {experiment_name} (ID: {experiment_id})")
return experiment_id
def log_model_params(model: Any, prefix: str = "model") -> Dict[str, Any]:
"""
Log model parameters to MLflow.
Args:
model: PyTorch model
prefix: Prefix for parameter names
Returns:
Dictionary of logged parameters
"""
from scripts.models import count_parameters
params = {
f"{prefix}_name": model.__class__.__name__,
f"{prefix}_total_params": count_parameters(model),
f"{prefix}_trainable_params": sum(
p.numel() for p in model.parameters() if p.requires_grad
)
}
mlflow.log_params(params)
return params
def log_training_config(config: Dict[str, Any]) -> None:
"""
Log training configuration to MLflow.
Args:
config: Dictionary of training hyperparameters
"""
# Flatten nested config if needed
flat_config = {}
for key, value in config.items():
if isinstance(value, dict):
for subkey, subvalue in value.items():
flat_config[f"{key}_{subkey}"] = subvalue
else:
flat_config[key] = value
mlflow.log_params(flat_config)
def log_data_info(
train_size: int,
val_size: int,
test_size: int,
num_classes: int = 10,
augmentation: bool = False
) -> None:
"""
Log dataset information to MLflow.
Args:
train_size: Number of training samples
val_size: Number of validation samples
test_size: Number of test samples
num_classes: Number of classes
augmentation: Whether data augmentation is used
"""
mlflow.log_params({
"data_train_size": train_size,
"data_val_size": val_size,
"data_test_size": test_size,
"data_num_classes": num_classes,
"data_augmentation": augmentation
})
def log_system_info() -> Dict[str, Any]:
"""
Log system information to MLflow.
Returns:
Dictionary of system information
"""
import torch
import platform
system_info = {
"system_platform": platform.system(),
"system_python_version": platform.python_version(),
"system_pytorch_version": torch.__version__,
"system_cuda_available": torch.cuda.is_available(),
"system_cuda_version": (
torch.version.cuda if torch.cuda.is_available() else "N/A"
),
"system_device": "cuda" if torch.cuda.is_available() else "cpu"
}
if torch.cuda.is_available():
system_info["system_gpu_name"] = torch.cuda.get_device_name(0)
system_info["system_gpu_count"] = torch.cuda.device_count()
mlflow.log_params(system_info)
return system_info
def log_metrics_epoch(metrics: Dict[str, float], step: int) -> None:
"""
Log metrics for a specific epoch.
Args:
metrics: Dictionary of metric names and values
step: Epoch number
"""
mlflow.log_metrics(metrics, step=step)
def log_artifact_path(path: str, artifact_path: Optional[str] = None) -> None:
"""
Log a file or directory as an artifact.
Args:
path: Path to file or directory
artifact_path: Optional artifact path in MLflow
"""
if Path(path).exists():
mlflow.log_artifact(path, artifact_path=artifact_path)
else:
print(f"Warning: Artifact not found: {path}")
def log_model_to_registry(
model: Any,
model_name: str,
artifact_path: str = "model",
registered_model_name: Optional[str] = None
) -> None:
"""
Log model to MLflow with model registry integration.
Args:
model: PyTorch model
model_name: Name for the model artifact
artifact_path: Artifact path in MLflow
registered_model_name: Name for model registry (optional)
"""
# Log model
mlflow.pytorch.log_model(
pytorch_model=model,
artifact_path=artifact_path,
registered_model_name=registered_model_name
)
def get_or_create_run(
run_name: Optional[str] = None,
tags: Optional[Dict[str, str]] = None
) -> mlflow.ActiveRun:
"""
Get existing run or create a new one.
Args:
run_name: Name for the run
tags: Tags for the run
Returns:
MLflow active run context
"""
return mlflow.start_run(run_name=run_name, tags=tags)
def end_run() -> None:
"""End the current MLflow run."""
mlflow.end_run()
def test_mlflow_setup():
"""Test MLflow setup and basic logging."""
print("Testing MLflow Setup")
print("=" * 50)
# Setup MLflow
setup_mlflow("test-experiment")
# Test logging
with mlflow.start_run(run_name="test-run"):
# Log parameters
mlflow.log_params({
"learning_rate": 0.001,
"batch_size": 64,
"epochs": 10
})
# Log metrics
for epoch in range(3):
mlflow.log_metrics({
"train_loss": 0.5 - epoch * 0.1,
"val_loss": 0.6 - epoch * 0.1,
"train_accuracy": 0.8 + epoch * 0.05,
"val_accuracy": 0.75 + epoch * 0.05
}, step=epoch)
# Log system info
system_info = log_system_info()
print("\nSystem Info:")
for key, value in system_info.items():
print(f" {key}: {value}")
print("\n✓ MLflow test complete!")
print(f"View results at: mlflow ui --backend-store-uri {MLFLOW_TRACKING_URI}")
if __name__ == "__main__":
test_mlflow_setup()