Spaces:
Running
Running
File size: 1,440 Bytes
14f6839 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | """
轻量图片分类模型构建函数。
这个文件承载 chap09 中的小型 Xception 风格二分类网络,让分类流水线可以直接构建训练模型。
"""
import keras
from keras.layers import BatchNormalization, Conv2D, Dense, Dropout, GlobalAveragePooling2D, MaxPooling2D, Rescaling, SeparableConv2D
def build_image_classification_model(
image_size: tuple[int, int] = (180, 180),
filters: tuple[int, ...] = (128, 256, 512, 728),
initial_filters: int = 32,
dropout_rate: float = 0.5
) -> keras.Model:
inputs = keras.Input(shape=image_size + (3,))
x = Rescaling(1.0 / 255)(inputs)
x = Conv2D(initial_filters, 3, strides=2, padding="same", use_bias=False)(x)
for filter_count in filters:
residual = Conv2D(filter_count, 1, strides=2, padding="same", use_bias=False)(x)
residual = BatchNormalization()(residual)
x = SeparableConv2D(filter_count, 3, padding="same", use_bias=False)(x)
x = BatchNormalization()(x)
x = keras.activations.relu(x)
x = SeparableConv2D(filter_count, 3, padding="same", use_bias=False)(x)
x = BatchNormalization()(x)
x = MaxPooling2D(3, strides=2, padding="same")(x)
x = keras.layers.add([x, residual])
x = GlobalAveragePooling2D()(x)
x = Dropout(dropout_rate)(x)
outputs = Dense(1, activation="sigmoid")(x)
return keras.Model(inputs, outputs, name="image_classification")
|