budijuarto's picture
Upload src/egg_damage/train_dl.py
5a7bdcb verified
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import accuracy_score, f1_score
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from .augmentations import build_eval_transform, build_train_transform
from .config import load_config
from .data_discovery import prepare_data
from .dataset import EggImageDataset, create_balanced_sampler
from .dl_models import MODEL_REGISTRY, checkpoint_payload, create_model, freeze_backbone_except_head
from .paths import ensure_dir
from .reporting import plot_training_curves
from .seeds import set_seed
from .utils import get_logger, save_json
LOGGER = get_logger(__name__)
def load_or_prepare_splits(config: dict[str, Any]) -> pd.DataFrame:
split_csv = Path(config["paths"]["split_csv"])
if split_csv.exists():
return pd.read_csv(split_csv)
return prepare_data(config)
def make_loaders(
splits_df: pd.DataFrame,
config: dict[str, Any],
) -> tuple[DataLoader, DataLoader]:
train_df = splits_df[splits_df["split"] == "train"].reset_index(drop=True)
val_df = splits_df[splits_df["split"] == "val"].reset_index(drop=True)
if train_df.empty or val_df.empty:
raise ValueError("Deep learning training needs non-empty train and validation splits.")
train_ds = EggImageDataset(train_df, transform=build_train_transform(config))
val_ds = EggImageDataset(val_df, transform=build_eval_transform(config))
sampler = None
shuffle = True
counts = train_df["label_id"].value_counts()
if config.get("balance", {}).get("enabled", True) and len(counts) == 2:
ratio = counts.max() / max(counts.min(), 1)
if ratio > float(config.get("data", {}).get("imbalance_threshold", 1.2)):
sampler = create_balanced_sampler(train_df, int(config["seed"]))
shuffle = False
LOGGER.info("Using balanced sampler for deep learning train loader.")
train_cfg = config["training"]
common = {
"batch_size": int(train_cfg["batch_size"]),
"num_workers": int(train_cfg.get("num_workers", 0)),
"pin_memory": bool(train_cfg.get("pin_memory", True) and torch.cuda.is_available()),
}
train_loader = DataLoader(train_ds, sampler=sampler, shuffle=shuffle, **common)
val_loader = DataLoader(val_ds, shuffle=False, **common)
return train_loader, val_loader
def class_weight_tensor(train_df: pd.DataFrame, device: torch.device) -> torch.Tensor | None:
counts = train_df["label_id"].value_counts().sort_index()
if len(counts) != 2:
return None
total = counts.sum()
weights = total / (len(counts) * counts)
return torch.tensor(weights.to_numpy(dtype=np.float32), device=device)
def epoch_pass(
model: nn.Module,
loader: DataLoader,
criterion: nn.Module,
device: torch.device,
optimizer: torch.optim.Optimizer | None = None,
scaler: torch.cuda.amp.GradScaler | None = None,
use_amp: bool = False,
max_grad_norm: float | None = None,
) -> dict[str, float]:
is_train = optimizer is not None
model.train(is_train)
losses: list[float] = []
all_true: list[int] = []
all_pred: list[int] = []
iterator = tqdm(loader, desc="train" if is_train else "val", leave=False)
for images, labels, _ in iterator:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
if is_train:
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(enabled=use_amp):
logits = model(images)
loss = criterion(logits, labels)
if is_train:
assert optimizer is not None
if scaler is not None and use_amp:
scaler.scale(loss).backward()
if max_grad_norm:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if max_grad_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
probs = torch.softmax(logits.detach(), dim=1)
pred = torch.argmax(probs, dim=1)
losses.append(float(loss.detach().cpu().item()) * labels.size(0))
all_true.extend(labels.detach().cpu().numpy().astype(int).tolist())
all_pred.extend(pred.detach().cpu().numpy().astype(int).tolist())
loss_mean = float(np.sum(losses) / max(len(all_true), 1))
return {
"loss": loss_mean,
"accuracy": float(accuracy_score(all_true, all_pred)) if all_true else 0.0,
"f1": float(f1_score(all_true, all_pred, zero_division=0)) if all_true else 0.0,
}
def train_one_dl_model(model_key: str, splits_df: pd.DataFrame, config: dict[str, Any]) -> Path:
set_seed(int(config["seed"]))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dir = ensure_dir(config["paths"]["model_dir"])
output_dir = ensure_dir(config["paths"]["output_dir"])
train_loader, val_loader = make_loaders(splits_df, config)
train_df = splits_df[splits_df["split"] == "train"].reset_index(drop=True)
model = create_model(model_key, config).to(device)
if bool(config["training"].get("freeze_backbone", True)):
if bool(getattr(model, "_egg_pretrained_loaded", False)):
freeze_backbone_except_head(model)
else:
LOGGER.warning(
"%s is not using pretrained weights; leaving the full model trainable instead of freezing random features.",
model_key,
)
trainable_params = [p for p in model.parameters() if p.requires_grad]
if not trainable_params:
raise RuntimeError(f"{model_key} has no trainable parameters.")
optimizer_name = str(config["training"].get("optimizer", "adamw")).lower()
opt_cls = torch.optim.AdamW if optimizer_name == "adamw" else torch.optim.Adam
optimizer = opt_cls(
trainable_params,
lr=float(config["training"]["learning_rate"]),
weight_decay=float(config["training"].get("weight_decay", 0.0)),
)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.5,
patience=int(config["training"].get("scheduler_patience", 1)),
)
use_amp = bool(config["training"].get("mixed_precision", True) and device.type == "cuda")
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
best_f1 = -1.0
best_path = model_dir / f"{model_key}.pt"
history: list[dict[str, Any]] = []
patience = int(config["training"].get("early_stopping_patience", 2))
stale_epochs = 0
for epoch in range(1, int(config["training"]["epochs"]) + 1):
LOGGER.info("Training %s epoch %d/%d", model_key, epoch, int(config["training"]["epochs"]))
train_stats = epoch_pass(
model,
train_loader,
criterion,
device,
optimizer=optimizer,
scaler=scaler,
use_amp=use_amp,
max_grad_norm=float(config["training"].get("max_grad_norm", 0.0)) or None,
)
with torch.no_grad():
val_stats = epoch_pass(model, val_loader, criterion, device, use_amp=use_amp)
scheduler.step(val_stats["loss"])
row = {
"epoch": epoch,
"train_loss": train_stats["loss"],
"val_loss": val_stats["loss"],
"train_accuracy": train_stats["accuracy"],
"val_accuracy": val_stats["accuracy"],
"train_f1": train_stats["f1"],
"val_f1": val_stats["f1"],
"learning_rate": float(optimizer.param_groups[0]["lr"]),
}
history.append(row)
LOGGER.info(
"%s epoch %d: train_loss=%.4f val_loss=%.4f val_f1=%.4f",
model_key,
epoch,
row["train_loss"],
row["val_loss"],
row["val_f1"],
)
if row["val_f1"] > best_f1:
best_f1 = row["val_f1"]
stale_epochs = 0
payload = checkpoint_payload(model, model_key, config, history, best_f1)
torch.save(payload, best_path)
LOGGER.info("Saved best checkpoint for %s: %s", model_key, best_path)
else:
stale_epochs += 1
if stale_epochs >= patience:
LOGGER.info("Early stopping %s after %d stale epoch(s).", model_key, stale_epochs)
break
history_df = pd.DataFrame(history)
history_path = output_dir / "histories" / f"{model_key}_history.csv"
history_path.parent.mkdir(parents=True, exist_ok=True)
history_df.to_csv(history_path, index=False)
save_json(history, output_dir / "histories" / f"{model_key}_history.json")
plot_training_curves(history_df, output_dir / "plots" / f"training_curves_{model_key}.png", model_key)
return best_path
def train_dl_models(config: dict[str, Any], model_keys: list[str] | None = None) -> list[Path]:
splits_df = load_or_prepare_splits(config)
enabled = config["models"]["enabled"]
requested = model_keys or [key for key in MODEL_REGISTRY if enabled.get(key, False)]
paths: list[Path] = []
for key in requested:
if key not in MODEL_REGISTRY:
continue
paths.append(train_one_dl_model(key, splits_df, config))
return paths
def main() -> None:
parser = argparse.ArgumentParser(description="Train deep transfer-learning egg classifiers.")
parser.add_argument("--config", default="configs/default.yaml")
parser.add_argument("--models", nargs="*", default=None, choices=list(MODEL_REGISTRY))
args = parser.parse_args()
config = load_config(args.config)
train_dl_models(config, args.models)
if __name__ == "__main__":
main()