import os import sys from pathlib import Path # Add project root to sys.path sys.path.append(str(Path(__file__).parent.parent)) import matplotlib # noqa: E402 matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 import mlflow # noqa: E402 import numpy as np # noqa: E402 import torch # noqa: E402 import torch.nn as nn # noqa: E402 import yaml # noqa: E402 from torch.utils.data import DataLoader # noqa: E402 from torchvision import transforms # noqa: E402 from tqdm import tqdm # noqa: E402 from src.dataset import TrashDataset # noqa: E402 from src.model import DeepCNN, ResNet18Transfer, SimpleCNN # noqa: E402 def load_config(config_path="config.yaml"): with open(config_path, "r") as f: return yaml.safe_load(f) def get_device(config_device): if config_device == "auto": return "cuda" if torch.cuda.is_available() else "cpu" return config_device class Trainer: """ Handles the training and validation lifecycle of a model with MLflow tracking. """ def __init__(self, model, train_loader, val_loader, config, model_name): self.config = config self.model_name = model_name self.device = get_device(config["device"]) self.model = model.to(self.device) self.train_loader = train_loader self.val_loader = val_loader self.epochs = config["epochs"] self.patience = config.get("patience", 5) # Handle class imbalance with weights y_train = np.load("data/processed/y_train.npy") class_counts = np.bincount(y_train) weights = 1.0 / class_counts weights = torch.FloatTensor(weights).to(self.device) self.criterion = nn.CrossEntropyLoss(weight=weights) self.optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=self.epochs ) self.history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []} self.checkpoint_path = f"models/{model_name.lower()}_best.pth" def train_epoch(self): self.model.train() running_loss = 0.0 correct = 0 total = 0 pbar = tqdm(self.train_loader, desc="Training", leave=False) for images, labels in pbar: images, labels = images.to(self.device), labels.to(self.device) self.optimizer.zero_grad() outputs = self.model(images) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{correct/total:.4f}"}) return running_loss / len(self.train_loader), correct / total def validate(self): self.model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for images, labels in self.val_loader: images, labels = images.to(self.device), labels.to(self.device) outputs = self.model(images) loss = self.criterion(outputs, labels) running_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() return running_loss / len(self.val_loader), correct / total def plot_history(self): save_path = f"models/plots/{self.model_name.lower()}_history.png" os.makedirs(os.path.dirname(save_path), exist_ok=True) epochs_range = range(1, len(self.history["train_loss"]) + 1) plt.figure(figsize=(12, 5)) # Plot Loss plt.subplot(1, 2, 1) plt.plot(epochs_range, self.history["train_loss"], label="Train Loss") plt.plot(epochs_range, self.history["val_loss"], label="Val Loss") plt.title(f"{self.model_name} - Loss") plt.xlabel("Epochs") plt.ylabel("Loss") plt.legend() # Plot Accuracy plt.subplot(1, 2, 2) plt.plot(epochs_range, self.history["train_acc"], label="Train Acc") plt.plot(epochs_range, self.history["val_acc"], label="Val Acc") plt.title(f"{self.model_name} - Accuracy") plt.xlabel("Epochs") plt.ylabel("Accuracy") plt.legend() plt.tight_layout() plt.savefig(save_path) print(f"--> Training history plot saved to {save_path}") mlflow.log_artifact(save_path) def fit(self): mlflow.set_experiment("Trash Classifier") with mlflow.start_run(run_name=self.model_name): mlflow.log_params(self.config) mlflow.log_param("model_architecture", self.model_name) print(f"\nStarting training for {self.model_name} on {self.device}...") best_val_acc = 0.0 epochs_no_improve = 0 for epoch in range(self.epochs): train_loss, train_acc = self.train_epoch() val_loss, val_acc = self.validate() self.scheduler.step() self.history["train_loss"].append(train_loss) self.history["train_acc"].append(train_acc) self.history["val_loss"].append(val_loss) self.history["val_acc"].append(val_acc) mlflow.log_metric("train_loss", train_loss, step=epoch) mlflow.log_metric("train_acc", train_acc, step=epoch) mlflow.log_metric("val_loss", val_loss, step=epoch) mlflow.log_metric("val_acc", val_acc, step=epoch) mlflow.log_metric("lr", self.optimizer.param_groups[0]["lr"], step=epoch) print( f"Epoch [{epoch + 1}/{self.epochs}] " f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | " f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f} | " f"LR: {self.optimizer.param_groups[0]['lr']:.6f}" ) if val_acc > best_val_acc: best_val_acc = val_acc epochs_no_improve = 0 os.makedirs("models", exist_ok=True) torch.save(self.model.state_dict(), self.checkpoint_path) print(f"--> Saved best model for {self.model_name} with Val Acc: {val_acc:.4f}") mlflow.log_artifact(self.checkpoint_path) else: epochs_no_improve += 1 if epochs_no_improve >= self.patience: print(f"Early stopping triggered after {epoch + 1} epochs.") break self.plot_history() return self.history if __name__ == "__main__": config = load_config() train_transform = transforms.Compose( [ transforms.ToPILImage(), transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) val_transform = transforms.Compose( [ transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) processed_dir = Path("data/processed") if not (processed_dir / "X_train.npy").exists(): print("Error: Processed data not found. Please run src/dataset.py first.") else: train_ds = TrashDataset( processed_dir / "X_train.npy", processed_dir / "y_train.npy", transform=train_transform ) val_ds = TrashDataset( processed_dir / "X_val.npy", processed_dir / "y_val.npy", transform=val_transform ) train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True) val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False) num_classes = len(config["classes"]) models_to_train = { "SimpleCNN": SimpleCNN(num_classes=num_classes), "DeepCNN": DeepCNN(num_classes=num_classes), "ResNet18": ResNet18Transfer(num_classes=num_classes, pretrained=True), } for name, model in models_to_train.items(): trainer = Trainer(model, train_loader, val_loader, config, name) trainer.fit() print("\nAll models trained. Starting comparison...") from src.comparison import run_comparison run_comparison()