File size: 1,354 Bytes
14f6839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
通用目录图片分类数据集。

这个文件把 train/val/test 三段式目录整理成 Keras 可训练的数据集。
每个数据段下面按类别建子目录,图片分类流水线通过它读取训练、验证和测试输入。
"""

from dataclasses import dataclass
from pathlib import Path

import keras
import tensorflow as tf


@dataclass
class ImageClassificationDirectoryDataset:
    train_path: Path
    validation_path: Path
    test_path: Path
    image_size: tuple[int, int] = (180, 180)
    label_mode: str = "binary"

    def training_ds(self, batch_size: int) -> tf.data.Dataset:
        return self._build_dataset(self.train_path, batch_size)

    def validation_ds(self, batch_size: int) -> tf.data.Dataset:
        return self._build_dataset(self.validation_path, batch_size)

    def test_ds(self, batch_size: int) -> tf.data.Dataset:
        return self._build_dataset(self.test_path, batch_size)

    def class_names(self) -> list[str]:
        dataset = self._build_dataset(self.train_path, batch_size=1)
        return dataset.class_names

    def _build_dataset(self, path: Path, batch_size: int) -> tf.data.Dataset:
        return keras.utils.image_dataset_from_directory(
            path,
            image_size=self.image_size,
            batch_size=batch_size,
            label_mode=self.label_mode
        )