rawanessam commited on
Commit
3bf134a
·
verified ·
1 Parent(s): 3573186

Update net.py

Browse files
Files changed (1) hide show
  1. net.py +16 -30
net.py CHANGED
@@ -31,36 +31,22 @@ def data_loader_bd_rm_from_tfrecord(batch_size=1):
31
  return loader_dict, num_batch
32
 
33
  class Network(object):
34
- """docstring for Network"""
35
- def __init__(self, dtype=tf.float32):
36
- print('Initial nn network object...')
37
- self.dtype = dtype
38
- self.pre_train_restore_map = {'vgg_16/conv1/conv1_1/weights':'FNet/conv1_1/W', # {'checkpoint_scope_var_name':'current_scope_var_name'} shape must be the same
39
- 'vgg_16/conv1/conv1_1/biases':'FNet/conv1_1/b',
40
- 'vgg_16/conv1/conv1_2/weights':'FNet/conv1_2/W',
41
- 'vgg_16/conv1/conv1_2/biases':'FNet/conv1_2/b',
42
- 'vgg_16/conv2/conv2_1/weights':'FNet/conv2_1/W',
43
- 'vgg_16/conv2/conv2_1/biases':'FNet/conv2_1/b',
44
- 'vgg_16/conv2/conv2_2/weights':'FNet/conv2_2/W',
45
- 'vgg_16/conv2/conv2_2/biases':'FNet/conv2_2/b',
46
- 'vgg_16/conv3/conv3_1/weights':'FNet/conv3_1/W',
47
- 'vgg_16/conv3/conv3_1/biases':'FNet/conv3_1/b',
48
- 'vgg_16/conv3/conv3_2/weights':'FNet/conv3_2/W',
49
- 'vgg_16/conv3/conv3_2/biases':'FNet/conv3_2/b',
50
- 'vgg_16/conv3/conv3_3/weights':'FNet/conv3_3/W',
51
- 'vgg_16/conv3/conv3_3/biases':'FNet/conv3_3/b',
52
- 'vgg_16/conv4/conv4_1/weights':'FNet/conv4_1/W',
53
- 'vgg_16/conv4/conv4_1/biases':'FNet/conv4_1/b',
54
- 'vgg_16/conv4/conv4_2/weights':'FNet/conv4_2/W',
55
- 'vgg_16/conv4/conv4_2/biases':'FNet/conv4_2/b',
56
- 'vgg_16/conv4/conv4_3/weights':'FNet/conv4_3/W',
57
- 'vgg_16/conv4/conv4_3/biases':'FNet/conv4_3/b',
58
- 'vgg_16/conv5/conv5_1/weights':'FNet/conv5_1/W',
59
- 'vgg_16/conv5/conv5_1/biases':'FNet/conv5_1/b',
60
- 'vgg_16/conv5/conv5_2/weights':'FNet/conv5_2/W',
61
- 'vgg_16/conv5/conv5_2/biases':'FNet/conv5_2/b',
62
- 'vgg_16/conv5/conv5_3/weights':'FNet/conv5_3/W',
63
- 'vgg_16/conv5/conv5_3/biases':'FNet/conv5_3/b'}
64
 
65
  # basic layer
66
  def _he_uniform(self, shape, regularizer=None, trainable=None, name=None):
 
31
  return loader_dict, num_batch
32
 
33
  class Network(object):
34
+ def __init__(self, dtype=tf.float32):
35
+ print('Initial nn network object...')
36
+ self.dtype = dtype
37
+ # ... existing code ...
38
+
39
+ def convert_one_hot_to_image(self, one_hot, dtype='float', act=None):
40
+ import tensorflow.compat.v1 as tf
41
+ if act == 'softmax':
42
+ one_hot = tf.nn.softmax(one_hot, axis=-1)
43
+ [n, h, w, c] = one_hot.shape.as_list()
44
+ im = tf.reshape(tf.argmax(one_hot, axis=-1), [n, h, w, 1])
45
+ if dtype == 'int':
46
+ im = tf.cast(im, dtype=tf.uint8)
47
+ else:
48
+ im = tf.cast(im, dtype=tf.float32)
49
+ return im
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # basic layer
52
  def _he_uniform(self, shape, regularizer=None, trainable=None, name=None):