Abgabe2 / labels.py
nbacchi's picture
Upload 6 files
16c7630 verified
from pathlib import Path
from typing import List
DEFAULT_LABELS = [
"cat",
"dog",
"rabbit",
]
PROJECT_DIR = Path(__file__).resolve().parent
def load_labels(dataset_train_dir: str = "data/custom_dataset/train", labels_file: str = "labels.txt") -> List[str]:
"""Load class labels from labels.txt or infer them from train subfolders."""
labels_path = Path(labels_file)
if not labels_path.is_absolute():
labels_path = PROJECT_DIR / labels_path
if labels_path.exists():
labels = [line.strip() for line in labels_path.read_text(encoding="utf-8").splitlines() if line.strip()]
if labels:
return labels
train_path = Path(dataset_train_dir)
if not train_path.is_absolute():
train_path = PROJECT_DIR / train_path
if train_path.exists() and train_path.is_dir():
inferred = sorted([p.name for p in train_path.iterdir() if p.is_dir()])
if inferred:
return inferred
return DEFAULT_LABELS