unionpoint's picture
Upload folder using huggingface_hub
5d2fa0b verified
import argparse
import math
import os
import torch
import wandb
from omegaconf import OmegaConf
from timm.optim import create_optimizer_v2
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, StepLR
from src.dataset import get_dataloaders
from src.loss import get_criterion
from src.models import PlantDiseaseModel, get_param_groups
from src.trainer import Trainer
from src.utils import CosineAnnealingWarmupLR, load_config, set_seed
def build_optimizer(model, config):
layer_decay = getattr(config.optimizer, "layer_decay", 1.0)
param_groups = get_param_groups(
model,
base_lr=config.optimizer.backbone_lr,
head_lr=config.optimizer.head_lr,
weight_decay=config.optimizer.weight_decay,
)
if config.optimizer.name.lower() == "adamw":
if layer_decay == 1:
optimizer = torch.optim.AdamW(param_groups)
else:
optimizer = create_optimizer_v2(
model,
opt="adamw",
lr=config.optimizer.head_lr,
layer_decay=layer_decay,
weight_decay=config.optimizer.weight_decay,
)
else:
optimizer = torch.optim.Adam(param_groups)
return optimizer
def build_scheduler(optimizer, config, len_loader):
if config.scheduler.name.lower() == "cosine":
return CosineAnnealingLR(
optimizer, T_max=config.training.epochs, eta_min=config.scheduler.min_lr
)
elif config.scheduler.name.lower() == "step":
return StepLR(optimizer, step_size=3, gamma=0.1)
elif config.scheduler.name.lower() == "plateau":
return ReduceLROnPlateau(
optimizer,
mode="max",
factor=0.1,
patience=3,
min_lr=config.scheduler.min_lr,
)
elif config.scheduler.name.lower() == "cosine_warmup":
return CosineAnnealingWarmupLR(
optimizer,
warmup_steps=config.scheduler.warmup_epochs
* len_loader
/ config.training.gradient_accumulation_steps,
total_steps=config.training.epochs
* len_loader
/ config.training.gradient_accumulation_steps,
min_lr=config.scheduler.min_lr,
)
else:
return None
def main():
parser = argparse.ArgumentParser(
description="Train Plant Disease Classification Baseline"
)
parser.add_argument(
"--config", type=str, default="configs/config.yaml", help="Path to config file"
)
parser.add_argument(
"--resume", type=str, default=None, help="Path to checkpoint to resume from"
)
parser.add_argument(
"--init_weights", type=str, default=None, help="Path to weights for warm start"
)
args = parser.parse_args()
config = load_config(args.config)
set_seed(config.seed, deterministic=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Environment: Using device {device}")
train_loader, val_loader, num_classes = get_dataloaders(config)
if num_classes == 0:
print(
"WARNING: No data found. Make sure your datasets are correctly structured."
)
# Fallback to prevent immediate crash if no data is present yet
num_classes = 1
config.model.num_classes = num_classes
model = PlantDiseaseModel(config, num_classes=num_classes)
model.to(device)
if args.init_weights and os.path.exists(args.init_weights):
print(f"Warm starting from weights: {args.init_weights}")
checkpoint = torch.load(args.init_weights, map_location=device)
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict(state_dict)
optimizer = build_optimizer(model, config)
criterion = get_criterion(config)
scheduler = build_scheduler(optimizer, config, len(train_loader))
# resume Logic
start_epoch = 1
checkpoint = None
run_id = None
if args.resume and os.path.exists(args.resume):
print(f"Resuming experiment from checkpoint: {args.resume}")
checkpoint = torch.load(args.resume, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
if scheduler and checkpoint["scheduler"]:
scheduler.load_state_dict(checkpoint["scheduler"])
start_epoch = checkpoint["epoch"] + 1
if "rng_states" in checkpoint:
torch.set_rng_state(checkpoint["rng_states"]["torch"].cpu())
if device.type == "cuda" and checkpoint["rng_states"]["cuda"] is not None:
torch.cuda.set_rng_state_all(
[s.cpu() for s in checkpoint["rng_states"]["cuda"]]
)
if config.logging.use_wandb:
run_id = checkpoint.get("wandb_run_id")
if start_epoch > config.training.epochs:
print(
f"Requested to resume at epoch {start_epoch}, but total epochs is {config.training.epochs}. Exiting."
)
return
# Wandb tracking
if config.logging.use_wandb:
wandb_config = OmegaConf.to_container(config, resolve=True)
wandb.init(
project=config.logging.project_name,
name=config.experiment_name,
config=wandb_config,
id=run_id, # Use the loaded ID (or None if brand new)
resume="allow",
)
trainer = Trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
config=config,
device=device,
)
if checkpoint is not None:
if trainer.use_ema and checkpoint.get("state_dict_ema"):
trainer.model_ema.module.load_state_dict(checkpoint["state_dict_ema"])
if args.resume and os.path.exists(args.resume):
if checkpoint["scaler"]:
trainer.scaler.load_state_dict(checkpoint["scaler"])
if checkpoint["early_stopping"]:
trainer.early_stopping.best_score = checkpoint["early_stopping"][
"best_score"
]
trainer.early_stopping.counter = checkpoint["early_stopping"]["counter"]
trainer.early_stopping.early_stop = checkpoint["early_stopping"][
"early_stop"
]
trainer.fit()
if __name__ == "__main__":
main()