eksemyashkina's picture
Added files
f096e52
from pathlib import Path
from tqdm import tqdm
import numpy as np
import argparse
import json
import wandb
import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import models
from models.resnet50 import ResNet
from models.mobilenet_v2 import MobileNetV2
from dataset import PlantsDataset
from utils import train_transform, test_transform, EMA
def parse_args():
parser = argparse.ArgumentParser(description="Train a model on plant dataset")
parser.add_argument("--train-root", type=str, default="data/plants/train", help="Path to the training data")
parser.add_argument("--test-root", type=str, default="data/plants/test", help="Path to the testing data")
parser.add_argument("--load-to-ram", type=bool, default=False, help="Load dataset to RAM")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing")
parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader")
parser.add_argument("--num-workers", type=int, default=1, help="Number of workers for DataLoader")
parser.add_argument("--num-epochs", type=int, default=10, help="Number of training epochs")
parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer")
parser.add_argument("--weights-path", type=str, default="weights/mobilenet_v2-b0353104.pth", choices=["weights/resnet50-0676ba61.pth", "weights/mobilenet_v2-b0353104.pth"], help="Path to the pre-trained weights")
parser.add_argument("--project-name", type=str, default="plants_classifier", help="WandB project name")
parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
parser.add_argument("--criterion", type=str, default="CrossEntropyLoss", help="Loss function type")
parser.add_argument("--labels-path", type=str, default="labels.json", help="Path to the labels json file")
parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run the training on")
parser.add_argument("--model", type=str, default="mobilenet", choices=["resnet", "mobilenet"], help="Model class name")
parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
parser.add_argument("--logs-dir", type=str, default="resnet-logs", choices=["resnet-logs", "mobilenet-logs"], help="???")
return parser.parse_args()
def main() -> None:
args = parse_args()
with open(args.labels_path, "r") as fp:
labels = json.load(fp)
num_classes = len(labels)
logs_dir = Path(args.logs_dir)
logs_dir.mkdir(exist_ok=True)
wandb.init(project=args.project_name, dir=logs_dir)
train_dataset = PlantsDataset(root=args.train_root, load_to_ram=args.load_to_ram, transform=train_transform, labels=labels)
test_dataset = PlantsDataset(root=args.test_root, load_to_ram=args.load_to_ram, transform=test_transform, labels=labels)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=args.pin_memory, num_workers=args.num_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=args.pin_memory, num_workers=args.num_workers)
device = torch.device(args.device)
if args.model == "resnet":
model = ResNet(weights_path=args.weights_path)
model.fc = nn.Linear(512 * model.expansion, num_classes)
nn.init.xavier_uniform_(model.fc.weight)
for name, param in model.named_parameters():
if "layer4" in name or "fc" in name:
param.requires_grad = True
else:
param.requires_grad = False
elif args.model == "mobilenet":
model = MobileNetV2(weights_path=args.weights_path)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
nn.init.xavier_uniform_(model.classifier[1].weight)
for name, param in model.named_parameters():
if "classifier" or "features.18" or "features.17" in name:
param.requires_grad = True
else:
param.requires_grad = False
model = model.to(device)
optimizer_class = getattr(torch.optim, args.optimizer)
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
criterion_class = getattr(nn, args.criterion)
criterion = criterion_class()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs)
best_accuracy = 0
train_loss_ema, train_accuracy_ema, grad_norm_ema = EMA(), EMA(), EMA()
for epoch in range(1, args.num_epochs + 1):
model.train()
pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}")
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_norm).item()
optimizer.step()
train_loss = loss.item()
train_accuracy = (logits.argmax(dim=1) == labels).sum().item() / logits.shape[0]
pbar.set_postfix({"loss": train_loss_ema(train_loss), "accuracy": train_accuracy_ema(train_accuracy), "grad_norm": grad_norm_ema(grad_norm)})
wandb.log(
{
"train/epoch": epoch,
"train/loss": train_loss,
"train/accuracy": train_accuracy,
"train/learning_rate": optimizer.param_groups[0]["lr"],
"train/grad_norm": grad_norm,
}
)
model.eval()
test_loss, test_accuracy = 0.0, 0.0
with torch.no_grad():
pbar = tqdm(test_loader, desc=f"Val epoch {epoch}/{args.num_epochs}")
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
test_loss += loss.item()
test_accuracy += (logits.argmax(dim=1) == labels).sum().item()
test_loss /= len(test_loader)
test_accuracy /= len(test_loader.dataset)
print(f"loss: {test_loss:.3f}, accuracy: {test_accuracy:.3f}")
wandb.log(
{
"val/epoch": epoch,
"val/test_loss": test_loss,
"val/test_accuracy": test_accuracy,
}
)
scheduler.step()
if test_accuracy > best_accuracy:
best_accuracy = test_accuracy
torch.save(model.state_dict(), logs_dir / f"checkpoint-best-{epoch:09}.pth")
elif epoch % args.save_frequency == 0:
torch.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
wandb.finish()
if __name__ == "__main__":
main()