yetrun's picture
ver2: 扩展 CV 训练框架,支持分类、分割与目标检测任务
14f6839
Raw
History Blame Contribute Delete
1.35 kB
"""
通用目录图片分类数据集。
这个文件把 train/val/test 三段式目录整理成 Keras 可训练的数据集。
每个数据段下面按类别建子目录,图片分类流水线通过它读取训练、验证和测试输入。
"""
from dataclasses import dataclass
from pathlib import Path
import keras
import tensorflow as tf
@dataclass
class ImageClassificationDirectoryDataset:
train_path: Path
validation_path: Path
test_path: Path
image_size: tuple[int, int] = (180, 180)
label_mode: str = "binary"
def training_ds(self, batch_size: int) -> tf.data.Dataset:
return self._build_dataset(self.train_path, batch_size)
def validation_ds(self, batch_size: int) -> tf.data.Dataset:
return self._build_dataset(self.validation_path, batch_size)
def test_ds(self, batch_size: int) -> tf.data.Dataset:
return self._build_dataset(self.test_path, batch_size)
def class_names(self) -> list[str]:
dataset = self._build_dataset(self.train_path, batch_size=1)
return dataset.class_names
def _build_dataset(self, path: Path, batch_size: int) -> tf.data.Dataset:
return keras.utils.image_dataset_from_directory(
path,
image_size=self.image_size,
batch_size=batch_size,
label_mode=self.label_mode
)