|
|
"""
|
|
|
Training and Evaluation Pipeline
|
|
|
"""
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
|
|
|
from torch.cuda.amp import GradScaler, autocast
|
|
|
from tqdm import tqdm
|
|
|
import numpy as np
|
|
|
from typing import Dict, List, Tuple, Optional
|
|
|
import time
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
|
|
|
import config
|
|
|
from models import get_model
|
|
|
from dataset import create_data_loaders
|
|
|
|
|
|
|
|
|
class EarlyStopping:
|
|
|
"""Early stopping to stop training when validation loss doesn't improve"""
|
|
|
|
|
|
def __init__(self, patience: int = 5, min_delta: float = 0.001):
|
|
|
self.patience = patience
|
|
|
self.min_delta = min_delta
|
|
|
self.counter = 0
|
|
|
self.best_loss = None
|
|
|
self.early_stop = False
|
|
|
|
|
|
def __call__(self, val_loss: float) -> bool:
|
|
|
if self.best_loss is None:
|
|
|
self.best_loss = val_loss
|
|
|
elif val_loss > self.best_loss - self.min_delta:
|
|
|
self.counter += 1
|
|
|
if self.counter >= self.patience:
|
|
|
self.early_stop = True
|
|
|
else:
|
|
|
self.best_loss = val_loss
|
|
|
self.counter = 0
|
|
|
return self.early_stop
|
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
"""Model trainer with mixed precision and various optimizations"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
model: nn.Module,
|
|
|
model_name: str,
|
|
|
train_loader,
|
|
|
val_loader,
|
|
|
num_classes: int,
|
|
|
device: str = config.DEVICE
|
|
|
):
|
|
|
self.model = model.to(device)
|
|
|
self.model_name = model_name
|
|
|
self.train_loader = train_loader
|
|
|
self.val_loader = val_loader
|
|
|
self.device = device
|
|
|
self.num_classes = num_classes
|
|
|
|
|
|
|
|
|
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
|
|
|
|
|
|
|
|
|
self.optimizer = optim.AdamW(
|
|
|
model.parameters(),
|
|
|
lr=config.LEARNING_RATE,
|
|
|
weight_decay=config.WEIGHT_DECAY
|
|
|
)
|
|
|
|
|
|
|
|
|
self.scheduler = OneCycleLR(
|
|
|
self.optimizer,
|
|
|
max_lr=config.LEARNING_RATE * 10,
|
|
|
epochs=config.EPOCHS,
|
|
|
steps_per_epoch=len(train_loader),
|
|
|
pct_start=0.1,
|
|
|
anneal_strategy='cos'
|
|
|
)
|
|
|
|
|
|
|
|
|
self.scaler = GradScaler()
|
|
|
|
|
|
|
|
|
self.early_stopping = EarlyStopping(patience=config.EARLY_STOPPING_PATIENCE)
|
|
|
|
|
|
|
|
|
self.history = {
|
|
|
'train_loss': [],
|
|
|
'train_acc': [],
|
|
|
'val_loss': [],
|
|
|
'val_acc': [],
|
|
|
'lr': []
|
|
|
}
|
|
|
|
|
|
|
|
|
self.best_val_acc = 0.0
|
|
|
self.best_model_state = None
|
|
|
|
|
|
def train_epoch(self) -> Tuple[float, float]:
|
|
|
"""Train for one epoch"""
|
|
|
self.model.train()
|
|
|
running_loss = 0.0
|
|
|
correct = 0
|
|
|
total = 0
|
|
|
|
|
|
pbar = tqdm(self.train_loader, desc="Training", leave=False)
|
|
|
|
|
|
for images, labels in pbar:
|
|
|
images = images.to(self.device)
|
|
|
labels = labels.to(self.device)
|
|
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
with autocast():
|
|
|
outputs = self.model(images)
|
|
|
loss = self.criterion(outputs, labels)
|
|
|
|
|
|
|
|
|
self.scaler.scale(loss).backward()
|
|
|
|
|
|
|
|
|
self.scaler.unscale_(self.optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
|
|
|
self.scaler.step(self.optimizer)
|
|
|
self.scaler.update()
|
|
|
self.scheduler.step()
|
|
|
|
|
|
running_loss += loss.item()
|
|
|
_, predicted = outputs.max(1)
|
|
|
total += labels.size(0)
|
|
|
correct += predicted.eq(labels).sum().item()
|
|
|
|
|
|
pbar.set_postfix({
|
|
|
'loss': f'{loss.item():.4f}',
|
|
|
'acc': f'{100.*correct/total:.2f}%'
|
|
|
})
|
|
|
|
|
|
epoch_loss = running_loss / len(self.train_loader)
|
|
|
epoch_acc = 100. * correct / total
|
|
|
|
|
|
return epoch_loss, epoch_acc
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def validate(self) -> Tuple[float, float]:
|
|
|
"""Validate the model"""
|
|
|
self.model.eval()
|
|
|
running_loss = 0.0
|
|
|
correct = 0
|
|
|
total = 0
|
|
|
|
|
|
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
|
|
|
images = images.to(self.device)
|
|
|
labels = labels.to(self.device)
|
|
|
|
|
|
with autocast():
|
|
|
outputs = self.model(images)
|
|
|
loss = self.criterion(outputs, labels)
|
|
|
|
|
|
running_loss += loss.item()
|
|
|
_, predicted = outputs.max(1)
|
|
|
total += labels.size(0)
|
|
|
correct += predicted.eq(labels).sum().item()
|
|
|
|
|
|
epoch_loss = running_loss / len(self.val_loader)
|
|
|
epoch_acc = 100. * correct / total
|
|
|
|
|
|
return epoch_loss, epoch_acc
|
|
|
|
|
|
def train(self, epochs: int = config.EPOCHS) -> Dict:
|
|
|
"""Full training loop"""
|
|
|
print(f"\n{'='*60}")
|
|
|
print(f"Training {self.model_name}")
|
|
|
print(f"{'='*60}")
|
|
|
print(f"Device: {self.device}")
|
|
|
print(f"Epochs: {epochs}")
|
|
|
print(f"Batch size: {config.BATCH_SIZE}")
|
|
|
print(f"Learning rate: {config.LEARNING_RATE}")
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
print(f"\nEpoch [{epoch+1}/{epochs}]")
|
|
|
|
|
|
|
|
|
train_loss, train_acc = self.train_epoch()
|
|
|
|
|
|
|
|
|
val_loss, val_acc = self.validate()
|
|
|
|
|
|
|
|
|
current_lr = self.optimizer.param_groups[0]['lr']
|
|
|
|
|
|
|
|
|
self.history['train_loss'].append(train_loss)
|
|
|
self.history['train_acc'].append(train_acc)
|
|
|
self.history['val_loss'].append(val_loss)
|
|
|
self.history['val_acc'].append(val_acc)
|
|
|
self.history['lr'].append(current_lr)
|
|
|
|
|
|
print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
|
|
|
print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
|
|
|
print(f" LR: {current_lr:.6f}")
|
|
|
|
|
|
|
|
|
if val_acc > self.best_val_acc:
|
|
|
self.best_val_acc = val_acc
|
|
|
self.best_model_state = self.model.state_dict().copy()
|
|
|
print(f" *** New best model! ***")
|
|
|
|
|
|
|
|
|
if self.early_stopping(val_loss):
|
|
|
print(f"\nEarly stopping triggered at epoch {epoch+1}")
|
|
|
break
|
|
|
|
|
|
training_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
if self.best_model_state is not None:
|
|
|
self.model.load_state_dict(self.best_model_state)
|
|
|
|
|
|
print(f"\nTraining completed in {training_time/60:.2f} minutes")
|
|
|
print(f"Best validation accuracy: {self.best_val_acc:.2f}%")
|
|
|
|
|
|
|
|
|
self.history['training_time'] = training_time
|
|
|
self.history['best_val_acc'] = self.best_val_acc
|
|
|
|
|
|
return self.history
|
|
|
|
|
|
def save_model(self, path: Optional[Path] = None):
|
|
|
"""Save the trained model"""
|
|
|
if path is None:
|
|
|
path = config.MODELS_DIR / f"{self.model_name.lower().replace(' ', '_')}.pth"
|
|
|
|
|
|
torch.save({
|
|
|
'model_state_dict': self.model.state_dict(),
|
|
|
'model_name': self.model_name,
|
|
|
'num_classes': self.num_classes,
|
|
|
'best_val_acc': self.best_val_acc,
|
|
|
'history': self.history
|
|
|
}, path)
|
|
|
|
|
|
print(f"Model saved to {path}")
|
|
|
return path
|
|
|
|
|
|
|
|
|
def train_all_models():
|
|
|
"""Train all 5 models and return results"""
|
|
|
print("\n" + "="*70)
|
|
|
print("TRAINING 5 MODELS FOR INDONESIAN HERBAL PLANTS CLASSIFICATION")
|
|
|
print("="*70)
|
|
|
|
|
|
|
|
|
train_loader, val_loader, test_loader, class_names = create_data_loaders()
|
|
|
num_classes = len(class_names)
|
|
|
|
|
|
|
|
|
with open(config.OUTPUT_DIR / "class_names.json", 'w') as f:
|
|
|
json.dump(class_names, f, indent=2)
|
|
|
|
|
|
results = {}
|
|
|
|
|
|
for model_name in config.MODEL_NAMES:
|
|
|
print(f"\n{'#'*70}")
|
|
|
print(f"# Model: {model_name.upper()}")
|
|
|
print(f"{'#'*70}")
|
|
|
|
|
|
|
|
|
model = get_model(model_name, num_classes, pretrained=True)
|
|
|
|
|
|
|
|
|
params = sum(p.numel() for p in model.parameters())
|
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
print(f"Total parameters: {params:,}")
|
|
|
print(f"Trainable parameters: {trainable_params:,}")
|
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model=model,
|
|
|
model_name=model_name,
|
|
|
train_loader=train_loader,
|
|
|
val_loader=val_loader,
|
|
|
num_classes=num_classes
|
|
|
)
|
|
|
|
|
|
|
|
|
history = trainer.train(epochs=config.EPOCHS)
|
|
|
|
|
|
|
|
|
model_path = trainer.save_model()
|
|
|
|
|
|
|
|
|
results[model_name] = {
|
|
|
'history': history,
|
|
|
'model_path': str(model_path),
|
|
|
'params': params,
|
|
|
'trainable_params': trainable_params
|
|
|
}
|
|
|
|
|
|
|
|
|
with open(config.OUTPUT_DIR / "training_results.json", 'w') as f:
|
|
|
|
|
|
serializable_results = {}
|
|
|
for name, data in results.items():
|
|
|
serializable_results[name] = {
|
|
|
'best_val_acc': data['history']['best_val_acc'],
|
|
|
'training_time': data['history']['training_time'],
|
|
|
'params': data['params'],
|
|
|
'model_path': data['model_path']
|
|
|
}
|
|
|
json.dump(serializable_results, f, indent=2)
|
|
|
|
|
|
return results, test_loader, class_names
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
results, test_loader, class_names = train_all_models()
|
|
|
|