from pathlib import Path from typing import List DEFAULT_LABELS = [ "cat", "dog", "rabbit", ] PROJECT_DIR = Path(__file__).resolve().parent def load_labels(dataset_train_dir: str = "data/custom_dataset/train", labels_file: str = "labels.txt") -> List[str]: """Load class labels from labels.txt or infer them from train subfolders.""" labels_path = Path(labels_file) if not labels_path.is_absolute(): labels_path = PROJECT_DIR / labels_path if labels_path.exists(): labels = [line.strip() for line in labels_path.read_text(encoding="utf-8").splitlines() if line.strip()] if labels: return labels train_path = Path(dataset_train_dir) if not train_path.is_absolute(): train_path = PROJECT_DIR / train_path if train_path.exists() and train_path.is_dir(): inferred = sorted([p.name for p in train_path.iterdir() if p.is_dir()]) if inferred: return inferred return DEFAULT_LABELS