| from pathlib import Path |
| from typing import List |
|
|
|
|
| DEFAULT_LABELS = [ |
| "cat", |
| "dog", |
| "rabbit", |
| ] |
|
|
|
|
| PROJECT_DIR = Path(__file__).resolve().parent |
|
|
|
|
| def load_labels(dataset_train_dir: str = "data/custom_dataset/train", labels_file: str = "labels.txt") -> List[str]: |
| """Load class labels from labels.txt or infer them from train subfolders.""" |
| labels_path = Path(labels_file) |
| if not labels_path.is_absolute(): |
| labels_path = PROJECT_DIR / labels_path |
|
|
| if labels_path.exists(): |
| labels = [line.strip() for line in labels_path.read_text(encoding="utf-8").splitlines() if line.strip()] |
| if labels: |
| return labels |
|
|
| train_path = Path(dataset_train_dir) |
| if not train_path.is_absolute(): |
| train_path = PROJECT_DIR / train_path |
|
|
| if train_path.exists() and train_path.is_dir(): |
| inferred = sorted([p.name for p in train_path.iterdir() if p.is_dir()]) |
| if inferred: |
| return inferred |
|
|
| return DEFAULT_LABELS |
|
|