import torch import torch.nn as nn from data_prep import train_loader, val_loader, device from models.model import PlantCNN from utils.config import load_config from clearml import Task from pathlib import Path from tqdm.auto import tqdm def train_step(model, loader, optimizer, loss_fn, device): model.train() running_loss = 0.0 correct = 0 total = 0 for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)): images = batch["pixel_values"].to(device) labels = batch["labels"].to(device) optimizer.zero_grad() output = model(images) loss = loss_fn(output, labels) loss.backward() optimizer.step() running_loss += loss.item()*labels.size(0) _, preds = torch.max(output, dim=1) correct += (preds==labels).sum().item() total += labels.size(0) epoch_loss = running_loss/total epoch_acc = correct/total return epoch_loss, epoch_acc def test_step(model, loader, loss_fn, device): model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)): images = batch["pixel_values"].to(device) labels = batch["labels"].to(device) output = model(images) loss = loss_fn(output, labels) running_loss += loss.item()*labels.size(0) _, preds = torch.max(output, dim=1) correct += (preds==labels).sum().item() total += labels.size(0) epoch_loss = running_loss/total epoch_acc = correct/total return epoch_loss, epoch_acc def main(): config = load_config() num_classes = config["num_classes"] channels = config["channels"] dropout = config["dropout"] lr = config["lr"] weight_decay = config["weight_decay"] num_epochs = config["num_epochs"] patience = config["early_stopping_patience"] project_name = "GAP_plant_disease_classification" model_name="PlantCNN" model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout).to(device) loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) task = Task.init(project_name=project_name, task_name=f"{model_name}_training") task.connect(config) task.add_tags([model_name, "train"]) logger = task.get_logger() best_val_acc = 0.0 best_state_dict = None patience_cnt = 0 for epoch in range(num_epochs): print(f"\nEpoch: {epoch+1}/{num_epochs}") train_loss, train_acc = train_step( model, train_loader, optimizer, loss_fn, device ) val_loss, val_acc = test_step( model, val_loader, loss_fn, device ) print(f"Train loss: {train_loss:.3f} | Train accuracy: {train_acc:.3f}") print(f"Validation loss: {val_loss:.3f} | Validation accuracy: {val_acc:.3f}") logger.report_scalar("loss", "train", train_loss, epoch) logger.report_scalar("loss", "val", val_loss, epoch) logger.report_scalar("accuracy", "train", train_acc, epoch) logger.report_scalar("accuracy", "val", val_acc, epoch) if val_acc > best_val_acc: best_val_acc = val_acc best_state_dict = model.state_dict() patience_cnt = 0 else: patience_cnt+=1 if patience_cnt >= patience: print(f"\nEarly stopping triggered after {epoch+1} epochs.") break if best_state_dict is not None: model.load_state_dict(best_state_dict) project_rt = Path(__file__).resolve().parent model_dir = project_rt/"saved_models" model_dir.mkdir(parents=True, exist_ok=True) model_path = model_dir/"plant_cnn.pt" torch.save(model.state_dict(), model_path) print(f"Saved best model to {model_path}") task.update_output_model(model_path=str(model_path), name="plant_cnn_best") if __name__ == "__main__": main()