Spaces:
Sleeping
Sleeping
| import wandb | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from typing import Dict, Any, Optional | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from sklearn.metrics import confusion_matrix | |
| from utils.data_loader import get_cifar10_info | |
| class WandbLogger: | |
| """Minimal yet powerful W&B integration for FAANG-level ML projects.""" | |
| def __init__(self, project: str = "cifar10-benchmark", entity: Optional[str] = None): | |
| self.project = project | |
| self.entity = entity | |
| self.run = None | |
| def init_experiment(self, config: Dict[str, Any], model: nn.Module, model_name: str): | |
| """Initialize W&B run with model architecture tracking.""" | |
| # Auto-detect model stats for config | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| enhanced_config = { | |
| **config, | |
| 'model_name': model_name, | |
| 'total_params': total_params, | |
| 'trainable_params': trainable_params, | |
| 'model_size_mb': total_params * 4 / (1024 ** 2), | |
| 'architecture': str(model.__class__.__name__) | |
| } | |
| self.run = wandb.init( | |
| project=self.project, | |
| entity=self.entity, | |
| config=enhanced_config, | |
| name=f"{model_name}-{wandb.util.generate_id()}" | |
| ) | |
| # Log model architecture | |
| wandb.watch(model, log_freq=100, log_graph=True) | |
| return self.run | |
| def log_metrics(self, metrics: Dict[str, float], step: int): | |
| """Log training metrics with automatic prefixing.""" | |
| wandb.log(metrics, step=step) | |
| def log_confusion_matrix(self, y_true: np.ndarray, y_pred: np.ndarray, epoch: int): | |
| """Log confusion matrix as W&B image.""" | |
| cifar10_info = get_cifar10_info() | |
| cm = confusion_matrix(y_true, y_pred) | |
| plt.figure(figsize=(8, 6)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=cifar10_info['class_names'], | |
| yticklabels=cifar10_info['class_names']) | |
| plt.title(f'Confusion Matrix - Epoch {epoch}') | |
| plt.tight_layout() | |
| wandb.log({ | |
| "confusion_matrix": wandb.Image(plt), | |
| "epoch": epoch | |
| }) | |
| plt.close() | |
| def log_model_checkpoint(self, model: nn.Module, optimizer, epoch: int, | |
| metrics: Dict[str, float], is_best: bool = False): | |
| """Log model checkpoint with metadata.""" | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| **metrics | |
| } | |
| filename = f"model_epoch_{epoch}.pth" | |
| torch.save(checkpoint, filename) | |
| artifact = wandb.Artifact( | |
| name=f"model-{self.run.id}", | |
| type="model", | |
| metadata={"epoch": epoch, "is_best": is_best, **metrics} | |
| ) | |
| artifact.add_file(filename) | |
| wandb.log_artifact(artifact) | |
| def finish(self): | |
| """Cleanup W&B run.""" | |
| if self.run: | |
| wandb.finish() | |
| def create_hyperparameter_sweep(): | |
| """FAANG-level hyperparameter sweep configuration.""" | |
| return { | |
| 'method': 'bayes', | |
| 'metric': {'name': 'val_accuracy', 'goal': 'maximize'}, | |
| 'parameters': { | |
| 'learning_rate': {'distribution': 'log_uniform', 'min': 1e-5, 'max': 1e-2}, | |
| 'batch_size': {'values': [32, 64, 128]}, | |
| 'weight_decay': {'distribution': 'log_uniform', 'min': 1e-6, 'max': 1e-3}, | |
| 'optimizer': {'values': ['adamw', 'sgd']}, | |
| 'scheduler': {'values': ['cosine', 'step']}, | |
| 'dropout_rate': {'distribution': 'uniform', 'min': 0.1, 'max': 0.5} | |
| } | |
| } | |
| def run_hyperparameter_sweep(train_fn, sweep_config: Dict[str, Any], count: int = 20): | |
| """Execute hyperparameter sweep with W&B.""" | |
| sweep_id = wandb.sweep(sweep_config, project="cifar10-benchmark") | |
| wandb.agent(sweep_id, train_fn, count=count) | |
| # Integration with existing training loop | |
| def enhanced_train_step(model, train_loader, val_loader, optimizer, criterion, | |
| scheduler, num_epochs, device, logger: WandbLogger): | |
| """Enhanced training with W&B logging.""" | |
| model.to(device) | |
| best_val_acc = 0.0 | |
| for epoch in range(num_epochs): | |
| # Training | |
| model.train() | |
| train_loss, train_acc = 0.0, 0.0 | |
| for batch_idx, (inputs, targets) in enumerate(train_loader): | |
| inputs, targets = inputs.to(device), targets.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, targets) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| train_acc += (outputs.argmax(1) == targets).float().mean().item() | |
| # Validation | |
| model.eval() | |
| val_loss, val_acc = 0.0, 0.0 | |
| all_preds, all_targets = [], [] | |
| with torch.no_grad(): | |
| for inputs, targets in val_loader: | |
| inputs, targets = inputs.to(device), targets.to(device) | |
| outputs = model(inputs) | |
| loss = criterion(outputs, targets) | |
| val_loss += loss.item() | |
| val_acc += (outputs.argmax(1) == targets).float().mean().item() | |
| all_preds.extend(outputs.argmax(1).cpu().numpy()) | |
| all_targets.extend(targets.cpu().numpy()) | |
| # Normalize metrics | |
| train_loss /= len(train_loader) | |
| train_acc /= len(train_loader) | |
| val_loss /= len(val_loader) | |
| val_acc /= len(val_loader) | |
| scheduler.step() | |
| # Log to W&B | |
| logger.log_metrics({ | |
| 'epoch': epoch, | |
| 'train_loss': train_loss, | |
| 'train_accuracy': train_acc * 100, | |
| 'val_loss': val_loss, | |
| 'val_accuracy': val_acc * 100, | |
| 'learning_rate': optimizer.param_groups[0]['lr'] | |
| }, step=epoch) | |
| # Log confusion matrix every 10 epochs | |
| if (epoch + 1) % 10 == 0: | |
| logger.log_confusion_matrix(all_targets, all_preds, epoch) | |
| # Save best model | |
| is_best = val_acc > best_val_acc | |
| if is_best: | |
| best_val_acc = val_acc | |
| logger.log_model_checkpoint( | |
| model, optimizer, epoch, | |
| {'val_accuracy': val_acc, 'val_loss': val_loss}, | |
| is_best=True | |
| ) | |
| print(f"Epoch {epoch+1}/{num_epochs} | " | |
| f"Train: {train_loss:.4f}/{train_acc:.3f} | " | |
| f"Val: {val_loss:.4f}/{val_acc:.3f}") | |
| return best_val_acc |