RecycloAI / src /train_model.py
seun829's picture
Upload 40 files
b5cb408 verified
"""
Fine-tune EfficientNet-B0 on a recycling image dataset.
Usage:
pip install torch torchvision
python train_efficientnet.py --data-dir path/to/data --epochs 10 --batch-size 32 --lr 1e-3
"""
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
def build_efficientnet(num_classes, device):
# Load pretrained EfficientNet-B0
model = models.efficientnet_b0(pretrained=True)
# Freeze feature extractor
for param in model.features.parameters():
param.requires_grad = False
# Replace classifier head
if isinstance(model.classifier, nn.Sequential):
# Find the first nn.Linear layer in the classifier
for layer in model.classifier:
if isinstance(layer, nn.Linear):
in_features = layer.in_features
break
else:
raise ValueError("No nn.Linear layer found in model.classifier")
else:
in_features = model.classifier.in_features
model.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(in_features, num_classes)
)
return model.to(device)
def train(model, dataloaders, dataset_sizes, criterion, optimizer, device, num_epochs):
best_acc = 0.0
best_path = 'best_efficientnet_model.pth'
for epoch in range(num_epochs):
print(f"Epoch {epoch+1}/{num_epochs}")
print('-' * 20)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = float(running_corrects) / dataset_sizes[phase]
print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
# Save best model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
torch.save(model.state_dict(), best_path)
print(f"Best validation accuracy: {best_acc:.4f}")
print(f"Model saved to {best_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir', type=str, default="data",
help='Directory with train/ and val/ subfolders (default: ./data)')
parser.add_argument('--epochs', type=int, default=10,
help='Number of training epochs')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size for training and validation')
parser.add_argument('--lr', type=float, default=1e-3,
help='Learning rate for optimizer')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Data transformations
transforms_dict = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
}
# Datasets and loaders
data_dirs = {x: os.path.join(args.data_dir, x) for x in ['train', 'val']}
image_datasets = {x: datasets.ImageFolder(data_dirs[x], transform=transforms_dict[x])
for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=args.batch_size,
shuffle=(x=='train'), num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
print(f"Classes: {image_datasets['train'].classes}")
# Build and train model
model = build_efficientnet(len(image_datasets['train'].classes), device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
train(model, dataloaders, dataset_sizes, criterion, optimizer, device, args.epochs)
if __name__ == '__main__':
main()