general-deep-learning / test /models /segmentation_model_test.py
yetrun's picture
ver2: 扩展 CV 训练框架,支持分类、分割与目标检测任务
14f6839
Raw
History Blame Contribute Delete
613 Bytes
import numpy as np
import tensorflow as tf
from models.segmentation import build_segmentation_model
def test_build_segmentation_model():
"""验证分割模型保持输入分辨率,并为每个像素输出类别概率。"""
model = build_segmentation_model(
image_size=(32, 32),
num_classes=3,
filters=(8,)
)
images = tf.zeros((2, 32, 32, 3), dtype=tf.float32)
outputs = model(images)
assert outputs.shape == (2, 32, 32, 3)
np.testing.assert_allclose(
tf.reduce_sum(outputs, axis=-1).numpy(),
np.ones((2, 32, 32)),
atol=1e-5
)