| import argparse |
| import math |
| import os |
|
|
| import torch |
| import wandb |
| from omegaconf import OmegaConf |
| from timm.optim import create_optimizer_v2 |
| from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, StepLR |
|
|
| from src.dataset import get_dataloaders |
| from src.loss import get_criterion |
| from src.models import PlantDiseaseModel, get_param_groups |
| from src.trainer import Trainer |
| from src.utils import CosineAnnealingWarmupLR, load_config, set_seed |
|
|
|
|
| def build_optimizer(model, config): |
| layer_decay = getattr(config.optimizer, "layer_decay", 1.0) |
| param_groups = get_param_groups( |
| model, |
| base_lr=config.optimizer.backbone_lr, |
| head_lr=config.optimizer.head_lr, |
| weight_decay=config.optimizer.weight_decay, |
| ) |
|
|
| if config.optimizer.name.lower() == "adamw": |
| if layer_decay == 1: |
| optimizer = torch.optim.AdamW(param_groups) |
| else: |
| optimizer = create_optimizer_v2( |
| model, |
| opt="adamw", |
| lr=config.optimizer.head_lr, |
| layer_decay=layer_decay, |
| weight_decay=config.optimizer.weight_decay, |
| ) |
| else: |
| optimizer = torch.optim.Adam(param_groups) |
|
|
| return optimizer |
|
|
|
|
| def build_scheduler(optimizer, config, len_loader): |
| if config.scheduler.name.lower() == "cosine": |
| return CosineAnnealingLR( |
| optimizer, T_max=config.training.epochs, eta_min=config.scheduler.min_lr |
| ) |
| elif config.scheduler.name.lower() == "step": |
| return StepLR(optimizer, step_size=3, gamma=0.1) |
| elif config.scheduler.name.lower() == "plateau": |
| return ReduceLROnPlateau( |
| optimizer, |
| mode="max", |
| factor=0.1, |
| patience=3, |
| min_lr=config.scheduler.min_lr, |
| ) |
| elif config.scheduler.name.lower() == "cosine_warmup": |
| return CosineAnnealingWarmupLR( |
| optimizer, |
| warmup_steps=config.scheduler.warmup_epochs |
| * len_loader |
| / config.training.gradient_accumulation_steps, |
| total_steps=config.training.epochs |
| * len_loader |
| / config.training.gradient_accumulation_steps, |
| min_lr=config.scheduler.min_lr, |
| ) |
| else: |
| return None |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Train Plant Disease Classification Baseline" |
| ) |
| parser.add_argument( |
| "--config", type=str, default="configs/config.yaml", help="Path to config file" |
| ) |
| parser.add_argument( |
| "--resume", type=str, default=None, help="Path to checkpoint to resume from" |
| ) |
| parser.add_argument( |
| "--init_weights", type=str, default=None, help="Path to weights for warm start" |
| ) |
| args = parser.parse_args() |
|
|
| config = load_config(args.config) |
|
|
| set_seed(config.seed, deterministic=True) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Environment: Using device {device}") |
|
|
| train_loader, val_loader, num_classes = get_dataloaders(config) |
|
|
| if num_classes == 0: |
| print( |
| "WARNING: No data found. Make sure your datasets are correctly structured." |
| ) |
| |
| num_classes = 1 |
|
|
| config.model.num_classes = num_classes |
|
|
| model = PlantDiseaseModel(config, num_classes=num_classes) |
| model.to(device) |
|
|
| if args.init_weights and os.path.exists(args.init_weights): |
| print(f"Warm starting from weights: {args.init_weights}") |
| checkpoint = torch.load(args.init_weights, map_location=device) |
| state_dict = checkpoint.get("state_dict", checkpoint) |
| model.load_state_dict(state_dict) |
|
|
| optimizer = build_optimizer(model, config) |
| criterion = get_criterion(config) |
| scheduler = build_scheduler(optimizer, config, len(train_loader)) |
|
|
| |
| start_epoch = 1 |
| checkpoint = None |
| run_id = None |
| if args.resume and os.path.exists(args.resume): |
| print(f"Resuming experiment from checkpoint: {args.resume}") |
| checkpoint = torch.load(args.resume, map_location=device) |
| model.load_state_dict(checkpoint["state_dict"]) |
| optimizer.load_state_dict(checkpoint["optimizer"]) |
| if scheduler and checkpoint["scheduler"]: |
| scheduler.load_state_dict(checkpoint["scheduler"]) |
| start_epoch = checkpoint["epoch"] + 1 |
|
|
| if "rng_states" in checkpoint: |
| torch.set_rng_state(checkpoint["rng_states"]["torch"].cpu()) |
| if device.type == "cuda" and checkpoint["rng_states"]["cuda"] is not None: |
| torch.cuda.set_rng_state_all( |
| [s.cpu() for s in checkpoint["rng_states"]["cuda"]] |
| ) |
|
|
| if config.logging.use_wandb: |
| run_id = checkpoint.get("wandb_run_id") |
|
|
| if start_epoch > config.training.epochs: |
| print( |
| f"Requested to resume at epoch {start_epoch}, but total epochs is {config.training.epochs}. Exiting." |
| ) |
| return |
|
|
| |
| if config.logging.use_wandb: |
| wandb_config = OmegaConf.to_container(config, resolve=True) |
| wandb.init( |
| project=config.logging.project_name, |
| name=config.experiment_name, |
| config=wandb_config, |
| id=run_id, |
| resume="allow", |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| train_loader=train_loader, |
| val_loader=val_loader, |
| criterion=criterion, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| config=config, |
| device=device, |
| ) |
|
|
| if checkpoint is not None: |
| if trainer.use_ema and checkpoint.get("state_dict_ema"): |
| trainer.model_ema.module.load_state_dict(checkpoint["state_dict_ema"]) |
|
|
| if args.resume and os.path.exists(args.resume): |
| if checkpoint["scaler"]: |
| trainer.scaler.load_state_dict(checkpoint["scaler"]) |
|
|
| if checkpoint["early_stopping"]: |
| trainer.early_stopping.best_score = checkpoint["early_stopping"][ |
| "best_score" |
| ] |
| trainer.early_stopping.counter = checkpoint["early_stopping"]["counter"] |
| trainer.early_stopping.early_stop = checkpoint["early_stopping"][ |
| "early_stop" |
| ] |
|
|
| trainer.fit() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|