rawanessam commited on
Commit
5fdff1b
·
verified ·
1 Parent(s): 3bf134a

Upload net.py

Browse files
Files changed (1) hide show
  1. net.py +42 -16
net.py CHANGED
@@ -31,22 +31,48 @@ def data_loader_bd_rm_from_tfrecord(batch_size=1):
31
  return loader_dict, num_batch
32
 
33
  class Network(object):
34
- def __init__(self, dtype=tf.float32):
35
- print('Initial nn network object...')
36
- self.dtype = dtype
37
- # ... existing code ...
38
-
39
- def convert_one_hot_to_image(self, one_hot, dtype='float', act=None):
40
- import tensorflow.compat.v1 as tf
41
- if act == 'softmax':
42
- one_hot = tf.nn.softmax(one_hot, axis=-1)
43
- [n, h, w, c] = one_hot.shape.as_list()
44
- im = tf.reshape(tf.argmax(one_hot, axis=-1), [n, h, w, 1])
45
- if dtype == 'int':
46
- im = tf.cast(im, dtype=tf.uint8)
47
- else:
48
- im = tf.cast(im, dtype=tf.float32)
49
- return im
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # basic layer
52
  def _he_uniform(self, shape, regularizer=None, trainable=None, name=None):
 
31
  return loader_dict, num_batch
32
 
33
  class Network(object):
34
+ """docstring for Network"""
35
+ def __init__(self, dtype=tf.float32):
36
+ print('Initial nn network object...')
37
+ self.dtype = dtype
38
+ self.pre_train_restore_map = {'vgg_16/conv1/conv1_1/weights':'FNet/conv1_1/W', # {'checkpoint_scope_var_name':'current_scope_var_name'} shape must be the same
39
+ 'vgg_16/conv1/conv1_1/biases':'FNet/conv1_1/b',
40
+ 'vgg_16/conv1/conv1_2/weights':'FNet/conv1_2/W',
41
+ 'vgg_16/conv1/conv1_2/biases':'FNet/conv1_2/b',
42
+ 'vgg_16/conv2/conv2_1/weights':'FNet/conv2_1/W',
43
+ 'vgg_16/conv2/conv2_1/biases':'FNet/conv2_1/b',
44
+ 'vgg_16/conv2/conv2_2/weights':'FNet/conv2_2/W',
45
+ 'vgg_16/conv2/conv2_2/biases':'FNet/conv2_2/b',
46
+ 'vgg_16/conv3/conv3_1/weights':'FNet/conv3_1/W',
47
+ 'vgg_16/conv3/conv3_1/biases':'FNet/conv3_1/b',
48
+ 'vgg_16/conv3/conv3_2/weights':'FNet/conv3_2/W',
49
+ 'vgg_16/conv3/conv3_2/biases':'FNet/conv3_2/b',
50
+ 'vgg_16/conv3/conv3_3/weights':'FNet/conv3_3/W',
51
+ 'vgg_16/conv3/conv3_3/biases':'FNet/conv3_3/b',
52
+ 'vgg_16/conv4/conv4_1/weights':'FNet/conv4_1/W',
53
+ 'vgg_16/conv4/conv4_1/biases':'FNet/conv4_1/b',
54
+ 'vgg_16/conv4/conv4_2/weights':'FNet/conv4_2/W',
55
+ 'vgg_16/conv4/conv4_2/biases':'FNet/conv4_2/b',
56
+ 'vgg_16/conv4/conv4_3/weights':'FNet/conv4_3/W',
57
+ 'vgg_16/conv4/conv4_3/biases':'FNet/conv4_3/b',
58
+ 'vgg_16/conv5/conv5_1/weights':'FNet/conv5_1/W',
59
+ 'vgg_16/conv5/conv5_1/biases':'FNet/conv5_1/b',
60
+ 'vgg_16/conv5/conv5_2/weights':'FNet/conv5_2/W',
61
+ 'vgg_16/conv5/conv5_2/biases':'FNet/conv5_2/b',
62
+ 'vgg_16/conv5/conv5_3/weights':'FNet/conv5_3/W',
63
+ 'vgg_16/conv5/conv5_3/biases':'FNet/conv5_3/b'}
64
+
65
+ def convert_one_hot_to_image(self, one_hot, dtype='float', act=None):
66
+ # This method was moved from MODEL in main.py for inference compatibility
67
+ if act == 'softmax':
68
+ one_hot = tf.nn.softmax(one_hot, axis=-1)
69
+ [n, h, w, c] = one_hot.shape.as_list()
70
+ im = tf.reshape(tf.argmax(one_hot, axis=-1), [n, h, w, 1])
71
+ if dtype == 'int':
72
+ im = tf.cast(im, dtype=tf.uint8)
73
+ else:
74
+ im = tf.cast(im, dtype=tf.float32)
75
+ return im
76
 
77
  # basic layer
78
  def _he_uniform(self, shape, regularizer=None, trainable=None, name=None):