Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from data_prep import train_loader, val_loader, device | |
| from models.model import PlantCNN | |
| from utils.config import load_config | |
| from clearml import Task | |
| from pathlib import Path | |
| from tqdm.auto import tqdm | |
| def train_step(model, loader, optimizer, loss_fn, device): | |
| model.train() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)): | |
| images = batch["pixel_values"].to(device) | |
| labels = batch["labels"].to(device) | |
| optimizer.zero_grad() | |
| output = model(images) | |
| loss = loss_fn(output, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item()*labels.size(0) | |
| _, preds = torch.max(output, dim=1) | |
| correct += (preds==labels).sum().item() | |
| total += labels.size(0) | |
| epoch_loss = running_loss/total | |
| epoch_acc = correct/total | |
| return epoch_loss, epoch_acc | |
| def test_step(model, loader, loss_fn, device): | |
| model.eval() | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| with torch.no_grad(): | |
| for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)): | |
| images = batch["pixel_values"].to(device) | |
| labels = batch["labels"].to(device) | |
| output = model(images) | |
| loss = loss_fn(output, labels) | |
| running_loss += loss.item()*labels.size(0) | |
| _, preds = torch.max(output, dim=1) | |
| correct += (preds==labels).sum().item() | |
| total += labels.size(0) | |
| epoch_loss = running_loss/total | |
| epoch_acc = correct/total | |
| return epoch_loss, epoch_acc | |
| def main(): | |
| config = load_config() | |
| num_classes = config["num_classes"] | |
| channels = config["channels"] | |
| dropout = config["dropout"] | |
| lr = config["lr"] | |
| weight_decay = config["weight_decay"] | |
| num_epochs = config["num_epochs"] | |
| patience = config["early_stopping_patience"] | |
| project_name = "GAP_plant_disease_classification" | |
| model_name="PlantCNN" | |
| model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout).to(device) | |
| loss_fn = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) | |
| task = Task.init(project_name=project_name, task_name=f"{model_name}_training") | |
| task.connect(config) | |
| task.add_tags([model_name, "train"]) | |
| logger = task.get_logger() | |
| best_val_acc = 0.0 | |
| best_state_dict = None | |
| patience_cnt = 0 | |
| for epoch in range(num_epochs): | |
| print(f"\nEpoch: {epoch+1}/{num_epochs}") | |
| train_loss, train_acc = train_step( | |
| model, train_loader, optimizer, loss_fn, device | |
| ) | |
| val_loss, val_acc = test_step( | |
| model, val_loader, loss_fn, device | |
| ) | |
| print(f"Train loss: {train_loss:.3f} | Train accuracy: {train_acc:.3f}") | |
| print(f"Validation loss: {val_loss:.3f} | Validation accuracy: {val_acc:.3f}") | |
| logger.report_scalar("loss", "train", train_loss, epoch) | |
| logger.report_scalar("loss", "val", val_loss, epoch) | |
| logger.report_scalar("accuracy", "train", train_acc, epoch) | |
| logger.report_scalar("accuracy", "val", val_acc, epoch) | |
| if val_acc > best_val_acc: | |
| best_val_acc = val_acc | |
| best_state_dict = model.state_dict() | |
| patience_cnt = 0 | |
| else: | |
| patience_cnt+=1 | |
| if patience_cnt >= patience: | |
| print(f"\nEarly stopping triggered after {epoch+1} epochs.") | |
| break | |
| if best_state_dict is not None: | |
| model.load_state_dict(best_state_dict) | |
| project_rt = Path(__file__).resolve().parent | |
| model_dir = project_rt/"saved_models" | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| model_path = model_dir/"plant_cnn.pt" | |
| torch.save(model.state_dict(), model_path) | |
| print(f"Saved best model to {model_path}") | |
| task.update_output_model(model_path=str(model_path), name="plant_cnn_best") | |
| if __name__ == "__main__": | |
| main() |