| from __future__ import annotations |
|
|
| import copy |
| from pathlib import Path |
| from typing import Any |
|
|
| import yaml |
|
|
| from .paths import project_root, resolve_path |
|
|
|
|
| DEFAULT_CONFIG: dict[str, Any] = { |
| "project": {"name": "egg-damage-classification"}, |
| "paths": { |
| "data_dir": "../Eggs Classification", |
| "output_dir": "outputs", |
| "model_dir": "models", |
| "split_csv": "outputs/splits.csv", |
| }, |
| "kaggle": { |
| "enabled": False, |
| "dataset": "abdullahkhanuet22/eggs-images-classification-damaged-or-not", |
| "download_dir": "data/raw", |
| }, |
| "seed": 42, |
| "data": { |
| "image_extensions": [".jpg", ".jpeg", ".png", ".bmp", ".webp"], |
| "train_size": 0.70, |
| "val_size": 0.15, |
| "test_size": 0.15, |
| "class_names": ["Not Damaged", "Damaged"], |
| "positive_class": "Damaged", |
| "imbalance_threshold": 1.20, |
| }, |
| "preprocessing": {"image_size": 224}, |
| "balance": { |
| "enabled": True, |
| "strategy": "augment_minority", |
| "max_augmented_train_samples": 3000, |
| }, |
| "features": { |
| "hog": { |
| "orientations": 9, |
| "pixels_per_cell": [16, 16], |
| "cells_per_block": [2, 2], |
| "block_norm": "L2-Hys", |
| }, |
| "lbp": {"radius": 2, "n_points": 16, "method": "uniform"}, |
| }, |
| "classical": { |
| "svm": { |
| "kernel": ["rbf"], |
| "C": [1.0, 3.0], |
| "gamma": ["scale"], |
| "class_weight": "balanced", |
| } |
| }, |
| "models": { |
| "enabled": { |
| "hog_svm": True, |
| "lbp_svm": True, |
| "mobilenet_v3": True, |
| "resnet50": True, |
| "efficientnet_b0": True, |
| "densenet121": False, |
| "xception": False, |
| "vit_small": False, |
| }, |
| "pretrained": True, |
| }, |
| "training": { |
| "batch_size": 16, |
| "epochs": 3, |
| "learning_rate": 3e-4, |
| "weight_decay": 1e-4, |
| "optimizer": "adamw", |
| "scheduler_patience": 1, |
| "early_stopping_patience": 2, |
| "num_workers": 0, |
| "freeze_backbone": True, |
| "mixed_precision": True, |
| "pin_memory": True, |
| "max_grad_norm": 5.0, |
| }, |
| "augmentation": { |
| "enabled": True, |
| "horizontal_flip": True, |
| "rotation_degrees": 10, |
| "translate": 0.03, |
| "scale_min": 0.95, |
| "scale_max": 1.05, |
| "color_jitter": { |
| "enabled": True, |
| "brightness": 0.12, |
| "contrast": 0.12, |
| "saturation": 0.08, |
| "hue": 0.02, |
| }, |
| }, |
| "evaluation": { |
| "threshold": 0.5, |
| "save_precision_recall_curve": True, |
| "save_calibration_plot": False, |
| "sample_grid_count": 12, |
| }, |
| "explainability": {"enabled": True, "max_images": 8}, |
| "gradio": { |
| "host": "127.0.0.1", |
| "port": 7860, |
| "share": False, |
| "low_confidence_threshold": 0.65, |
| }, |
| } |
|
|
|
|
| def deep_update(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: |
| merged = copy.deepcopy(base) |
| for key, value in (override or {}).items(): |
| if isinstance(value, dict) and isinstance(merged.get(key), dict): |
| merged[key] = deep_update(merged[key], value) |
| else: |
| merged[key] = value |
| return merged |
|
|
|
|
| def load_config(config_path: str | Path | None = None, overrides: dict[str, Any] | None = None) -> dict[str, Any]: |
| root = project_root() |
| path = Path(config_path).expanduser().resolve() if config_path else root / "configs" / "default.yaml" |
| user_config: dict[str, Any] = {} |
| if path.exists(): |
| with path.open("r", encoding="utf-8") as f: |
| user_config = yaml.safe_load(f) or {} |
| config = deep_update(DEFAULT_CONFIG, user_config) |
| if overrides: |
| config = deep_update(config, overrides) |
| config["_config_path"] = str(path) |
| config["_project_root"] = str(root) |
| normalize_paths(config) |
| return config |
|
|
|
|
| def normalize_paths(config: dict[str, Any]) -> None: |
| root = Path(config["_project_root"]) |
| for key in ("data_dir", "output_dir", "model_dir", "split_csv"): |
| value = config.get("paths", {}).get(key) |
| resolved = resolve_path(value, root) |
| if resolved is not None: |
| config["paths"][key] = str(resolved) |
| kaggle_dir = config.get("kaggle", {}).get("download_dir") |
| resolved_kaggle = resolve_path(kaggle_dir, root) |
| if resolved_kaggle is not None: |
| config["kaggle"]["download_dir"] = str(resolved_kaggle) |
|
|
|
|
| def save_config_snapshot(config: dict[str, Any], output_dir: str | Path) -> Path: |
| path = Path(output_dir) / "config_resolved.yaml" |
| path.parent.mkdir(parents=True, exist_ok=True) |
| safe = copy.deepcopy(config) |
| with path.open("w", encoding="utf-8") as f: |
| yaml.safe_dump(safe, f, sort_keys=False) |
| return path |
|
|
|
|