github-actions[bot]
Deploy from c78f0ab7f068a2e34f263a03499fd51c9fc81c53
e2e75ee
raw
history blame
4.95 kB
"""Application configuration using Pydantic Settings.
Configuration is loaded from environment variables with the FACE_AGE_ prefix.
Falls back to sensible defaults for local development.
Environment variables:
FACE_AGE_DEVICE: PyTorch device (default: cpu)
FACE_AGE_DETECTOR_MODEL_ID: HuggingFace model ID for face/person detector
FACE_AGE_MIVOLO_MODEL_ID: HuggingFace model ID for MiVOLO v2 age estimator
FACE_AGE_CONFIDENCE_THRESHOLD: Detection confidence threshold
FACE_AGE_IOU_THRESHOLD: Detection IoU threshold
FACE_AGE_ANNOTATION_FORMAT: Output image format
FACE_AGE_MIVOLO_BATCH_SIZE: Max batch size for MiVOLO forward passes
"""
import torch
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
# Type aliases
type DeviceSpec = str
type Probability = float
# Default HuggingFace model IDs
_DEFAULT_DETECTOR_MODEL_ID: str = "iitolstykh/YOLO-Face-Person-Detector"
_DEFAULT_MIVOLO_MODEL_ID: str = "iitolstykh/mivolo_v2"
def _detect_best_device() -> str:
"""Auto-detect the best available device for inference.
Returns:
Device string: 'cuda' if available, 'mps' on Apple Silicon, else 'cpu'.
"""
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
# Threshold bounds
MIN_PROBABILITY: Probability = 0.0
MAX_PROBABILITY: Probability = 1.0
class Settings(BaseSettings):
"""Runtime configuration for face-age-inference.
Attributes:
device: PyTorch device specification (e.g., 'cpu', 'cuda:0', 'mps').
detector_model_id: HuggingFace model ID for YOLO face/person detector.
mivolo_model_id: HuggingFace model ID for MiVOLO v2 age estimator.
confidence_threshold: Minimum confidence for face detection (0-1).
iou_threshold: IoU threshold for non-maximum suppression (0-1).
annotation_format: Image format for annotated outputs ('.jpg' or '.png').
"""
model_config = SettingsConfigDict(
env_prefix="FACE_AGE_",
)
device: DeviceSpec = Field(
default_factory=_detect_best_device,
description="PyTorch device identifier for inference (e.g., 'cpu', 'cuda:0', 'mps')",
)
detector_model_id: str = Field(
default=_DEFAULT_DETECTOR_MODEL_ID,
description="HuggingFace model ID for YOLO face/person detector",
)
mivolo_model_id: str = Field(
default=_DEFAULT_MIVOLO_MODEL_ID,
description="HuggingFace model ID for MiVOLO v2 age estimator",
)
confidence_threshold: Probability = Field(
default=0.15,
ge=MIN_PROBABILITY,
le=MAX_PROBABILITY,
description="Minimum confidence score for face detection (0.0 to 1.0)",
)
iou_threshold: Probability = Field(
default=0.4,
ge=MIN_PROBABILITY,
le=MAX_PROBABILITY,
description="IoU threshold for non-maximum suppression (0.0 to 1.0)",
)
annotation_format: str = Field(
default=".jpg",
description="Image format for annotated outputs ('.jpg' or '.png')",
)
mivolo_batch_size: int = Field(
default=8,
ge=1,
description="Max batch size for MiVOLO forward passes (reduce if you hit OOM)",
)
@field_validator("annotation_format")
@classmethod
def validate_annotation_format(cls, value: str) -> str:
"""Ensure annotation format is supported.
Args:
value: The format string to validate.
Returns:
Validated format string.
Raises:
ValueError: If format is not supported.
"""
normalized = value.lower()
if normalized not in {".jpg", ".jpeg", ".png"}:
raise ValueError(
f"Unsupported annotation format: {value}. "
"Must be one of: .jpg, .jpeg, .png"
)
# Normalize .jpeg to .jpg for consistency
return ".jpg" if normalized == ".jpeg" else normalized
@field_validator("device")
@classmethod
def validate_device(cls, value: str) -> str:
"""Validate device specification format.
Args:
value: Device specification string.
Returns:
Validated device string.
Raises:
ValueError: If device format is invalid.
"""
valid_prefixes = ("cpu", "cuda", "mps")
if not any(value.startswith(prefix) for prefix in valid_prefixes):
raise ValueError(
f"Invalid device specification: {value}. "
f"Must start with one of: {', '.join(valid_prefixes)}"
)
return value
# Global settings instance
settings = Settings()
__all__ = [
"Settings",
"settings",
]