File size: 5,108 Bytes
6b430c5 2bb3322 6b430c5 2bb3322 6b430c5 2bb3322 6b430c5 2bb3322 6b430c5 2bb3322 6b430c5 2bb3322 6b430c5 2bb3322 6b430c5 2bb3322 6b430c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
TensorFlow Decoder using stegastamp_pretrained model
This model contains BOTH encoder and decoder in one checkpoint
"""
import tensorflow as tf
import numpy as np
from PIL import Image as PILImage
class TFStegaStampDecoderPretrained:
"""
TensorFlow-based StegaStamp decoder using the pretrained model
This is the CORRECT decoder for extracting 100-bit secrets
"""
def __init__(self, model_path='saved_models/stegastamp_pretrained'):
"""
Load the TensorFlow pretrained model (contains both encoder and decoder)
Args:
model_path: Path to the pretrained saved model
"""
print(f"Loading TensorFlow pretrained model from: {model_path}")
# Load model with TF2 API
loaded = tf.saved_model.load(model_path)
# Get the serving function
self._infer_fn = loaded.signatures['serving_default']
# Get input/output info
input_signature = self._infer_fn.structured_input_signature[1]
output_signature = self._infer_fn.structured_outputs
self._input_names = list(input_signature.keys())
self._output_names = list(output_signature.keys())
print(f" Input tensors: {self._input_names}")
print(f" Output tensors: {self._output_names}")
print("TensorFlow decoder loaded successfully!")
def decode(self, image):
"""
Decode a 100-bit secret from an image
Args:
image: numpy array of shape (H, W, 3) or (B, H, W, 3)
OR PIL Image
Values should be in [0, 1] range for numpy, [0, 255] for PIL
Returns:
numpy array of shape (100,) containing decoded bits (0 or 1)
"""
# Convert PIL Image to numpy if needed
if hasattr(image, 'mode'): # It's a PIL Image
image = np.array(image).astype(np.float32) / 255.0
# Ensure batch dimension
if image.ndim == 3:
image = np.expand_dims(image, axis=0)
# Ensure correct shape (B, H, W, C) for TensorFlow
if image.shape[-1] != 3:
# If input is (B, C, H, W), transpose to (B, H, W, C)
image = np.transpose(image, (0, 2, 3, 1))
# Ensure 400x400
if image.shape[1:3] != (400, 400):
# Resize using PIL
image_pil = PILImage.fromarray((image[0] * 255).astype(np.uint8))
image_pil = image_pil.resize((400, 400))
image = np.array(image_pil).astype(np.float32) / 255.0
image = np.expand_dims(image, axis=0)
# Ensure float32
image = image.astype(np.float32)
# Run inference with TF2 API
# Convert to tensor
image_tensor = tf.convert_to_tensor(image)
# Call the function - use the correct input name from signature
if 'image' in self._input_names:
result = self._infer_fn(image=image_tensor)
else:
# Use the first input name
result = self._infer_fn(**{self._input_names[0]: image_tensor})
# Get decoded output - use the correct output name
if 'decoded' in self._output_names:
decoded = result['decoded'].numpy()
else:
# Use the first output
decoded = list(result.values())[0].numpy()
# Output is (batch, 100)
secret = decoded[0] # Remove batch dimension
# Threshold to get binary values
bits = (secret > 0.5).astype(np.float32)
return bits
def __call__(self, image):
"""Make the decoder callable"""
return self.decode(image)
def close(self):
"""Cleanup resources"""
# TF2 models don't need explicit session closing
pass
def __del__(self):
"""Cleanup on deletion"""
try:
self.close()
except:
pass # Ignore cleanup errors during shutdown
# Quick test
if __name__ == "__main__":
print("\n" + "="*80)
print("Testing TF Decoder with Pretrained Model")
print("="*80 + "\n")
try:
# Initialize decoder
decoder = TFStegaStampDecoderPretrained()
# Create test image
test_image = np.ones((400, 400, 3), dtype=np.float32) * 0.5
print("\nTesting decoding...")
decoded = decoder.decode(test_image)
if decoded is not None:
print(f"✓ Decoded {len(decoded)} bits")
print(f" Sample values: {decoded[:20]}")
print(f" Mean: {decoded.mean():.3f}")
else:
print("❌ Decoding failed")
decoder.close()
print("\n✓ Test complete!")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc() |