File size: 613 Bytes
14f6839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
import tensorflow as tf

from models.segmentation import build_segmentation_model


def test_build_segmentation_model():
    """验证分割模型保持输入分辨率,并为每个像素输出类别概率。"""
    model = build_segmentation_model(
        image_size=(32, 32),
        num_classes=3,
        filters=(8,)
    )
    images = tf.zeros((2, 32, 32, 3), dtype=tf.float32)

    outputs = model(images)

    assert outputs.shape == (2, 32, 32, 3)
    np.testing.assert_allclose(
        tf.reduce_sum(outputs, axis=-1).numpy(),
        np.ones((2, 32, 32)),
        atol=1e-5
    )