trash-classifier / src /train.py
alshami-dev's picture
First Update to the App
0b86da8 verified
Raw
History Blame Contribute Delete
9.15 kB
import os
import sys
from pathlib import Path
# Add project root to sys.path
sys.path.append(str(Path(__file__).parent.parent))
import matplotlib # noqa: E402
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
import mlflow # noqa: E402
import numpy as np # noqa: E402
import torch # noqa: E402
import torch.nn as nn # noqa: E402
import yaml # noqa: E402
from torch.utils.data import DataLoader # noqa: E402
from torchvision import transforms # noqa: E402
from tqdm import tqdm # noqa: E402
from src.dataset import TrashDataset # noqa: E402
from src.model import DeepCNN, ResNet18Transfer, SimpleCNN # noqa: E402
def load_config(config_path="config.yaml"):
with open(config_path, "r") as f:
return yaml.safe_load(f)
def get_device(config_device):
if config_device == "auto":
return "cuda" if torch.cuda.is_available() else "cpu"
return config_device
class Trainer:
"""
Handles the training and validation lifecycle of a model with MLflow tracking.
"""
def __init__(self, model, train_loader, val_loader, config, model_name):
self.config = config
self.model_name = model_name
self.device = get_device(config["device"])
self.model = model.to(self.device)
self.train_loader = train_loader
self.val_loader = val_loader
self.epochs = config["epochs"]
self.patience = config.get("patience", 5)
# Handle class imbalance with weights
y_train = np.load("data/processed/y_train.npy")
class_counts = np.bincount(y_train)
weights = 1.0 / class_counts
weights = torch.FloatTensor(weights).to(self.device)
self.criterion = nn.CrossEntropyLoss(weight=weights)
self.optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=self.epochs
)
self.history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
self.checkpoint_path = f"models/{model_name.lower()}_best.pth"
def train_epoch(self):
self.model.train()
running_loss = 0.0
correct = 0
total = 0
pbar = tqdm(self.train_loader, desc="Training", leave=False)
for images, labels in pbar:
images, labels = images.to(self.device), labels.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(images)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{correct/total:.4f}"})
return running_loss / len(self.train_loader), correct / total
def validate(self):
self.model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in self.val_loader:
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.model(images)
loss = self.criterion(outputs, labels)
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return running_loss / len(self.val_loader), correct / total
def plot_history(self):
save_path = f"models/plots/{self.model_name.lower()}_history.png"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
epochs_range = range(1, len(self.history["train_loss"]) + 1)
plt.figure(figsize=(12, 5))
# Plot Loss
plt.subplot(1, 2, 1)
plt.plot(epochs_range, self.history["train_loss"], label="Train Loss")
plt.plot(epochs_range, self.history["val_loss"], label="Val Loss")
plt.title(f"{self.model_name} - Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
# Plot Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs_range, self.history["train_acc"], label="Train Acc")
plt.plot(epochs_range, self.history["val_acc"], label="Val Acc")
plt.title(f"{self.model_name} - Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.tight_layout()
plt.savefig(save_path)
print(f"--> Training history plot saved to {save_path}")
mlflow.log_artifact(save_path)
def fit(self):
mlflow.set_experiment("Trash Classifier")
with mlflow.start_run(run_name=self.model_name):
mlflow.log_params(self.config)
mlflow.log_param("model_architecture", self.model_name)
print(f"\nStarting training for {self.model_name} on {self.device}...")
best_val_acc = 0.0
epochs_no_improve = 0
for epoch in range(self.epochs):
train_loss, train_acc = self.train_epoch()
val_loss, val_acc = self.validate()
self.scheduler.step()
self.history["train_loss"].append(train_loss)
self.history["train_acc"].append(train_acc)
self.history["val_loss"].append(val_loss)
self.history["val_acc"].append(val_acc)
mlflow.log_metric("train_loss", train_loss, step=epoch)
mlflow.log_metric("train_acc", train_acc, step=epoch)
mlflow.log_metric("val_loss", val_loss, step=epoch)
mlflow.log_metric("val_acc", val_acc, step=epoch)
mlflow.log_metric("lr", self.optimizer.param_groups[0]["lr"], step=epoch)
print(
f"Epoch [{epoch + 1}/{self.epochs}] "
f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f} | "
f"LR: {self.optimizer.param_groups[0]['lr']:.6f}"
)
if val_acc > best_val_acc:
best_val_acc = val_acc
epochs_no_improve = 0
os.makedirs("models", exist_ok=True)
torch.save(self.model.state_dict(), self.checkpoint_path)
print(f"--> Saved best model for {self.model_name} with Val Acc: {val_acc:.4f}")
mlflow.log_artifact(self.checkpoint_path)
else:
epochs_no_improve += 1
if epochs_no_improve >= self.patience:
print(f"Early stopping triggered after {epoch + 1} epochs.")
break
self.plot_history()
return self.history
if __name__ == "__main__":
config = load_config()
train_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
val_transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
processed_dir = Path("data/processed")
if not (processed_dir / "X_train.npy").exists():
print("Error: Processed data not found. Please run src/dataset.py first.")
else:
train_ds = TrashDataset(
processed_dir / "X_train.npy", processed_dir / "y_train.npy", transform=train_transform
)
val_ds = TrashDataset(
processed_dir / "X_val.npy", processed_dir / "y_val.npy", transform=val_transform
)
train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False)
num_classes = len(config["classes"])
models_to_train = {
"SimpleCNN": SimpleCNN(num_classes=num_classes),
"DeepCNN": DeepCNN(num_classes=num_classes),
"ResNet18": ResNet18Transfer(num_classes=num_classes, pretrained=True),
}
for name, model in models_to_train.items():
trainer = Trainer(model, train_loader, val_loader, config, name)
trainer.fit()
print("\nAll models trained. Starting comparison...")
from src.comparison import run_comparison
run_comparison()