from __future__ import annotations import argparse from pathlib import Path from typing import Any import numpy as np import pandas as pd import torch from sklearn.metrics import accuracy_score, f1_score from torch import nn from torch.utils.data import DataLoader from tqdm import tqdm from .augmentations import build_eval_transform, build_train_transform from .config import load_config from .data_discovery import prepare_data from .dataset import EggImageDataset, create_balanced_sampler from .dl_models import MODEL_REGISTRY, checkpoint_payload, create_model, freeze_backbone_except_head from .paths import ensure_dir from .reporting import plot_training_curves from .seeds import set_seed from .utils import get_logger, save_json LOGGER = get_logger(__name__) def load_or_prepare_splits(config: dict[str, Any]) -> pd.DataFrame: split_csv = Path(config["paths"]["split_csv"]) if split_csv.exists(): return pd.read_csv(split_csv) return prepare_data(config) def make_loaders( splits_df: pd.DataFrame, config: dict[str, Any], ) -> tuple[DataLoader, DataLoader]: train_df = splits_df[splits_df["split"] == "train"].reset_index(drop=True) val_df = splits_df[splits_df["split"] == "val"].reset_index(drop=True) if train_df.empty or val_df.empty: raise ValueError("Deep learning training needs non-empty train and validation splits.") train_ds = EggImageDataset(train_df, transform=build_train_transform(config)) val_ds = EggImageDataset(val_df, transform=build_eval_transform(config)) sampler = None shuffle = True counts = train_df["label_id"].value_counts() if config.get("balance", {}).get("enabled", True) and len(counts) == 2: ratio = counts.max() / max(counts.min(), 1) if ratio > float(config.get("data", {}).get("imbalance_threshold", 1.2)): sampler = create_balanced_sampler(train_df, int(config["seed"])) shuffle = False LOGGER.info("Using balanced sampler for deep learning train loader.") train_cfg = config["training"] common = { "batch_size": int(train_cfg["batch_size"]), "num_workers": int(train_cfg.get("num_workers", 0)), "pin_memory": bool(train_cfg.get("pin_memory", True) and torch.cuda.is_available()), } train_loader = DataLoader(train_ds, sampler=sampler, shuffle=shuffle, **common) val_loader = DataLoader(val_ds, shuffle=False, **common) return train_loader, val_loader def class_weight_tensor(train_df: pd.DataFrame, device: torch.device) -> torch.Tensor | None: counts = train_df["label_id"].value_counts().sort_index() if len(counts) != 2: return None total = counts.sum() weights = total / (len(counts) * counts) return torch.tensor(weights.to_numpy(dtype=np.float32), device=device) def epoch_pass( model: nn.Module, loader: DataLoader, criterion: nn.Module, device: torch.device, optimizer: torch.optim.Optimizer | None = None, scaler: torch.cuda.amp.GradScaler | None = None, use_amp: bool = False, max_grad_norm: float | None = None, ) -> dict[str, float]: is_train = optimizer is not None model.train(is_train) losses: list[float] = [] all_true: list[int] = [] all_pred: list[int] = [] iterator = tqdm(loader, desc="train" if is_train else "val", leave=False) for images, labels, _ in iterator: images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) if is_train: optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(enabled=use_amp): logits = model(images) loss = criterion(logits, labels) if is_train: assert optimizer is not None if scaler is not None and use_amp: scaler.scale(loss).backward() if max_grad_norm: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) scaler.step(optimizer) scaler.update() else: loss.backward() if max_grad_norm: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() probs = torch.softmax(logits.detach(), dim=1) pred = torch.argmax(probs, dim=1) losses.append(float(loss.detach().cpu().item()) * labels.size(0)) all_true.extend(labels.detach().cpu().numpy().astype(int).tolist()) all_pred.extend(pred.detach().cpu().numpy().astype(int).tolist()) loss_mean = float(np.sum(losses) / max(len(all_true), 1)) return { "loss": loss_mean, "accuracy": float(accuracy_score(all_true, all_pred)) if all_true else 0.0, "f1": float(f1_score(all_true, all_pred, zero_division=0)) if all_true else 0.0, } def train_one_dl_model(model_key: str, splits_df: pd.DataFrame, config: dict[str, Any]) -> Path: set_seed(int(config["seed"])) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_dir = ensure_dir(config["paths"]["model_dir"]) output_dir = ensure_dir(config["paths"]["output_dir"]) train_loader, val_loader = make_loaders(splits_df, config) train_df = splits_df[splits_df["split"] == "train"].reset_index(drop=True) model = create_model(model_key, config).to(device) if bool(config["training"].get("freeze_backbone", True)): if bool(getattr(model, "_egg_pretrained_loaded", False)): freeze_backbone_except_head(model) else: LOGGER.warning( "%s is not using pretrained weights; leaving the full model trainable instead of freezing random features.", model_key, ) trainable_params = [p for p in model.parameters() if p.requires_grad] if not trainable_params: raise RuntimeError(f"{model_key} has no trainable parameters.") optimizer_name = str(config["training"].get("optimizer", "adamw")).lower() opt_cls = torch.optim.AdamW if optimizer_name == "adamw" else torch.optim.Adam optimizer = opt_cls( trainable_params, lr=float(config["training"]["learning_rate"]), weight_decay=float(config["training"].get("weight_decay", 0.0)), ) criterion = nn.CrossEntropyLoss() scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=int(config["training"].get("scheduler_patience", 1)), ) use_amp = bool(config["training"].get("mixed_precision", True) and device.type == "cuda") scaler = torch.cuda.amp.GradScaler(enabled=use_amp) best_f1 = -1.0 best_path = model_dir / f"{model_key}.pt" history: list[dict[str, Any]] = [] patience = int(config["training"].get("early_stopping_patience", 2)) stale_epochs = 0 for epoch in range(1, int(config["training"]["epochs"]) + 1): LOGGER.info("Training %s epoch %d/%d", model_key, epoch, int(config["training"]["epochs"])) train_stats = epoch_pass( model, train_loader, criterion, device, optimizer=optimizer, scaler=scaler, use_amp=use_amp, max_grad_norm=float(config["training"].get("max_grad_norm", 0.0)) or None, ) with torch.no_grad(): val_stats = epoch_pass(model, val_loader, criterion, device, use_amp=use_amp) scheduler.step(val_stats["loss"]) row = { "epoch": epoch, "train_loss": train_stats["loss"], "val_loss": val_stats["loss"], "train_accuracy": train_stats["accuracy"], "val_accuracy": val_stats["accuracy"], "train_f1": train_stats["f1"], "val_f1": val_stats["f1"], "learning_rate": float(optimizer.param_groups[0]["lr"]), } history.append(row) LOGGER.info( "%s epoch %d: train_loss=%.4f val_loss=%.4f val_f1=%.4f", model_key, epoch, row["train_loss"], row["val_loss"], row["val_f1"], ) if row["val_f1"] > best_f1: best_f1 = row["val_f1"] stale_epochs = 0 payload = checkpoint_payload(model, model_key, config, history, best_f1) torch.save(payload, best_path) LOGGER.info("Saved best checkpoint for %s: %s", model_key, best_path) else: stale_epochs += 1 if stale_epochs >= patience: LOGGER.info("Early stopping %s after %d stale epoch(s).", model_key, stale_epochs) break history_df = pd.DataFrame(history) history_path = output_dir / "histories" / f"{model_key}_history.csv" history_path.parent.mkdir(parents=True, exist_ok=True) history_df.to_csv(history_path, index=False) save_json(history, output_dir / "histories" / f"{model_key}_history.json") plot_training_curves(history_df, output_dir / "plots" / f"training_curves_{model_key}.png", model_key) return best_path def train_dl_models(config: dict[str, Any], model_keys: list[str] | None = None) -> list[Path]: splits_df = load_or_prepare_splits(config) enabled = config["models"]["enabled"] requested = model_keys or [key for key in MODEL_REGISTRY if enabled.get(key, False)] paths: list[Path] = [] for key in requested: if key not in MODEL_REGISTRY: continue paths.append(train_one_dl_model(key, splits_df, config)) return paths def main() -> None: parser = argparse.ArgumentParser(description="Train deep transfer-learning egg classifiers.") parser.add_argument("--config", default="configs/default.yaml") parser.add_argument("--models", nargs="*", default=None, choices=list(MODEL_REGISTRY)) args = parser.parse_args() config = load_config(args.config) train_dl_models(config, args.models) if __name__ == "__main__": main()