Erick
Upload folder using huggingface_hub
47cb9bd verified
"""
config.py β€” Pydantic settings for the autolabel pipeline.
Handles:
- Auto device detection: CUDA β†’ MPS β†’ CPU
- OWLv2 model selection
- Detection thresholds
- Data paths derived from project root
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import List
import torch
from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Project root β€” two levels up from this file (autolabel/config.py β†’ project/)
# ---------------------------------------------------------------------------
PROJECT_ROOT = Path(__file__).resolve().parent.parent
def _detect_device() -> str:
"""Return the best available torch device string."""
if torch.cuda.is_available():
device = "cuda"
name = torch.cuda.get_device_name(0)
logger.info("Device selected: CUDA (%s)", name)
elif torch.backends.mps.is_available():
device = "mps"
logger.info(
"Device selected: MPS (Apple Silicon). "
"Set PYTORCH_ENABLE_MPS_FALLBACK=1 for unsupported ops."
)
else:
device = "cpu"
logger.warning("Device selected: CPU β€” no CUDA or MPS found. Inference will be slow.")
return device
class Settings(BaseSettings):
"""Central configuration for the autolabel pipeline.
All values can be overridden via environment variables prefixed with
AUTOLABEL_ (e.g., AUTOLABEL_THRESHOLD=0.2).
The .env file is loaded automatically from the project root.
"""
model_config = SettingsConfigDict(
env_prefix="AUTOLABEL_",
env_file=str(PROJECT_ROOT / ".env"),
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
# ------------------------------------------------------------------
# Device
# ------------------------------------------------------------------
device: str = Field(
default="",
description="Torch device override. Leave empty for auto-detection.",
)
# ------------------------------------------------------------------
# OWLv2 model
# ------------------------------------------------------------------
model: str = Field(
default="google/owlv2-large-patch14-finetuned",
description="Hugging Face model identifier for OWLv2.",
)
# ------------------------------------------------------------------
# Detection
# ------------------------------------------------------------------
threshold: float = Field(
default=0.1,
ge=0.0,
le=1.0,
description="Minimum confidence score to keep a detection.",
)
prompts: List[str] = Field(
default=[
"cup",
"bottle",
"keyboard",
"computer mouse",
"cell phone",
"remote control",
"book",
"plant",
"bowl",
"mug",
"laptop",
"monitor",
"pen",
"scissors",
"stapler",
"headphones",
"wallet",
"keys",
"glasses",
"candle",
"backpack",
"notebook",
"water bottle",
"coffee cup",
"charger",
],
description="Text prompts sent to OWLv2 for open-vocabulary detection.",
)
# ------------------------------------------------------------------
# Paths
# ------------------------------------------------------------------
raw_dir: Path = Field(
default=PROJECT_ROOT / "data" / "raw",
description="Input images directory.",
)
detections_dir: Path = Field(
default=PROJECT_ROOT / "data" / "detections",
description="OWLv2 output JSON files.",
)
labeled_dir: Path = Field(
default=PROJECT_ROOT / "data" / "labeled",
description="Reviewed and accepted annotation JSON files.",
)
# ------------------------------------------------------------------
# Validators
# ------------------------------------------------------------------
@field_validator("threshold", mode="before")
@classmethod
def _coerce_threshold(cls, v: object) -> float:
return float(v) # type: ignore[arg-type]
@field_validator("prompts", mode="before")
@classmethod
def _parse_prompts(cls, v: object) -> List[str]:
"""Allow comma-separated string from env var."""
if isinstance(v, str):
return [p.strip() for p in v.split(",") if p.strip()]
return list(v) # type: ignore[arg-type]
@model_validator(mode="after")
def _resolve_device(self) -> "Settings":
if not self.device:
self.device = _detect_device()
else:
logger.info("Device override from env/config: %s", self.device)
return self
@model_validator(mode="after")
def _ensure_dirs(self) -> "Settings":
for path in (self.raw_dir, self.detections_dir, self.labeled_dir):
path.mkdir(parents=True, exist_ok=True)
return self
# ------------------------------------------------------------------
# Convenience
# ------------------------------------------------------------------
@property
def torch_dtype(self) -> torch.dtype:
"""fp16 on CUDA, fp32 everywhere else (MPS doesn't support fp16 fully)."""
return torch.float16 if self.device == "cuda" else torch.float32
# Module-level singleton β€” import this everywhere.
settings = Settings()