| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from sklearn.metrics import accuracy_score, f1_score |
| from torch import nn |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from .augmentations import build_eval_transform, build_train_transform |
| from .config import load_config |
| from .data_discovery import prepare_data |
| from .dataset import EggImageDataset, create_balanced_sampler |
| from .dl_models import MODEL_REGISTRY, checkpoint_payload, create_model, freeze_backbone_except_head |
| from .paths import ensure_dir |
| from .reporting import plot_training_curves |
| from .seeds import set_seed |
| from .utils import get_logger, save_json |
|
|
|
|
| LOGGER = get_logger(__name__) |
|
|
|
|
| def load_or_prepare_splits(config: dict[str, Any]) -> pd.DataFrame: |
| split_csv = Path(config["paths"]["split_csv"]) |
| if split_csv.exists(): |
| return pd.read_csv(split_csv) |
| return prepare_data(config) |
|
|
|
|
| def make_loaders( |
| splits_df: pd.DataFrame, |
| config: dict[str, Any], |
| ) -> tuple[DataLoader, DataLoader]: |
| train_df = splits_df[splits_df["split"] == "train"].reset_index(drop=True) |
| val_df = splits_df[splits_df["split"] == "val"].reset_index(drop=True) |
| if train_df.empty or val_df.empty: |
| raise ValueError("Deep learning training needs non-empty train and validation splits.") |
|
|
| train_ds = EggImageDataset(train_df, transform=build_train_transform(config)) |
| val_ds = EggImageDataset(val_df, transform=build_eval_transform(config)) |
| sampler = None |
| shuffle = True |
| counts = train_df["label_id"].value_counts() |
| if config.get("balance", {}).get("enabled", True) and len(counts) == 2: |
| ratio = counts.max() / max(counts.min(), 1) |
| if ratio > float(config.get("data", {}).get("imbalance_threshold", 1.2)): |
| sampler = create_balanced_sampler(train_df, int(config["seed"])) |
| shuffle = False |
| LOGGER.info("Using balanced sampler for deep learning train loader.") |
|
|
| train_cfg = config["training"] |
| common = { |
| "batch_size": int(train_cfg["batch_size"]), |
| "num_workers": int(train_cfg.get("num_workers", 0)), |
| "pin_memory": bool(train_cfg.get("pin_memory", True) and torch.cuda.is_available()), |
| } |
| train_loader = DataLoader(train_ds, sampler=sampler, shuffle=shuffle, **common) |
| val_loader = DataLoader(val_ds, shuffle=False, **common) |
| return train_loader, val_loader |
|
|
|
|
| def class_weight_tensor(train_df: pd.DataFrame, device: torch.device) -> torch.Tensor | None: |
| counts = train_df["label_id"].value_counts().sort_index() |
| if len(counts) != 2: |
| return None |
| total = counts.sum() |
| weights = total / (len(counts) * counts) |
| return torch.tensor(weights.to_numpy(dtype=np.float32), device=device) |
|
|
|
|
| def epoch_pass( |
| model: nn.Module, |
| loader: DataLoader, |
| criterion: nn.Module, |
| device: torch.device, |
| optimizer: torch.optim.Optimizer | None = None, |
| scaler: torch.cuda.amp.GradScaler | None = None, |
| use_amp: bool = False, |
| max_grad_norm: float | None = None, |
| ) -> dict[str, float]: |
| is_train = optimizer is not None |
| model.train(is_train) |
| losses: list[float] = [] |
| all_true: list[int] = [] |
| all_pred: list[int] = [] |
| iterator = tqdm(loader, desc="train" if is_train else "val", leave=False) |
| for images, labels, _ in iterator: |
| images = images.to(device, non_blocking=True) |
| labels = labels.to(device, non_blocking=True) |
| if is_train: |
| optimizer.zero_grad(set_to_none=True) |
| with torch.cuda.amp.autocast(enabled=use_amp): |
| logits = model(images) |
| loss = criterion(logits, labels) |
| if is_train: |
| assert optimizer is not None |
| if scaler is not None and use_amp: |
| scaler.scale(loss).backward() |
| if max_grad_norm: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| if max_grad_norm: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
| optimizer.step() |
| probs = torch.softmax(logits.detach(), dim=1) |
| pred = torch.argmax(probs, dim=1) |
| losses.append(float(loss.detach().cpu().item()) * labels.size(0)) |
| all_true.extend(labels.detach().cpu().numpy().astype(int).tolist()) |
| all_pred.extend(pred.detach().cpu().numpy().astype(int).tolist()) |
| loss_mean = float(np.sum(losses) / max(len(all_true), 1)) |
| return { |
| "loss": loss_mean, |
| "accuracy": float(accuracy_score(all_true, all_pred)) if all_true else 0.0, |
| "f1": float(f1_score(all_true, all_pred, zero_division=0)) if all_true else 0.0, |
| } |
|
|
|
|
| def train_one_dl_model(model_key: str, splits_df: pd.DataFrame, config: dict[str, Any]) -> Path: |
| set_seed(int(config["seed"])) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model_dir = ensure_dir(config["paths"]["model_dir"]) |
| output_dir = ensure_dir(config["paths"]["output_dir"]) |
|
|
| train_loader, val_loader = make_loaders(splits_df, config) |
| train_df = splits_df[splits_df["split"] == "train"].reset_index(drop=True) |
| model = create_model(model_key, config).to(device) |
| if bool(config["training"].get("freeze_backbone", True)): |
| if bool(getattr(model, "_egg_pretrained_loaded", False)): |
| freeze_backbone_except_head(model) |
| else: |
| LOGGER.warning( |
| "%s is not using pretrained weights; leaving the full model trainable instead of freezing random features.", |
| model_key, |
| ) |
|
|
| trainable_params = [p for p in model.parameters() if p.requires_grad] |
| if not trainable_params: |
| raise RuntimeError(f"{model_key} has no trainable parameters.") |
| optimizer_name = str(config["training"].get("optimizer", "adamw")).lower() |
| opt_cls = torch.optim.AdamW if optimizer_name == "adamw" else torch.optim.Adam |
| optimizer = opt_cls( |
| trainable_params, |
| lr=float(config["training"]["learning_rate"]), |
| weight_decay=float(config["training"].get("weight_decay", 0.0)), |
| ) |
| criterion = nn.CrossEntropyLoss() |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, |
| mode="min", |
| factor=0.5, |
| patience=int(config["training"].get("scheduler_patience", 1)), |
| ) |
| use_amp = bool(config["training"].get("mixed_precision", True) and device.type == "cuda") |
| scaler = torch.cuda.amp.GradScaler(enabled=use_amp) |
|
|
| best_f1 = -1.0 |
| best_path = model_dir / f"{model_key}.pt" |
| history: list[dict[str, Any]] = [] |
| patience = int(config["training"].get("early_stopping_patience", 2)) |
| stale_epochs = 0 |
|
|
| for epoch in range(1, int(config["training"]["epochs"]) + 1): |
| LOGGER.info("Training %s epoch %d/%d", model_key, epoch, int(config["training"]["epochs"])) |
| train_stats = epoch_pass( |
| model, |
| train_loader, |
| criterion, |
| device, |
| optimizer=optimizer, |
| scaler=scaler, |
| use_amp=use_amp, |
| max_grad_norm=float(config["training"].get("max_grad_norm", 0.0)) or None, |
| ) |
| with torch.no_grad(): |
| val_stats = epoch_pass(model, val_loader, criterion, device, use_amp=use_amp) |
| scheduler.step(val_stats["loss"]) |
| row = { |
| "epoch": epoch, |
| "train_loss": train_stats["loss"], |
| "val_loss": val_stats["loss"], |
| "train_accuracy": train_stats["accuracy"], |
| "val_accuracy": val_stats["accuracy"], |
| "train_f1": train_stats["f1"], |
| "val_f1": val_stats["f1"], |
| "learning_rate": float(optimizer.param_groups[0]["lr"]), |
| } |
| history.append(row) |
| LOGGER.info( |
| "%s epoch %d: train_loss=%.4f val_loss=%.4f val_f1=%.4f", |
| model_key, |
| epoch, |
| row["train_loss"], |
| row["val_loss"], |
| row["val_f1"], |
| ) |
| if row["val_f1"] > best_f1: |
| best_f1 = row["val_f1"] |
| stale_epochs = 0 |
| payload = checkpoint_payload(model, model_key, config, history, best_f1) |
| torch.save(payload, best_path) |
| LOGGER.info("Saved best checkpoint for %s: %s", model_key, best_path) |
| else: |
| stale_epochs += 1 |
| if stale_epochs >= patience: |
| LOGGER.info("Early stopping %s after %d stale epoch(s).", model_key, stale_epochs) |
| break |
|
|
| history_df = pd.DataFrame(history) |
| history_path = output_dir / "histories" / f"{model_key}_history.csv" |
| history_path.parent.mkdir(parents=True, exist_ok=True) |
| history_df.to_csv(history_path, index=False) |
| save_json(history, output_dir / "histories" / f"{model_key}_history.json") |
| plot_training_curves(history_df, output_dir / "plots" / f"training_curves_{model_key}.png", model_key) |
| return best_path |
|
|
|
|
| def train_dl_models(config: dict[str, Any], model_keys: list[str] | None = None) -> list[Path]: |
| splits_df = load_or_prepare_splits(config) |
| enabled = config["models"]["enabled"] |
| requested = model_keys or [key for key in MODEL_REGISTRY if enabled.get(key, False)] |
| paths: list[Path] = [] |
| for key in requested: |
| if key not in MODEL_REGISTRY: |
| continue |
| paths.append(train_one_dl_model(key, splits_df, config)) |
| return paths |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Train deep transfer-learning egg classifiers.") |
| parser.add_argument("--config", default="configs/default.yaml") |
| parser.add_argument("--models", nargs="*", default=None, choices=list(MODEL_REGISTRY)) |
| args = parser.parse_args() |
| config = load_config(args.config) |
| train_dl_models(config, args.models) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|