Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| from tensorflow.keras import losses | |
| class SpatialConsistencyLoss(losses.Loss): | |
| def __init__(self, **kwargs): | |
| super(SpatialConsistencyLoss, self).__init__(reduction="none") | |
| self.left_kernel = tf.constant( | |
| [[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32 | |
| ) | |
| self.right_kernel = tf.constant( | |
| [[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32 | |
| ) | |
| self.up_kernel = tf.constant( | |
| [[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32 | |
| ) | |
| self.down_kernel = tf.constant( | |
| [[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32 | |
| ) | |
| def call(self, y_true, y_pred): | |
| original_mean = tf.reduce_mean(y_true, 3, keepdims=True) | |
| enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True) | |
| original_pool = tf.nn.avg_pool2d( | |
| original_mean, ksize=4, strides=4, padding="VALID" | |
| ) | |
| enhanced_pool = tf.nn.avg_pool2d( | |
| enhanced_mean, ksize=4, strides=4, padding="VALID" | |
| ) | |
| d_original_left = tf.nn.conv2d( | |
| original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
| ) | |
| d_original_right = tf.nn.conv2d( | |
| original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
| ) | |
| d_original_up = tf.nn.conv2d( | |
| original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
| ) | |
| d_original_down = tf.nn.conv2d( | |
| original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
| ) | |
| d_enhanced_left = tf.nn.conv2d( | |
| enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
| ) | |
| d_enhanced_right = tf.nn.conv2d( | |
| enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
| ) | |
| d_enhanced_up = tf.nn.conv2d( | |
| enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
| ) | |
| d_enhanced_down = tf.nn.conv2d( | |
| enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
| ) | |
| d_left = tf.square(d_original_left - d_enhanced_left) | |
| d_right = tf.square(d_original_right - d_enhanced_right) | |
| d_up = tf.square(d_original_up - d_enhanced_up) | |
| d_down = tf.square(d_original_down - d_enhanced_down) | |
| return d_left + d_right + d_up + d_down | |