GAP_SMALL_PROJECT2 / train.py
fatimaxa's picture
Upload 46 files
83be575 verified
import torch
import torch.nn as nn
from data_prep import train_loader, val_loader, device
from models.model import PlantCNN
from utils.config import load_config
from clearml import Task
from pathlib import Path
from tqdm.auto import tqdm
def train_step(model, loader, optimizer, loss_fn, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, batch in enumerate(tqdm(loader, desc="Train", leave=False)):
images = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()
output = model(images)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()*labels.size(0)
_, preds = torch.max(output, dim=1)
correct += (preds==labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss/total
epoch_acc = correct/total
return epoch_loss, epoch_acc
def test_step(model, loader, loss_fn, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, batch in enumerate(tqdm(loader, desc="Val", leave=False)):
images = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
output = model(images)
loss = loss_fn(output, labels)
running_loss += loss.item()*labels.size(0)
_, preds = torch.max(output, dim=1)
correct += (preds==labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss/total
epoch_acc = correct/total
return epoch_loss, epoch_acc
def main():
config = load_config()
num_classes = config["num_classes"]
channels = config["channels"]
dropout = config["dropout"]
lr = config["lr"]
weight_decay = config["weight_decay"]
num_epochs = config["num_epochs"]
patience = config["early_stopping_patience"]
project_name = "GAP_plant_disease_classification"
model_name="PlantCNN"
model = PlantCNN(num_classes=num_classes, channels=channels, dropout=dropout).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
task = Task.init(project_name=project_name, task_name=f"{model_name}_training")
task.connect(config)
task.add_tags([model_name, "train"])
logger = task.get_logger()
best_val_acc = 0.0
best_state_dict = None
patience_cnt = 0
for epoch in range(num_epochs):
print(f"\nEpoch: {epoch+1}/{num_epochs}")
train_loss, train_acc = train_step(
model, train_loader, optimizer, loss_fn, device
)
val_loss, val_acc = test_step(
model, val_loader, loss_fn, device
)
print(f"Train loss: {train_loss:.3f} | Train accuracy: {train_acc:.3f}")
print(f"Validation loss: {val_loss:.3f} | Validation accuracy: {val_acc:.3f}")
logger.report_scalar("loss", "train", train_loss, epoch)
logger.report_scalar("loss", "val", val_loss, epoch)
logger.report_scalar("accuracy", "train", train_acc, epoch)
logger.report_scalar("accuracy", "val", val_acc, epoch)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_state_dict = model.state_dict()
patience_cnt = 0
else:
patience_cnt+=1
if patience_cnt >= patience:
print(f"\nEarly stopping triggered after {epoch+1} epochs.")
break
if best_state_dict is not None:
model.load_state_dict(best_state_dict)
project_rt = Path(__file__).resolve().parent
model_dir = project_rt/"saved_models"
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir/"plant_cnn.pt"
torch.save(model.state_dict(), model_path)
print(f"Saved best model to {model_path}")
task.update_output_model(model_path=str(model_path), name="plant_cnn_best")
if __name__ == "__main__":
main()