budijuarto's picture
Deploy Indonesian Herbal Plants Classifier
fa49101 verified
"""
Training and Evaluation Pipeline
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import numpy as np
from typing import Dict, List, Tuple, Optional
import time
import json
from pathlib import Path
import config
from models import get_model
from dataset import create_data_loaders
class EarlyStopping:
"""Early stopping to stop training when validation loss doesn't improve"""
def __init__(self, patience: int = 5, min_delta: float = 0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss: float) -> bool:
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss > self.best_loss - self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_loss = val_loss
self.counter = 0
return self.early_stop
class Trainer:
"""Model trainer with mixed precision and various optimizations"""
def __init__(
self,
model: nn.Module,
model_name: str,
train_loader,
val_loader,
num_classes: int,
device: str = config.DEVICE
):
self.model = model.to(device)
self.model_name = model_name
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.num_classes = num_classes
# Loss function with label smoothing
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# Optimizer - AdamW with weight decay
self.optimizer = optim.AdamW(
model.parameters(),
lr=config.LEARNING_RATE,
weight_decay=config.WEIGHT_DECAY
)
# Learning rate scheduler
self.scheduler = OneCycleLR(
self.optimizer,
max_lr=config.LEARNING_RATE * 10,
epochs=config.EPOCHS,
steps_per_epoch=len(train_loader),
pct_start=0.1,
anneal_strategy='cos'
)
# Mixed precision training
self.scaler = GradScaler()
# Early stopping
self.early_stopping = EarlyStopping(patience=config.EARLY_STOPPING_PATIENCE)
# History
self.history = {
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': [],
'lr': []
}
# Best model tracking
self.best_val_acc = 0.0
self.best_model_state = None
def train_epoch(self) -> Tuple[float, float]:
"""Train for one epoch"""
self.model.train()
running_loss = 0.0
correct = 0
total = 0
pbar = tqdm(self.train_loader, desc="Training", leave=False)
for images, labels in pbar:
images = images.to(self.device)
labels = labels.to(self.device)
self.optimizer.zero_grad()
# Mixed precision forward pass
with autocast():
outputs = self.model(images)
loss = self.criterion(outputs, labels)
# Backward pass with gradient scaling
self.scaler.scale(loss).backward()
# Gradient clipping
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.scheduler.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
pbar.set_postfix({
'loss': f'{loss.item():.4f}',
'acc': f'{100.*correct/total:.2f}%'
})
epoch_loss = running_loss / len(self.train_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
@torch.no_grad()
def validate(self) -> Tuple[float, float]:
"""Validate the model"""
self.model.eval()
running_loss = 0.0
correct = 0
total = 0
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
images = images.to(self.device)
labels = labels.to(self.device)
with autocast():
outputs = self.model(images)
loss = self.criterion(outputs, labels)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / len(self.val_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
def train(self, epochs: int = config.EPOCHS) -> Dict:
"""Full training loop"""
print(f"\n{'='*60}")
print(f"Training {self.model_name}")
print(f"{'='*60}")
print(f"Device: {self.device}")
print(f"Epochs: {epochs}")
print(f"Batch size: {config.BATCH_SIZE}")
print(f"Learning rate: {config.LEARNING_RATE}")
start_time = time.time()
for epoch in range(epochs):
print(f"\nEpoch [{epoch+1}/{epochs}]")
# Train
train_loss, train_acc = self.train_epoch()
# Validate
val_loss, val_acc = self.validate()
# Get current learning rate
current_lr = self.optimizer.param_groups[0]['lr']
# Update history
self.history['train_loss'].append(train_loss)
self.history['train_acc'].append(train_acc)
self.history['val_loss'].append(val_loss)
self.history['val_acc'].append(val_acc)
self.history['lr'].append(current_lr)
print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
print(f" LR: {current_lr:.6f}")
# Save best model
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
self.best_model_state = self.model.state_dict().copy()
print(f" *** New best model! ***")
# Early stopping
if self.early_stopping(val_loss):
print(f"\nEarly stopping triggered at epoch {epoch+1}")
break
training_time = time.time() - start_time
# Load best model
if self.best_model_state is not None:
self.model.load_state_dict(self.best_model_state)
print(f"\nTraining completed in {training_time/60:.2f} minutes")
print(f"Best validation accuracy: {self.best_val_acc:.2f}%")
# Add training time to history
self.history['training_time'] = training_time
self.history['best_val_acc'] = self.best_val_acc
return self.history
def save_model(self, path: Optional[Path] = None):
"""Save the trained model"""
if path is None:
path = config.MODELS_DIR / f"{self.model_name.lower().replace(' ', '_')}.pth"
torch.save({
'model_state_dict': self.model.state_dict(),
'model_name': self.model_name,
'num_classes': self.num_classes,
'best_val_acc': self.best_val_acc,
'history': self.history
}, path)
print(f"Model saved to {path}")
return path
def train_all_models():
"""Train all 5 models and return results"""
print("\n" + "="*70)
print("TRAINING 5 MODELS FOR INDONESIAN HERBAL PLANTS CLASSIFICATION")
print("="*70)
# Create data loaders
train_loader, val_loader, test_loader, class_names = create_data_loaders()
num_classes = len(class_names)
# Save class names
with open(config.OUTPUT_DIR / "class_names.json", 'w') as f:
json.dump(class_names, f, indent=2)
results = {}
for model_name in config.MODEL_NAMES:
print(f"\n{'#'*70}")
print(f"# Model: {model_name.upper()}")
print(f"{'#'*70}")
# Create model
model = get_model(model_name, num_classes, pretrained=True)
# Count parameters
params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Create trainer
trainer = Trainer(
model=model,
model_name=model_name,
train_loader=train_loader,
val_loader=val_loader,
num_classes=num_classes
)
# Train
history = trainer.train(epochs=config.EPOCHS)
# Save model
model_path = trainer.save_model()
# Store results
results[model_name] = {
'history': history,
'model_path': str(model_path),
'params': params,
'trainable_params': trainable_params
}
# Save results summary
with open(config.OUTPUT_DIR / "training_results.json", 'w') as f:
# Convert to serializable format
serializable_results = {}
for name, data in results.items():
serializable_results[name] = {
'best_val_acc': data['history']['best_val_acc'],
'training_time': data['history']['training_time'],
'params': data['params'],
'model_path': data['model_path']
}
json.dump(serializable_results, f, indent=2)
return results, test_loader, class_names
if __name__ == "__main__":
results, test_loader, class_names = train_all_models()