Spaces:
Sleeping
Sleeping
| """ | |
| config.py β Pydantic settings for the autolabel pipeline. | |
| Handles: | |
| - Auto device detection: CUDA β MPS β CPU | |
| - OWLv2 model selection | |
| - Detection thresholds | |
| - Data paths derived from project root | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import List | |
| import torch | |
| from pydantic import Field, field_validator, model_validator | |
| from pydantic_settings import BaseSettings, SettingsConfigDict | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Project root β two levels up from this file (autolabel/config.py β project/) | |
| # --------------------------------------------------------------------------- | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| def _detect_device() -> str: | |
| """Return the best available torch device string.""" | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| name = torch.cuda.get_device_name(0) | |
| logger.info("Device selected: CUDA (%s)", name) | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| logger.info( | |
| "Device selected: MPS (Apple Silicon). " | |
| "Set PYTORCH_ENABLE_MPS_FALLBACK=1 for unsupported ops." | |
| ) | |
| else: | |
| device = "cpu" | |
| logger.warning("Device selected: CPU β no CUDA or MPS found. Inference will be slow.") | |
| return device | |
| class Settings(BaseSettings): | |
| """Central configuration for the autolabel pipeline. | |
| All values can be overridden via environment variables prefixed with | |
| AUTOLABEL_ (e.g., AUTOLABEL_THRESHOLD=0.2). | |
| The .env file is loaded automatically from the project root. | |
| """ | |
| model_config = SettingsConfigDict( | |
| env_prefix="AUTOLABEL_", | |
| env_file=str(PROJECT_ROOT / ".env"), | |
| env_file_encoding="utf-8", | |
| case_sensitive=False, | |
| extra="ignore", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Device | |
| # ------------------------------------------------------------------ | |
| device: str = Field( | |
| default="", | |
| description="Torch device override. Leave empty for auto-detection.", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # OWLv2 model | |
| # ------------------------------------------------------------------ | |
| model: str = Field( | |
| default="google/owlv2-large-patch14-finetuned", | |
| description="Hugging Face model identifier for OWLv2.", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Detection | |
| # ------------------------------------------------------------------ | |
| threshold: float = Field( | |
| default=0.1, | |
| ge=0.0, | |
| le=1.0, | |
| description="Minimum confidence score to keep a detection.", | |
| ) | |
| prompts: List[str] = Field( | |
| default=[ | |
| "cup", | |
| "bottle", | |
| "keyboard", | |
| "computer mouse", | |
| "cell phone", | |
| "remote control", | |
| "book", | |
| "plant", | |
| "bowl", | |
| "mug", | |
| "laptop", | |
| "monitor", | |
| "pen", | |
| "scissors", | |
| "stapler", | |
| "headphones", | |
| "wallet", | |
| "keys", | |
| "glasses", | |
| "candle", | |
| "backpack", | |
| "notebook", | |
| "water bottle", | |
| "coffee cup", | |
| "charger", | |
| ], | |
| description="Text prompts sent to OWLv2 for open-vocabulary detection.", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Paths | |
| # ------------------------------------------------------------------ | |
| raw_dir: Path = Field( | |
| default=PROJECT_ROOT / "data" / "raw", | |
| description="Input images directory.", | |
| ) | |
| detections_dir: Path = Field( | |
| default=PROJECT_ROOT / "data" / "detections", | |
| description="OWLv2 output JSON files.", | |
| ) | |
| labeled_dir: Path = Field( | |
| default=PROJECT_ROOT / "data" / "labeled", | |
| description="Reviewed and accepted annotation JSON files.", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Validators | |
| # ------------------------------------------------------------------ | |
| def _coerce_threshold(cls, v: object) -> float: | |
| return float(v) # type: ignore[arg-type] | |
| def _parse_prompts(cls, v: object) -> List[str]: | |
| """Allow comma-separated string from env var.""" | |
| if isinstance(v, str): | |
| return [p.strip() for p in v.split(",") if p.strip()] | |
| return list(v) # type: ignore[arg-type] | |
| def _resolve_device(self) -> "Settings": | |
| if not self.device: | |
| self.device = _detect_device() | |
| else: | |
| logger.info("Device override from env/config: %s", self.device) | |
| return self | |
| def _ensure_dirs(self) -> "Settings": | |
| for path in (self.raw_dir, self.detections_dir, self.labeled_dir): | |
| path.mkdir(parents=True, exist_ok=True) | |
| return self | |
| # ------------------------------------------------------------------ | |
| # Convenience | |
| # ------------------------------------------------------------------ | |
| def torch_dtype(self) -> torch.dtype: | |
| """fp16 on CUDA, fp32 everywhere else (MPS doesn't support fp16 fully).""" | |
| return torch.float16 if self.device == "cuda" else torch.float32 | |
| # Module-level singleton β import this everywhere. | |
| settings = Settings() |