KingTechnician commited on
Commit
2bb3322
·
verified ·
1 Parent(s): 20f3ee9

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. decoder.py +35 -30
decoder.py CHANGED
@@ -5,8 +5,6 @@ This model contains BOTH encoder and decoder in one checkpoint
5
  import tensorflow as tf
6
  import numpy as np
7
  from PIL import Image as PILImage
8
- from tensorflow.python.saved_model import tag_constants
9
- from tensorflow.python.saved_model import signature_constants
10
 
11
 
12
  class TFStegaStampDecoderPretrained:
@@ -24,31 +22,21 @@ class TFStegaStampDecoderPretrained:
24
  """
25
  print(f"Loading TensorFlow pretrained model from: {model_path}")
26
 
27
- # Create session and load model
28
- self.sess = tf.compat.v1.Session(graph=tf.Graph())
29
 
30
- # Load the model
31
- model = tf.compat.v1.saved_model.loader.load(
32
- self.sess,
33
- [tag_constants.SERVING],
34
- model_path
35
- )
36
 
37
- # Get signature
38
- signature_def = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
 
39
 
40
- # Get input tensor (image)
41
- input_image_name = signature_def.inputs['image'].name
42
- self.input_image = self.sess.graph.get_tensor_by_name(input_image_name)
43
- print(f" Input tensor: {input_image_name}")
44
- print(f" Shape: {self.input_image.shape}")
45
-
46
- # Get output tensor (decoded secret)
47
- output_secret_name = signature_def.outputs['decoded'].name
48
- self.output_secret = self.sess.graph.get_tensor_by_name(output_secret_name)
49
- print(f" Output tensor: {output_secret_name}")
50
- print(f" Shape: {self.output_secret.shape}")
51
 
 
 
52
  print("TensorFlow decoder loaded successfully!")
53
 
54
  def decode(self, image):
@@ -84,12 +72,29 @@ class TFStegaStampDecoderPretrained:
84
  image = np.array(image_pil).astype(np.float32) / 255.0
85
  image = np.expand_dims(image, axis=0)
86
 
87
- # Run inference
88
- feed_dict = {self.input_image: image}
89
- secret = self.sess.run(self.output_secret, feed_dict=feed_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # Output is (batch, 100)
92
- secret = secret[0] # Remove batch dimension
93
 
94
  # Threshold to get binary values
95
  bits = (secret > 0.5).astype(np.float32)
@@ -101,9 +106,9 @@ class TFStegaStampDecoderPretrained:
101
  return self.decode(image)
102
 
103
  def close(self):
104
- """Close the TensorFlow session"""
105
- if hasattr(self, 'sess') and self.sess is not None:
106
- self.sess.close()
107
 
108
  def __del__(self):
109
  """Cleanup on deletion"""
 
5
  import tensorflow as tf
6
  import numpy as np
7
  from PIL import Image as PILImage
 
 
8
 
9
 
10
  class TFStegaStampDecoderPretrained:
 
22
  """
23
  print(f"Loading TensorFlow pretrained model from: {model_path}")
24
 
25
+ # Load model with TF2 API
26
+ loaded = tf.saved_model.load(model_path)
27
 
28
+ # Get the serving function
29
+ self._infer_fn = loaded.signatures['serving_default']
 
 
 
 
30
 
31
+ # Get input/output info
32
+ input_signature = self._infer_fn.structured_input_signature[1]
33
+ output_signature = self._infer_fn.structured_outputs
34
 
35
+ self._input_names = list(input_signature.keys())
36
+ self._output_names = list(output_signature.keys())
 
 
 
 
 
 
 
 
 
37
 
38
+ print(f" Input tensors: {self._input_names}")
39
+ print(f" Output tensors: {self._output_names}")
40
  print("TensorFlow decoder loaded successfully!")
41
 
42
  def decode(self, image):
 
72
  image = np.array(image_pil).astype(np.float32) / 255.0
73
  image = np.expand_dims(image, axis=0)
74
 
75
+ # Ensure float32
76
+ image = image.astype(np.float32)
77
+
78
+ # Run inference with TF2 API
79
+ # Convert to tensor
80
+ image_tensor = tf.convert_to_tensor(image)
81
+
82
+ # Call the function - use the correct input name from signature
83
+ if 'image' in self._input_names:
84
+ result = self._infer_fn(image=image_tensor)
85
+ else:
86
+ # Use the first input name
87
+ result = self._infer_fn(**{self._input_names[0]: image_tensor})
88
+
89
+ # Get decoded output - use the correct output name
90
+ if 'decoded' in self._output_names:
91
+ decoded = result['decoded'].numpy()
92
+ else:
93
+ # Use the first output
94
+ decoded = list(result.values())[0].numpy()
95
 
96
  # Output is (batch, 100)
97
+ secret = decoded[0] # Remove batch dimension
98
 
99
  # Threshold to get binary values
100
  bits = (secret > 0.5).astype(np.float32)
 
106
  return self.decode(image)
107
 
108
  def close(self):
109
+ """Cleanup resources"""
110
+ # TF2 models don't need explicit session closing
111
+ pass
112
 
113
  def __del__(self):
114
  """Cleanup on deletion"""