budijuarto's picture
Upload src/egg_damage/config.py
e14b469 verified
from __future__ import annotations
import copy
from pathlib import Path
from typing import Any
import yaml
from .paths import project_root, resolve_path
DEFAULT_CONFIG: dict[str, Any] = {
"project": {"name": "egg-damage-classification"},
"paths": {
"data_dir": "../Eggs Classification",
"output_dir": "outputs",
"model_dir": "models",
"split_csv": "outputs/splits.csv",
},
"kaggle": {
"enabled": False,
"dataset": "abdullahkhanuet22/eggs-images-classification-damaged-or-not",
"download_dir": "data/raw",
},
"seed": 42,
"data": {
"image_extensions": [".jpg", ".jpeg", ".png", ".bmp", ".webp"],
"train_size": 0.70,
"val_size": 0.15,
"test_size": 0.15,
"class_names": ["Not Damaged", "Damaged"],
"positive_class": "Damaged",
"imbalance_threshold": 1.20,
},
"preprocessing": {"image_size": 224},
"balance": {
"enabled": True,
"strategy": "augment_minority",
"max_augmented_train_samples": 3000,
},
"features": {
"hog": {
"orientations": 9,
"pixels_per_cell": [16, 16],
"cells_per_block": [2, 2],
"block_norm": "L2-Hys",
},
"lbp": {"radius": 2, "n_points": 16, "method": "uniform"},
},
"classical": {
"svm": {
"kernel": ["rbf"],
"C": [1.0, 3.0],
"gamma": ["scale"],
"class_weight": "balanced",
}
},
"models": {
"enabled": {
"hog_svm": True,
"lbp_svm": True,
"mobilenet_v3": True,
"resnet50": True,
"efficientnet_b0": True,
"densenet121": False,
"xception": False,
"vit_small": False,
},
"pretrained": True,
},
"training": {
"batch_size": 16,
"epochs": 3,
"learning_rate": 3e-4,
"weight_decay": 1e-4,
"optimizer": "adamw",
"scheduler_patience": 1,
"early_stopping_patience": 2,
"num_workers": 0,
"freeze_backbone": True,
"mixed_precision": True,
"pin_memory": True,
"max_grad_norm": 5.0,
},
"augmentation": {
"enabled": True,
"horizontal_flip": True,
"rotation_degrees": 10,
"translate": 0.03,
"scale_min": 0.95,
"scale_max": 1.05,
"color_jitter": {
"enabled": True,
"brightness": 0.12,
"contrast": 0.12,
"saturation": 0.08,
"hue": 0.02,
},
},
"evaluation": {
"threshold": 0.5,
"save_precision_recall_curve": True,
"save_calibration_plot": False,
"sample_grid_count": 12,
},
"explainability": {"enabled": True, "max_images": 8},
"gradio": {
"host": "127.0.0.1",
"port": 7860,
"share": False,
"low_confidence_threshold": 0.65,
},
}
def deep_update(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
merged = copy.deepcopy(base)
for key, value in (override or {}).items():
if isinstance(value, dict) and isinstance(merged.get(key), dict):
merged[key] = deep_update(merged[key], value)
else:
merged[key] = value
return merged
def load_config(config_path: str | Path | None = None, overrides: dict[str, Any] | None = None) -> dict[str, Any]:
root = project_root()
path = Path(config_path).expanduser().resolve() if config_path else root / "configs" / "default.yaml"
user_config: dict[str, Any] = {}
if path.exists():
with path.open("r", encoding="utf-8") as f:
user_config = yaml.safe_load(f) or {}
config = deep_update(DEFAULT_CONFIG, user_config)
if overrides:
config = deep_update(config, overrides)
config["_config_path"] = str(path)
config["_project_root"] = str(root)
normalize_paths(config)
return config
def normalize_paths(config: dict[str, Any]) -> None:
root = Path(config["_project_root"])
for key in ("data_dir", "output_dir", "model_dir", "split_csv"):
value = config.get("paths", {}).get(key)
resolved = resolve_path(value, root)
if resolved is not None:
config["paths"][key] = str(resolved)
kaggle_dir = config.get("kaggle", {}).get("download_dir")
resolved_kaggle = resolve_path(kaggle_dir, root)
if resolved_kaggle is not None:
config["kaggle"]["download_dir"] = str(resolved_kaggle)
def save_config_snapshot(config: dict[str, Any], output_dir: str | Path) -> Path:
path = Path(output_dir) / "config_resolved.yaml"
path.parent.mkdir(parents=True, exist_ok=True)
safe = copy.deepcopy(config)
with path.open("w", encoding="utf-8") as f:
yaml.safe_dump(safe, f, sort_keys=False)
return path