Keras
astronomy
mervess commited on
Commit
3e3ffde
·
verified ·
1 Parent(s): ab90a0e

Upload filters.py

Browse files

Necessary to load the model. It includes the GaussianFilter used in the final layer of the network.

Files changed (1) hide show
  1. filters.py +34 -0
filters.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.layers import Layer
3
+
4
+
5
+
6
+ class GaussianFilter(Layer):
7
+ def __init__(self, kernel_size=5, sigma=1.0, **kwargs):
8
+ super(GaussianFilter, self).__init__(**kwargs)
9
+ self.kernel_size = kernel_size
10
+ self.sigma = sigma
11
+
12
+ def build(self, input_shape):
13
+ # Create a Gaussian kernel
14
+ def gaussian_kernel(size, sigma):
15
+ x = tf.range(-size // 2 + 1, size // 2 + 1, dtype=tf.float32)
16
+ x = tf.exp(-(x**2) / (2 * sigma**2))
17
+ kernel = tf.tensordot(x, x, axes=0)
18
+ return kernel / tf.reduce_sum(kernel)
19
+
20
+ kernel = gaussian_kernel(self.kernel_size, self.sigma)
21
+ kernel = kernel[:, :, tf.newaxis, tf.newaxis]
22
+ self.kernel = tf.tile(kernel, [1, 1, input_shape[-1], 1])
23
+ self.built = True
24
+
25
+ def call(self, inputs):
26
+ return tf.nn.depthwise_conv2d(inputs, self.kernel, strides=[1, 1, 1, 1], padding='SAME')
27
+
28
+ def compute_output_shape(self, input_shape):
29
+ return input_shape
30
+
31
+ def get_config(self):
32
+ config = super(GaussianFilter, self).get_config()
33
+ config.update({'kernel_size': self.kernel_size, 'sigma': self.sigma})
34
+ return config