Kunitomi's picture
Upload folder using huggingface_hub
196c526 verified
#!/usr/bin/env python3
"""
Production-grade configuration management system for bean color analysis.
Supports environment-specific configurations, validation, and hot-reloading.
"""
import os
import logging
from pathlib import Path
from typing import Dict, Any, Optional, Union, List
from dataclasses import dataclass, field
import yaml
from functools import lru_cache
logger = logging.getLogger(__name__)
@dataclass
class ModelConfig:
"""Model configuration parameters."""
num_classes: int
checkpoint_path: str
device: str
@dataclass
class ImageConfig:
"""Image processing configuration."""
resize_width: int
resize_height: int
imagenet_mean: List[float]
imagenet_std: List[float]
@dataclass
class TrainingConfig:
"""Training configuration parameters."""
epochs: int
batch_size: int
learning_rate: float
weight_decay: float
momentum: float
lr_scheduler_step: int
lr_scheduler_gamma: float
gradient_clip_max_norm: float
backbone_freeze_epochs: int
@dataclass
class AugmentationConfig:
"""Data augmentation configuration."""
random_rotate90_prob: float
rotate_limit: int
rotate_prob: float
horizontal_flip_prob: float
vertical_flip_prob: float
brightness_limit: float
contrast_limit: float
brightness_contrast_prob: float
@dataclass
class InferenceConfig:
"""Inference configuration parameters."""
confidence_threshold: float
nms_threshold: float
min_contour_area: int
@dataclass
class PathsConfig:
"""File and directory paths."""
data_dir: str
coco_json: str
log_dir: str
output_dir: str
@dataclass
class LoggingConfig:
"""Logging configuration."""
level: str
format: str
file_logging: bool
console_logging: bool
@dataclass
class APIConfig:
"""API configuration parameters."""
title: str
description: str
version: str
host: str
port: int
debug: bool
reload: bool
workers: int
max_file_size_mb: int
max_batch_size: int
request_timeout_seconds: int
rate_limit_enabled: bool
rate_limit_requests: int
rate_limit_window_minutes: int
cors_enabled: bool
cors_origins: List[str]
cors_methods: List[str]
cors_headers: List[str]
auth_enabled: bool
api_key_header: str
admin_api_key: str
include_model_info: bool
include_processing_time: bool
default_return_polygons: bool
cache_enabled: bool
cache_ttl_seconds: int
cache_max_size: int
@dataclass
class BeanVisionConfig:
"""Main configuration class for Bean Vision."""
model: ModelConfig
image: ImageConfig
training: TrainingConfig
augmentation: AugmentationConfig
inference: InferenceConfig
paths: PathsConfig
logging: LoggingConfig
api: Optional[APIConfig] = None
@classmethod
def from_yaml(cls, config_path: str) -> 'BeanVisionConfig':
"""Load configuration from YAML file."""
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {config_path}")
try:
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML configuration: {e}")
try:
# API config is optional
api_config = None
if 'api' in config_dict:
api_config = APIConfig(**config_dict['api'])
return cls(
model=ModelConfig(**config_dict['model']),
image=ImageConfig(**config_dict['image']),
training=TrainingConfig(**config_dict['training']),
augmentation=AugmentationConfig(**config_dict['augmentation']),
inference=InferenceConfig(**config_dict['inference']),
paths=PathsConfig(**config_dict['paths']),
logging=LoggingConfig(**config_dict['logging']),
api=api_config
)
except (KeyError, TypeError) as e:
raise ValueError(f"Invalid configuration structure: {e}")
def validate(self) -> None:
"""Validate configuration values."""
# Validate model config
if self.model.num_classes < 2:
raise ValueError("num_classes must be >= 2")
# Validate image config
if self.image.resize_width <= 0 or self.image.resize_height <= 0:
raise ValueError("Image dimensions must be positive")
if len(self.image.imagenet_mean) != 3 or len(self.image.imagenet_std) != 3:
raise ValueError("ImageNet mean and std must have 3 values")
# Validate training config
if self.training.epochs <= 0:
raise ValueError("epochs must be positive")
if self.training.batch_size <= 0:
raise ValueError("batch_size must be positive")
if self.training.learning_rate <= 0:
raise ValueError("learning_rate must be positive")
# Validate inference config
if not 0 <= self.inference.confidence_threshold <= 1:
raise ValueError("confidence_threshold must be between 0 and 1")
if not 0 <= self.inference.nms_threshold <= 1:
raise ValueError("nms_threshold must be between 0 and 1")
# Validate paths
for path_attr in ['data_dir', 'log_dir', 'output_dir']:
path = getattr(self.paths, path_attr)
if not path or not isinstance(path, str):
raise ValueError(f"{path_attr} must be a non-empty string")
@dataclass
class AnalysisThresholds:
"""Analysis threshold configuration."""
median_delta_e2000_max: float = 3.0
p95_delta_e2000_max: float = 5.0
std_L_max: float = 2.0
center_edge_L_max: float = 2.0
cv_L_max: float = 8.0
p95_pairwise_delta_e2000_max: float = 6.0
tolerance_percent_min: float = 90.0
defect_delta_e2000_max: float = 10.0
defect_std_L_max: float = 5.0
@dataclass
class PerformanceConfig:
"""Performance optimization configuration."""
multiprocessing_enabled: bool = True
max_workers: Optional[int] = None
chunk_size: int = 10
lab_conversion_cache_size: int = 1000
enable_lru_cache: bool = True
use_numpy_vectorized: bool = True
kdtree_enabled: bool = True
kdtree_leaf_size: int = 30
kdtree_sample_size: int = 1000
@dataclass
class UIConfig:
"""User interface configuration."""
verbosity: str = "normal" # quiet, normal, verbose, debug
show_progress_bars: bool = True
update_frequency: int = 10
use_color: bool = True
show_warnings: bool = True
max_error_lines: int = 5
@dataclass
class BusinessConfig:
"""Business metrics configuration."""
cost_analysis_enabled: bool = True
cost_per_defective_bean: float = 0.05
cost_per_uneven_bean: float = 0.02
labor_cost_per_hour: float = 25.0
baseline_defect_rate: float = 0.15
target_defect_rate: float = 0.05
batch_value_usd: float = 500.0
@dataclass
class CalibrationProfile:
"""Color calibration profile."""
name: str
description: str
illuminant: str
viewing_angle: int
background_gray: int
lab_tolerance_factor: float = 1.0
delta_e_threshold_factor: float = 1.0
color_correction_matrix: Optional[List[List[float]]] = None
class ConfigurationManager:
"""
Production-grade configuration manager with validation and environment support.
"""
def __init__(self, config_dir: Union[str, Path] = None, environment: str = None):
"""
Initialize configuration manager.
Args:
config_dir: Directory containing configuration files
environment: Environment name (dev, staging, prod)
"""
self.config_dir = Path(config_dir) if config_dir else Path(__file__).parent.parent.parent / "config"
self.environment = environment or os.getenv("BEAN_VISION_ENV", "development")
# Configuration cache
self._config_cache: Dict[str, Any] = {}
self._file_mtimes: Dict[str, float] = {}
# Load main configurations
self._load_all_configs()
def _load_all_configs(self) -> None:
"""Load all configuration files."""
try:
# Load main analysis config
self._config_cache["analysis"] = self._load_yaml_config("analysis_config.yaml")
# Load calibration profiles
self._config_cache["calibration"] = self._load_yaml_config("calibration_profiles.yaml")
# Apply environment-specific overrides
self._apply_environment_overrides()
logger.info(f"Loaded configuration for environment: {self.environment}")
except Exception as e:
logger.error(f"Failed to load configuration: {e}")
raise
def _load_yaml_config(self, filename: str) -> Dict[str, Any]:
"""Load a YAML configuration file with caching."""
filepath = self.config_dir / filename
if not filepath.exists():
raise FileNotFoundError(f"Configuration file not found: {filepath}")
# Check if file has been modified
mtime = filepath.stat().st_mtime
if filename in self._file_mtimes and self._file_mtimes[filename] == mtime:
return self._config_cache.get(filename.replace(".yaml", ""), {})
# Load the file
with open(filepath, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
self._file_mtimes[filename] = mtime
logger.debug(f"Loaded configuration from {filepath}")
return config
def _apply_environment_overrides(self) -> None:
"""Apply environment-specific configuration overrides."""
analysis_config = self._config_cache.get("analysis", {})
env_overrides = analysis_config.get("environments", {}).get(self.environment, {})
if env_overrides:
self._deep_merge_dict(analysis_config, env_overrides)
logger.info(f"Applied {len(env_overrides)} environment overrides for {self.environment}")
def _deep_merge_dict(self, base: Dict[str, Any], override: Dict[str, Any]) -> None:
"""Deep merge override dictionary into base dictionary."""
for key, value in override.items():
if key in base and isinstance(base[key], dict) and isinstance(value, dict):
self._deep_merge_dict(base[key], value)
else:
base[key] = value
@lru_cache(maxsize=128)
def get_analysis_thresholds(self, bean_type: str) -> AnalysisThresholds:
"""Get analysis thresholds for specified bean type."""
config = self._config_cache["analysis"]["analysis"]
if bean_type == "roasted":
thresholds = config["roasted_thresholds"]
elif bean_type == "green":
thresholds = config["green_thresholds"]
else:
raise ValueError(f"Unknown bean type: {bean_type}")
return AnalysisThresholds(**thresholds)
@lru_cache(maxsize=32)
def get_performance_config(self) -> PerformanceConfig:
"""Get performance optimization configuration."""
config = self._config_cache["analysis"]["performance"]
return PerformanceConfig(
multiprocessing_enabled=config["multiprocessing"]["enabled"],
max_workers=config["multiprocessing"]["max_workers"],
chunk_size=config["multiprocessing"]["chunk_size"],
lab_conversion_cache_size=config["caching"]["lab_conversion_cache_size"],
enable_lru_cache=config["caching"]["enable_lru_cache"],
use_numpy_vectorized=config["vectorization"]["use_numpy_vectorized"],
kdtree_enabled=config["kdtree"]["enabled"],
kdtree_leaf_size=config["kdtree"]["leaf_size"],
kdtree_sample_size=config["kdtree"]["sample_size"]
)
@lru_cache(maxsize=32)
def get_ui_config(self) -> UIConfig:
"""Get user interface configuration."""
config = self._config_cache["analysis"]["ui"]
return UIConfig(
verbosity=config["verbosity"],
show_progress_bars=config["progress"]["show_progress_bars"],
update_frequency=config["progress"]["update_frequency"],
use_color=config["terminal"]["use_color"],
show_warnings=config["terminal"]["show_warnings"],
max_error_lines=config["terminal"]["max_error_lines"]
)
@lru_cache(maxsize=32)
def get_business_config(self) -> BusinessConfig:
"""Get business metrics configuration."""
config = self._config_cache["analysis"]["business"]
return BusinessConfig(
cost_analysis_enabled=config["cost_analysis"]["enabled"],
cost_per_defective_bean=config["cost_analysis"]["cost_per_defective_bean"],
cost_per_uneven_bean=config["cost_analysis"]["cost_per_uneven_bean"],
labor_cost_per_hour=config["cost_analysis"]["labor_cost_per_hour"],
baseline_defect_rate=config["roi"]["baseline_defect_rate"],
target_defect_rate=config["roi"]["target_defect_rate"],
batch_value_usd=config["roi"]["batch_value_usd"]
)
def get_calibration_profile(self, profile_name: str) -> CalibrationProfile:
"""Get calibration profile by name."""
profiles = self._config_cache["calibration"]["profiles"]
if profile_name not in profiles:
available = list(profiles.keys())
raise ValueError(f"Unknown calibration profile: {profile_name}. Available: {available}")
profile_data = profiles[profile_name]
return CalibrationProfile(
name=profile_data["name"],
description=profile_data["description"],
illuminant=profile_data["illuminant"],
viewing_angle=profile_data["viewing_angle"],
background_gray=profile_data["background_gray"],
lab_tolerance_factor=profile_data.get("analysis_adjustments", {}).get("lab_tolerance_factor", 1.0),
delta_e_threshold_factor=profile_data.get("analysis_adjustments", {}).get("delta_e_threshold_factor", 1.0),
color_correction_matrix=profile_data.get("color_correction", {}).get("matrix")
)
def get_available_calibration_profiles(self) -> List[str]:
"""Get list of available calibration profile names."""
return list(self._config_cache["calibration"]["profiles"].keys())
def get_colorchecker_patches(self) -> Dict[str, List[float]]:
"""Get X-Rite ColorChecker reference patch values."""
return self._config_cache["calibration"]["colorchecker"]["reference_patches"]
def get_config_value(self, path: str, default: Any = None) -> Any:
"""
Get configuration value by dot-notation path.
Args:
path: Dot-notation path (e.g., "analysis.roasted_thresholds.std_L_max")
default: Default value if path not found
Returns:
Configuration value
"""
keys = path.split(".")
current = self._config_cache
try:
for key in keys:
current = current[key]
return current
except KeyError:
logger.warning(f"Configuration path not found: {path}")
return default
def reload_config(self) -> None:
"""Reload configuration from files."""
logger.info("Reloading configuration...")
# Clear caches
self._config_cache.clear()
self._file_mtimes.clear()
self.get_analysis_thresholds.cache_clear()
self.get_performance_config.cache_clear()
self.get_ui_config.cache_clear()
self.get_business_config.cache_clear()
# Reload all configs
self._load_all_configs()
def validate_config(self) -> List[str]:
"""
Validate configuration and return list of issues.
Returns:
List of validation error messages
"""
issues = []
try:
# Validate analysis config
analysis_config = self._config_cache.get("analysis", {})
if not analysis_config:
issues.append("Missing analysis configuration")
# Validate required sections
required_sections = ["analysis", "performance", "ui", "logging", "output"]
for section in required_sections:
if section not in analysis_config:
issues.append(f"Missing required section: {section}")
# Validate calibration profiles
calibration_config = self._config_cache.get("calibration", {})
if not calibration_config.get("profiles"):
issues.append("No calibration profiles defined")
# Validate threshold values
for bean_type in ["roasted", "green"]:
try:
thresholds = self.get_analysis_thresholds(bean_type)
if thresholds.tolerance_percent_min < 0 or thresholds.tolerance_percent_min > 100:
issues.append(f"{bean_type} tolerance_percent_min must be 0-100")
except Exception as e:
issues.append(f"Invalid {bean_type} thresholds: {e}")
# Validate performance config
try:
perf_config = self.get_performance_config()
if perf_config.chunk_size <= 0:
issues.append("Performance chunk_size must be positive")
except Exception as e:
issues.append(f"Invalid performance config: {e}")
except Exception as e:
issues.append(f"Configuration validation error: {e}")
return issues
def to_dict(self) -> Dict[str, Any]:
"""Export entire configuration as dictionary."""
return {
"environment": self.environment,
"config_dir": str(self.config_dir),
"analysis": self._config_cache.get("analysis", {}),
"calibration": self._config_cache.get("calibration", {})
}
# Global configuration instance
_config_manager: Optional[ConfigurationManager] = None
def get_config_manager() -> ConfigurationManager:
"""Get global configuration manager instance."""
global _config_manager
if _config_manager is None:
_config_manager = ConfigurationManager()
return _config_manager
def initialize_config(config_dir: Union[str, Path] = None, environment: str = None) -> None:
"""Initialize global configuration manager."""
global _config_manager
_config_manager = ConfigurationManager(config_dir, environment)
# Convenience functions
def get_analysis_thresholds(bean_type: str) -> AnalysisThresholds:
"""Get analysis thresholds for bean type."""
return get_config_manager().get_analysis_thresholds(bean_type)
def get_performance_config() -> PerformanceConfig:
"""Get performance configuration."""
return get_config_manager().get_performance_config()
def get_ui_config() -> UIConfig:
"""Get UI configuration."""
return get_config_manager().get_ui_config()
def get_business_config() -> BusinessConfig:
"""Get business configuration."""
return get_config_manager().get_business_config()
def get_calibration_profile(profile_name: str) -> CalibrationProfile:
"""Get calibration profile by name."""
return get_config_manager().get_calibration_profile(profile_name)
def load_config(config_path: str = "config.yaml") -> BeanVisionConfig:
"""Load and validate configuration."""
config = BeanVisionConfig.from_yaml(config_path)
config.validate()
return config