|
|
import torch
|
|
|
import argparse
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
import os
|
|
|
import wandb
|
|
|
from tqdm.auto import tqdm
|
|
|
from model import build_model
|
|
|
from datasets import get_datasets, get_data_loaders
|
|
|
from utils import save_model, save_plots, SaveBestModel
|
|
|
|
|
|
seed = 42
|
|
|
torch.manual_seed(seed)
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument(
|
|
|
'-e', '--epochs',
|
|
|
type=int,
|
|
|
default=15,
|
|
|
help='Number of epochs to train our network for'
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
'-lr', '--learning-rate',
|
|
|
type=float,
|
|
|
dest='learning_rate',
|
|
|
default=0.0001,
|
|
|
help='Learning rate for training the model'
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
'-b', '--batch-size',
|
|
|
dest='batch_size',
|
|
|
default=16,
|
|
|
type=int
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
'-ft', '--fine-tune',
|
|
|
dest='fine_tune' ,
|
|
|
action='store_true',
|
|
|
help='pass this to fine tune all layers'
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
'--save-name',
|
|
|
dest='save_name',
|
|
|
default='model',
|
|
|
help='file name of the final model to save'
|
|
|
)
|
|
|
args = vars(parser.parse_args())
|
|
|
|
|
|
|
|
|
def train(model, trainloader, optimizer, criterion):
|
|
|
model.train()
|
|
|
print('Training')
|
|
|
train_running_loss = 0.0
|
|
|
train_running_correct = 0
|
|
|
counter = 0
|
|
|
for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
|
|
|
counter += 1
|
|
|
image, labels = data
|
|
|
image = image.to(device)
|
|
|
labels = labels.to(device)
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
outputs = model(image)
|
|
|
|
|
|
loss = criterion(outputs, labels)
|
|
|
train_running_loss += loss.item()
|
|
|
|
|
|
_, preds = torch.max(outputs.data, 1)
|
|
|
train_running_correct += (preds == labels).sum().item()
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
epoch_loss = train_running_loss / counter
|
|
|
epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
|
|
|
return epoch_loss, epoch_acc
|
|
|
|
|
|
|
|
|
def validate(model, testloader, criterion, class_names):
|
|
|
model.eval()
|
|
|
print('Validation')
|
|
|
valid_running_loss = 0.0
|
|
|
valid_running_correct = 0
|
|
|
counter = 0
|
|
|
with torch.no_grad():
|
|
|
for i, data in tqdm(enumerate(testloader), total=len(testloader)):
|
|
|
counter += 1
|
|
|
|
|
|
image, labels = data
|
|
|
image = image.to(device)
|
|
|
labels = labels.to(device)
|
|
|
|
|
|
outputs = model(image)
|
|
|
|
|
|
loss = criterion(outputs, labels)
|
|
|
valid_running_loss += loss.item()
|
|
|
|
|
|
_, preds = torch.max(outputs.data, 1)
|
|
|
valid_running_correct += (preds == labels).sum().item()
|
|
|
|
|
|
|
|
|
epoch_loss = valid_running_loss / counter
|
|
|
epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
|
|
|
return epoch_loss, epoch_acc
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
wandb.init(
|
|
|
project="Tulsi-classification",
|
|
|
name="swin_transformers-cls",
|
|
|
config={
|
|
|
"epochs": args['epochs'],
|
|
|
"batch_size": args['batch_size'],
|
|
|
"learning_rate": args['learning_rate'],
|
|
|
"architecture": "Swin-Tiny",
|
|
|
"optimizer": "AdamW",
|
|
|
"weight_decay": 0.002
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
out_dir = os.path.join('outputs')
|
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
dataset_train, dataset_valid, dataset_classes = get_datasets()
|
|
|
print(f"[INFO]: Number of training images: {len(dataset_train)}")
|
|
|
print(f"[INFO]: Number of validation images: {len(dataset_valid)}")
|
|
|
print(f"[INFO]: Classes: {dataset_classes}")
|
|
|
|
|
|
|
|
|
train_loader, valid_loader = get_data_loaders(
|
|
|
dataset_train, dataset_valid, batch_size=args['batch_size']
|
|
|
)
|
|
|
|
|
|
|
|
|
lr = args['learning_rate']
|
|
|
epochs = args['epochs']
|
|
|
device = ('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
print(f"Computation device: {device}")
|
|
|
print(f"Learning rate: {lr}")
|
|
|
print(f"Epochs to train for: {epochs}\n")
|
|
|
|
|
|
|
|
|
model = build_model(
|
|
|
fine_tune=args['fine_tune'],
|
|
|
num_classes=len(dataset_classes)
|
|
|
).to(device)
|
|
|
print(model)
|
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
print(f"{total_params:,} total parameters.")
|
|
|
total_trainable_params = sum(
|
|
|
p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
print(f"{total_trainable_params:,} training parameters.")
|
|
|
|
|
|
|
|
|
optimizer = optim.AdamW(
|
|
|
model.parameters(),
|
|
|
lr=lr,
|
|
|
weight_decay=0.002
|
|
|
)
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
save_best_model = SaveBestModel()
|
|
|
|
|
|
|
|
|
train_loss, valid_loss = [], []
|
|
|
train_acc, valid_acc = [], []
|
|
|
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
print(f"[INFO]: Epoch {epoch+1} of {epochs}")
|
|
|
train_epoch_loss, train_epoch_acc = train(model, train_loader,
|
|
|
optimizer, criterion)
|
|
|
valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader,
|
|
|
criterion, dataset_classes)
|
|
|
train_loss.append(train_epoch_loss)
|
|
|
valid_loss.append(valid_epoch_loss)
|
|
|
train_acc.append(train_epoch_acc)
|
|
|
valid_acc.append(valid_epoch_acc)
|
|
|
|
|
|
print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
|
|
|
print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
|
|
|
|
|
|
|
|
|
wandb.log({
|
|
|
"train/loss": train_epoch_loss,
|
|
|
"train/accuracy": train_epoch_acc,
|
|
|
"val/loss": valid_epoch_loss,
|
|
|
"val/accuracy": valid_epoch_acc,
|
|
|
"epoch": epoch + 1
|
|
|
})
|
|
|
|
|
|
save_best_model(
|
|
|
valid_epoch_loss, epoch, model, out_dir, args['save_name']
|
|
|
)
|
|
|
print('-'*50)
|
|
|
|
|
|
|
|
|
save_model(epochs, model, optimizer, criterion, out_dir, args['save_name'])
|
|
|
|
|
|
save_plots(train_acc, valid_acc, train_loss, valid_loss, out_dir)
|
|
|
|
|
|
wandb.finish()
|
|
|
print('TRAINING COMPLETE') |