Upload folder using huggingface_hub
Browse files- decoder.py +35 -30
decoder.py
CHANGED
|
@@ -5,8 +5,6 @@ This model contains BOTH encoder and decoder in one checkpoint
|
|
| 5 |
import tensorflow as tf
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image as PILImage
|
| 8 |
-
from tensorflow.python.saved_model import tag_constants
|
| 9 |
-
from tensorflow.python.saved_model import signature_constants
|
| 10 |
|
| 11 |
|
| 12 |
class TFStegaStampDecoderPretrained:
|
|
@@ -24,31 +22,21 @@ class TFStegaStampDecoderPretrained:
|
|
| 24 |
"""
|
| 25 |
print(f"Loading TensorFlow pretrained model from: {model_path}")
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
|
| 29 |
|
| 30 |
-
#
|
| 31 |
-
|
| 32 |
-
self.sess,
|
| 33 |
-
[tag_constants.SERVING],
|
| 34 |
-
model_path
|
| 35 |
-
)
|
| 36 |
|
| 37 |
-
# Get
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
self.input_image = self.sess.graph.get_tensor_by_name(input_image_name)
|
| 43 |
-
print(f" Input tensor: {input_image_name}")
|
| 44 |
-
print(f" Shape: {self.input_image.shape}")
|
| 45 |
-
|
| 46 |
-
# Get output tensor (decoded secret)
|
| 47 |
-
output_secret_name = signature_def.outputs['decoded'].name
|
| 48 |
-
self.output_secret = self.sess.graph.get_tensor_by_name(output_secret_name)
|
| 49 |
-
print(f" Output tensor: {output_secret_name}")
|
| 50 |
-
print(f" Shape: {self.output_secret.shape}")
|
| 51 |
|
|
|
|
|
|
|
| 52 |
print("TensorFlow decoder loaded successfully!")
|
| 53 |
|
| 54 |
def decode(self, image):
|
|
@@ -84,12 +72,29 @@ class TFStegaStampDecoderPretrained:
|
|
| 84 |
image = np.array(image_pil).astype(np.float32) / 255.0
|
| 85 |
image = np.expand_dims(image, axis=0)
|
| 86 |
|
| 87 |
-
#
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Output is (batch, 100)
|
| 92 |
-
secret =
|
| 93 |
|
| 94 |
# Threshold to get binary values
|
| 95 |
bits = (secret > 0.5).astype(np.float32)
|
|
@@ -101,9 +106,9 @@ class TFStegaStampDecoderPretrained:
|
|
| 101 |
return self.decode(image)
|
| 102 |
|
| 103 |
def close(self):
|
| 104 |
-
"""
|
| 105 |
-
|
| 106 |
-
|
| 107 |
|
| 108 |
def __del__(self):
|
| 109 |
"""Cleanup on deletion"""
|
|
|
|
| 5 |
import tensorflow as tf
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image as PILImage
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class TFStegaStampDecoderPretrained:
|
|
|
|
| 22 |
"""
|
| 23 |
print(f"Loading TensorFlow pretrained model from: {model_path}")
|
| 24 |
|
| 25 |
+
# Load model with TF2 API
|
| 26 |
+
loaded = tf.saved_model.load(model_path)
|
| 27 |
|
| 28 |
+
# Get the serving function
|
| 29 |
+
self._infer_fn = loaded.signatures['serving_default']
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
+
# Get input/output info
|
| 32 |
+
input_signature = self._infer_fn.structured_input_signature[1]
|
| 33 |
+
output_signature = self._infer_fn.structured_outputs
|
| 34 |
|
| 35 |
+
self._input_names = list(input_signature.keys())
|
| 36 |
+
self._output_names = list(output_signature.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
print(f" Input tensors: {self._input_names}")
|
| 39 |
+
print(f" Output tensors: {self._output_names}")
|
| 40 |
print("TensorFlow decoder loaded successfully!")
|
| 41 |
|
| 42 |
def decode(self, image):
|
|
|
|
| 72 |
image = np.array(image_pil).astype(np.float32) / 255.0
|
| 73 |
image = np.expand_dims(image, axis=0)
|
| 74 |
|
| 75 |
+
# Ensure float32
|
| 76 |
+
image = image.astype(np.float32)
|
| 77 |
+
|
| 78 |
+
# Run inference with TF2 API
|
| 79 |
+
# Convert to tensor
|
| 80 |
+
image_tensor = tf.convert_to_tensor(image)
|
| 81 |
+
|
| 82 |
+
# Call the function - use the correct input name from signature
|
| 83 |
+
if 'image' in self._input_names:
|
| 84 |
+
result = self._infer_fn(image=image_tensor)
|
| 85 |
+
else:
|
| 86 |
+
# Use the first input name
|
| 87 |
+
result = self._infer_fn(**{self._input_names[0]: image_tensor})
|
| 88 |
+
|
| 89 |
+
# Get decoded output - use the correct output name
|
| 90 |
+
if 'decoded' in self._output_names:
|
| 91 |
+
decoded = result['decoded'].numpy()
|
| 92 |
+
else:
|
| 93 |
+
# Use the first output
|
| 94 |
+
decoded = list(result.values())[0].numpy()
|
| 95 |
|
| 96 |
# Output is (batch, 100)
|
| 97 |
+
secret = decoded[0] # Remove batch dimension
|
| 98 |
|
| 99 |
# Threshold to get binary values
|
| 100 |
bits = (secret > 0.5).astype(np.float32)
|
|
|
|
| 106 |
return self.decode(image)
|
| 107 |
|
| 108 |
def close(self):
|
| 109 |
+
"""Cleanup resources"""
|
| 110 |
+
# TF2 models don't need explicit session closing
|
| 111 |
+
pass
|
| 112 |
|
| 113 |
def __del__(self):
|
| 114 |
"""Cleanup on deletion"""
|