Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from pathlib import Path | |
| # Add project root to sys.path | |
| sys.path.append(str(Path(__file__).parent.parent)) | |
| import matplotlib # noqa: E402 | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # noqa: E402 | |
| import mlflow # noqa: E402 | |
| import numpy as np # noqa: E402 | |
| import torch # noqa: E402 | |
| import torch.nn as nn # noqa: E402 | |
| import yaml # noqa: E402 | |
| from torch.utils.data import DataLoader # noqa: E402 | |
| from torchvision import transforms # noqa: E402 | |
| from tqdm import tqdm # noqa: E402 | |
| from src.dataset import TrashDataset # noqa: E402 | |
| from src.model import DeepCNN, ResNet18Transfer, SimpleCNN # noqa: E402 | |
| def load_config(config_path="config.yaml"): | |
| with open(config_path, "r") as f: | |
| return yaml.safe_load(f) | |
| def get_device(config_device): | |
| if config_device == "auto": | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| return config_device | |
| class Trainer: | |
| """ | |
| Handles the training and validation lifecycle of a model with MLflow tracking. | |
| """ | |
| def __init__(self, model, train_loader, val_loader, config, model_name): | |
| self.config = config | |
| self.model_name = model_name | |
| self.device = get_device(config["device"]) | |
| self.model = model.to(self.device) | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.epochs = config["epochs"] | |
| self.patience = config.get("patience", 5) | |
| # Handle class imbalance with weights | |
| y_train = np.load("data/processed/y_train.npy") | |
| class_counts = np.bincount(y_train) | |
| weights = 1.0 / class_counts | |
| weights = torch.FloatTensor(weights).to(self.device) | |
| self.criterion = nn.CrossEntropyLoss(weight=weights) | |
| self.optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) | |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| self.optimizer, T_max=self.epochs | |
| ) | |
| self.history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []} | |
| self.checkpoint_path = f"models/{model_name.lower()}_best.pth" | |
| def train_epoch(self): | |
| self.model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| pbar = tqdm(self.train_loader, desc="Training", leave=False) | |
| for images, labels in pbar: | |
| images, labels = images.to(self.device), labels.to(self.device) | |
| self.optimizer.zero_grad() | |
| outputs = self.model(images) | |
| loss = self.criterion(outputs, labels) | |
| loss.backward() | |
| self.optimizer.step() | |
| running_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{correct/total:.4f}"}) | |
| return running_loss / len(self.train_loader), correct / total | |
| def validate(self): | |
| self.model.eval() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for images, labels in self.val_loader: | |
| images, labels = images.to(self.device), labels.to(self.device) | |
| outputs = self.model(images) | |
| loss = self.criterion(outputs, labels) | |
| running_loss += loss.item() | |
| _, predicted = torch.max(outputs, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| return running_loss / len(self.val_loader), correct / total | |
| def plot_history(self): | |
| save_path = f"models/plots/{self.model_name.lower()}_history.png" | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| epochs_range = range(1, len(self.history["train_loss"]) + 1) | |
| plt.figure(figsize=(12, 5)) | |
| # Plot Loss | |
| plt.subplot(1, 2, 1) | |
| plt.plot(epochs_range, self.history["train_loss"], label="Train Loss") | |
| plt.plot(epochs_range, self.history["val_loss"], label="Val Loss") | |
| plt.title(f"{self.model_name} - Loss") | |
| plt.xlabel("Epochs") | |
| plt.ylabel("Loss") | |
| plt.legend() | |
| # Plot Accuracy | |
| plt.subplot(1, 2, 2) | |
| plt.plot(epochs_range, self.history["train_acc"], label="Train Acc") | |
| plt.plot(epochs_range, self.history["val_acc"], label="Val Acc") | |
| plt.title(f"{self.model_name} - Accuracy") | |
| plt.xlabel("Epochs") | |
| plt.ylabel("Accuracy") | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig(save_path) | |
| print(f"--> Training history plot saved to {save_path}") | |
| mlflow.log_artifact(save_path) | |
| def fit(self): | |
| mlflow.set_experiment("Trash Classifier") | |
| with mlflow.start_run(run_name=self.model_name): | |
| mlflow.log_params(self.config) | |
| mlflow.log_param("model_architecture", self.model_name) | |
| print(f"\nStarting training for {self.model_name} on {self.device}...") | |
| best_val_acc = 0.0 | |
| epochs_no_improve = 0 | |
| for epoch in range(self.epochs): | |
| train_loss, train_acc = self.train_epoch() | |
| val_loss, val_acc = self.validate() | |
| self.scheduler.step() | |
| self.history["train_loss"].append(train_loss) | |
| self.history["train_acc"].append(train_acc) | |
| self.history["val_loss"].append(val_loss) | |
| self.history["val_acc"].append(val_acc) | |
| mlflow.log_metric("train_loss", train_loss, step=epoch) | |
| mlflow.log_metric("train_acc", train_acc, step=epoch) | |
| mlflow.log_metric("val_loss", val_loss, step=epoch) | |
| mlflow.log_metric("val_acc", val_acc, step=epoch) | |
| mlflow.log_metric("lr", self.optimizer.param_groups[0]["lr"], step=epoch) | |
| print( | |
| f"Epoch [{epoch + 1}/{self.epochs}] " | |
| f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | " | |
| f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f} | " | |
| f"LR: {self.optimizer.param_groups[0]['lr']:.6f}" | |
| ) | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| epochs_no_improve = 0 | |
| os.makedirs("models", exist_ok=True) | |
| torch.save(self.model.state_dict(), self.checkpoint_path) | |
| print(f"--> Saved best model for {self.model_name} with Val Acc: {val_acc:.4f}") | |
| mlflow.log_artifact(self.checkpoint_path) | |
| else: | |
| epochs_no_improve += 1 | |
| if epochs_no_improve >= self.patience: | |
| print(f"Early stopping triggered after {epoch + 1} epochs.") | |
| break | |
| self.plot_history() | |
| return self.history | |
| if __name__ == "__main__": | |
| config = load_config() | |
| train_transform = transforms.Compose( | |
| [ | |
| transforms.ToPILImage(), | |
| transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| val_transform = transforms.Compose( | |
| [ | |
| transforms.ToPILImage(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| processed_dir = Path("data/processed") | |
| if not (processed_dir / "X_train.npy").exists(): | |
| print("Error: Processed data not found. Please run src/dataset.py first.") | |
| else: | |
| train_ds = TrashDataset( | |
| processed_dir / "X_train.npy", processed_dir / "y_train.npy", transform=train_transform | |
| ) | |
| val_ds = TrashDataset( | |
| processed_dir / "X_val.npy", processed_dir / "y_val.npy", transform=val_transform | |
| ) | |
| train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True) | |
| val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False) | |
| num_classes = len(config["classes"]) | |
| models_to_train = { | |
| "SimpleCNN": SimpleCNN(num_classes=num_classes), | |
| "DeepCNN": DeepCNN(num_classes=num_classes), | |
| "ResNet18": ResNet18Transfer(num_classes=num_classes, pretrained=True), | |
| } | |
| for name, model in models_to_train.items(): | |
| trainer = Trainer(model, train_loader, val_loader, config, name) | |
| trainer.fit() | |
| print("\nAll models trained. Starting comparison...") | |
| from src.comparison import run_comparison | |
| run_comparison() | |