|
|
"""
|
|
|
Fine-tune EfficientNet-B0 on a recycling image dataset.
|
|
|
|
|
|
Usage:
|
|
|
pip install torch torchvision
|
|
|
python train_efficientnet.py --data-dir path/to/data --epochs 10 --batch-size 32 --lr 1e-3
|
|
|
"""
|
|
|
import os
|
|
|
import argparse
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torchvision import datasets, transforms, models
|
|
|
|
|
|
|
|
|
def build_efficientnet(num_classes, device):
|
|
|
|
|
|
model = models.efficientnet_b0(pretrained=True)
|
|
|
|
|
|
for param in model.features.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
if isinstance(model.classifier, nn.Sequential):
|
|
|
|
|
|
for layer in model.classifier:
|
|
|
if isinstance(layer, nn.Linear):
|
|
|
in_features = layer.in_features
|
|
|
break
|
|
|
else:
|
|
|
raise ValueError("No nn.Linear layer found in model.classifier")
|
|
|
else:
|
|
|
in_features = model.classifier.in_features
|
|
|
model.classifier = nn.Sequential(
|
|
|
nn.Dropout(0.2),
|
|
|
nn.Linear(in_features, num_classes)
|
|
|
)
|
|
|
return model.to(device)
|
|
|
|
|
|
|
|
|
def train(model, dataloaders, dataset_sizes, criterion, optimizer, device, num_epochs):
|
|
|
best_acc = 0.0
|
|
|
best_path = 'best_efficientnet_model.pth'
|
|
|
for epoch in range(num_epochs):
|
|
|
print(f"Epoch {epoch+1}/{num_epochs}")
|
|
|
print('-' * 20)
|
|
|
|
|
|
for phase in ['train', 'val']:
|
|
|
if phase == 'train':
|
|
|
model.train()
|
|
|
else:
|
|
|
model.eval()
|
|
|
|
|
|
running_loss = 0.0
|
|
|
running_corrects = 0
|
|
|
|
|
|
for inputs, labels in dataloaders[phase]:
|
|
|
inputs = inputs.to(device)
|
|
|
labels = labels.to(device)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
with torch.set_grad_enabled(phase == 'train'):
|
|
|
outputs = model(inputs)
|
|
|
_, preds = torch.max(outputs, 1)
|
|
|
loss = criterion(outputs, labels)
|
|
|
if phase == 'train':
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
running_loss += loss.item() * inputs.size(0)
|
|
|
running_corrects += torch.sum(preds == labels.data)
|
|
|
|
|
|
epoch_loss = running_loss / dataset_sizes[phase]
|
|
|
epoch_acc = float(running_corrects) / dataset_sizes[phase]
|
|
|
|
|
|
print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
|
|
|
|
|
|
|
|
|
if phase == 'val' and epoch_acc > best_acc:
|
|
|
best_acc = epoch_acc
|
|
|
torch.save(model.state_dict(), best_path)
|
|
|
|
|
|
print(f"Best validation accuracy: {best_acc:.4f}")
|
|
|
print(f"Model saved to {best_path}")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument('--data-dir', type=str, default="data",
|
|
|
help='Directory with train/ and val/ subfolders (default: ./data)')
|
|
|
|
|
|
parser.add_argument('--epochs', type=int, default=10,
|
|
|
help='Number of training epochs')
|
|
|
parser.add_argument('--batch-size', type=int, default=32,
|
|
|
help='Batch size for training and validation')
|
|
|
parser.add_argument('--lr', type=float, default=1e-3,
|
|
|
help='Learning rate for optimizer')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
|
|
|
transforms_dict = {
|
|
|
'train': transforms.Compose([
|
|
|
transforms.Resize((224, 224)),
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize([0.485, 0.456, 0.406],
|
|
|
[0.229, 0.224, 0.225])
|
|
|
]),
|
|
|
'val': transforms.Compose([
|
|
|
transforms.Resize((224, 224)),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize([0.485, 0.456, 0.406],
|
|
|
[0.229, 0.224, 0.225])
|
|
|
]),
|
|
|
}
|
|
|
|
|
|
|
|
|
data_dirs = {x: os.path.join(args.data_dir, x) for x in ['train', 'val']}
|
|
|
image_datasets = {x: datasets.ImageFolder(data_dirs[x], transform=transforms_dict[x])
|
|
|
for x in ['train', 'val']}
|
|
|
dataloaders = {x: DataLoader(image_datasets[x], batch_size=args.batch_size,
|
|
|
shuffle=(x=='train'), num_workers=4)
|
|
|
for x in ['train', 'val']}
|
|
|
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
|
|
|
|
|
|
print(f"Classes: {image_datasets['train'].classes}")
|
|
|
|
|
|
|
|
|
model = build_efficientnet(len(image_datasets['train'].classes), device)
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
|
|
|
|
|
|
train(model, dataloaders, dataset_sizes, criterion, optimizer, device, args.epochs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main() |