from __future__ import annotations import copy from pathlib import Path from typing import Any import yaml from .paths import project_root, resolve_path DEFAULT_CONFIG: dict[str, Any] = { "project": {"name": "egg-damage-classification"}, "paths": { "data_dir": "../Eggs Classification", "output_dir": "outputs", "model_dir": "models", "split_csv": "outputs/splits.csv", }, "kaggle": { "enabled": False, "dataset": "abdullahkhanuet22/eggs-images-classification-damaged-or-not", "download_dir": "data/raw", }, "seed": 42, "data": { "image_extensions": [".jpg", ".jpeg", ".png", ".bmp", ".webp"], "train_size": 0.70, "val_size": 0.15, "test_size": 0.15, "class_names": ["Not Damaged", "Damaged"], "positive_class": "Damaged", "imbalance_threshold": 1.20, }, "preprocessing": {"image_size": 224}, "balance": { "enabled": True, "strategy": "augment_minority", "max_augmented_train_samples": 3000, }, "features": { "hog": { "orientations": 9, "pixels_per_cell": [16, 16], "cells_per_block": [2, 2], "block_norm": "L2-Hys", }, "lbp": {"radius": 2, "n_points": 16, "method": "uniform"}, }, "classical": { "svm": { "kernel": ["rbf"], "C": [1.0, 3.0], "gamma": ["scale"], "class_weight": "balanced", } }, "models": { "enabled": { "hog_svm": True, "lbp_svm": True, "mobilenet_v3": True, "resnet50": True, "efficientnet_b0": True, "densenet121": False, "xception": False, "vit_small": False, }, "pretrained": True, }, "training": { "batch_size": 16, "epochs": 3, "learning_rate": 3e-4, "weight_decay": 1e-4, "optimizer": "adamw", "scheduler_patience": 1, "early_stopping_patience": 2, "num_workers": 0, "freeze_backbone": True, "mixed_precision": True, "pin_memory": True, "max_grad_norm": 5.0, }, "augmentation": { "enabled": True, "horizontal_flip": True, "rotation_degrees": 10, "translate": 0.03, "scale_min": 0.95, "scale_max": 1.05, "color_jitter": { "enabled": True, "brightness": 0.12, "contrast": 0.12, "saturation": 0.08, "hue": 0.02, }, }, "evaluation": { "threshold": 0.5, "save_precision_recall_curve": True, "save_calibration_plot": False, "sample_grid_count": 12, }, "explainability": {"enabled": True, "max_images": 8}, "gradio": { "host": "127.0.0.1", "port": 7860, "share": False, "low_confidence_threshold": 0.65, }, } def deep_update(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: merged = copy.deepcopy(base) for key, value in (override or {}).items(): if isinstance(value, dict) and isinstance(merged.get(key), dict): merged[key] = deep_update(merged[key], value) else: merged[key] = value return merged def load_config(config_path: str | Path | None = None, overrides: dict[str, Any] | None = None) -> dict[str, Any]: root = project_root() path = Path(config_path).expanduser().resolve() if config_path else root / "configs" / "default.yaml" user_config: dict[str, Any] = {} if path.exists(): with path.open("r", encoding="utf-8") as f: user_config = yaml.safe_load(f) or {} config = deep_update(DEFAULT_CONFIG, user_config) if overrides: config = deep_update(config, overrides) config["_config_path"] = str(path) config["_project_root"] = str(root) normalize_paths(config) return config def normalize_paths(config: dict[str, Any]) -> None: root = Path(config["_project_root"]) for key in ("data_dir", "output_dir", "model_dir", "split_csv"): value = config.get("paths", {}).get(key) resolved = resolve_path(value, root) if resolved is not None: config["paths"][key] = str(resolved) kaggle_dir = config.get("kaggle", {}).get("download_dir") resolved_kaggle = resolve_path(kaggle_dir, root) if resolved_kaggle is not None: config["kaggle"]["download_dir"] = str(resolved_kaggle) def save_config_snapshot(config: dict[str, Any], output_dir: str | Path) -> Path: path = Path(output_dir) / "config_resolved.yaml" path.parent.mkdir(parents=True, exist_ok=True) safe = copy.deepcopy(config) with path.open("w", encoding="utf-8") as f: yaml.safe_dump(safe, f, sort_keys=False) return path