|
|
"""
|
|
|
TensorFlow Decoder using stegastamp_pretrained model
|
|
|
This model contains BOTH encoder and decoder in one checkpoint
|
|
|
"""
|
|
|
import tensorflow as tf
|
|
|
import numpy as np
|
|
|
from PIL import Image as PILImage
|
|
|
|
|
|
|
|
|
class TFStegaStampDecoderPretrained:
|
|
|
"""
|
|
|
TensorFlow-based StegaStamp decoder using the pretrained model
|
|
|
This is the CORRECT decoder for extracting 100-bit secrets
|
|
|
"""
|
|
|
|
|
|
def __init__(self, model_path='saved_models/stegastamp_pretrained'):
|
|
|
"""
|
|
|
Load the TensorFlow pretrained model (contains both encoder and decoder)
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the pretrained saved model
|
|
|
"""
|
|
|
print(f"Loading TensorFlow pretrained model from: {model_path}")
|
|
|
|
|
|
|
|
|
loaded = tf.saved_model.load(model_path)
|
|
|
|
|
|
|
|
|
self._infer_fn = loaded.signatures['serving_default']
|
|
|
|
|
|
|
|
|
input_signature = self._infer_fn.structured_input_signature[1]
|
|
|
output_signature = self._infer_fn.structured_outputs
|
|
|
|
|
|
self._input_names = list(input_signature.keys())
|
|
|
self._output_names = list(output_signature.keys())
|
|
|
|
|
|
print(f" Input tensors: {self._input_names}")
|
|
|
print(f" Output tensors: {self._output_names}")
|
|
|
print("TensorFlow decoder loaded successfully!")
|
|
|
|
|
|
def decode(self, image):
|
|
|
"""
|
|
|
Decode a 100-bit secret from an image
|
|
|
|
|
|
Args:
|
|
|
image: numpy array of shape (H, W, 3) or (B, H, W, 3)
|
|
|
OR PIL Image
|
|
|
Values should be in [0, 1] range for numpy, [0, 255] for PIL
|
|
|
|
|
|
Returns:
|
|
|
numpy array of shape (100,) containing decoded bits (0 or 1)
|
|
|
"""
|
|
|
|
|
|
if hasattr(image, 'mode'):
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
|
|
|
|
|
if image.ndim == 3:
|
|
|
image = np.expand_dims(image, axis=0)
|
|
|
|
|
|
|
|
|
if image.shape[-1] != 3:
|
|
|
|
|
|
image = np.transpose(image, (0, 2, 3, 1))
|
|
|
|
|
|
|
|
|
if image.shape[1:3] != (400, 400):
|
|
|
|
|
|
image_pil = PILImage.fromarray((image[0] * 255).astype(np.uint8))
|
|
|
image_pil = image_pil.resize((400, 400))
|
|
|
image = np.array(image_pil).astype(np.float32) / 255.0
|
|
|
image = np.expand_dims(image, axis=0)
|
|
|
|
|
|
|
|
|
image = image.astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
image_tensor = tf.convert_to_tensor(image)
|
|
|
|
|
|
|
|
|
if 'image' in self._input_names:
|
|
|
result = self._infer_fn(image=image_tensor)
|
|
|
else:
|
|
|
|
|
|
result = self._infer_fn(**{self._input_names[0]: image_tensor})
|
|
|
|
|
|
|
|
|
if 'decoded' in self._output_names:
|
|
|
decoded = result['decoded'].numpy()
|
|
|
else:
|
|
|
|
|
|
decoded = list(result.values())[0].numpy()
|
|
|
|
|
|
|
|
|
secret = decoded[0]
|
|
|
|
|
|
|
|
|
bits = (secret > 0.5).astype(np.float32)
|
|
|
|
|
|
return bits
|
|
|
|
|
|
def __call__(self, image):
|
|
|
"""Make the decoder callable"""
|
|
|
return self.decode(image)
|
|
|
|
|
|
def close(self):
|
|
|
"""Cleanup resources"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
def __del__(self):
|
|
|
"""Cleanup on deletion"""
|
|
|
try:
|
|
|
self.close()
|
|
|
except:
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("\n" + "="*80)
|
|
|
print("Testing TF Decoder with Pretrained Model")
|
|
|
print("="*80 + "\n")
|
|
|
|
|
|
try:
|
|
|
|
|
|
decoder = TFStegaStampDecoderPretrained()
|
|
|
|
|
|
|
|
|
test_image = np.ones((400, 400, 3), dtype=np.float32) * 0.5
|
|
|
|
|
|
print("\nTesting decoding...")
|
|
|
decoded = decoder.decode(test_image)
|
|
|
|
|
|
if decoded is not None:
|
|
|
print(f"✓ Decoded {len(decoded)} bits")
|
|
|
print(f" Sample values: {decoded[:20]}")
|
|
|
print(f" Mean: {decoded.mean():.3f}")
|
|
|
else:
|
|
|
print("❌ Decoding failed")
|
|
|
|
|
|
decoder.close()
|
|
|
print("\n✓ Test complete!")
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"\n❌ Error: {e}")
|
|
|
import traceback
|
|
|
traceback.print_exc() |