stegastamp / decoder.py
KingTechnician's picture
Upload folder using huggingface_hub
2bb3322 verified
"""
TensorFlow Decoder using stegastamp_pretrained model
This model contains BOTH encoder and decoder in one checkpoint
"""
import tensorflow as tf
import numpy as np
from PIL import Image as PILImage
class TFStegaStampDecoderPretrained:
"""
TensorFlow-based StegaStamp decoder using the pretrained model
This is the CORRECT decoder for extracting 100-bit secrets
"""
def __init__(self, model_path='saved_models/stegastamp_pretrained'):
"""
Load the TensorFlow pretrained model (contains both encoder and decoder)
Args:
model_path: Path to the pretrained saved model
"""
print(f"Loading TensorFlow pretrained model from: {model_path}")
# Load model with TF2 API
loaded = tf.saved_model.load(model_path)
# Get the serving function
self._infer_fn = loaded.signatures['serving_default']
# Get input/output info
input_signature = self._infer_fn.structured_input_signature[1]
output_signature = self._infer_fn.structured_outputs
self._input_names = list(input_signature.keys())
self._output_names = list(output_signature.keys())
print(f" Input tensors: {self._input_names}")
print(f" Output tensors: {self._output_names}")
print("TensorFlow decoder loaded successfully!")
def decode(self, image):
"""
Decode a 100-bit secret from an image
Args:
image: numpy array of shape (H, W, 3) or (B, H, W, 3)
OR PIL Image
Values should be in [0, 1] range for numpy, [0, 255] for PIL
Returns:
numpy array of shape (100,) containing decoded bits (0 or 1)
"""
# Convert PIL Image to numpy if needed
if hasattr(image, 'mode'): # It's a PIL Image
image = np.array(image).astype(np.float32) / 255.0
# Ensure batch dimension
if image.ndim == 3:
image = np.expand_dims(image, axis=0)
# Ensure correct shape (B, H, W, C) for TensorFlow
if image.shape[-1] != 3:
# If input is (B, C, H, W), transpose to (B, H, W, C)
image = np.transpose(image, (0, 2, 3, 1))
# Ensure 400x400
if image.shape[1:3] != (400, 400):
# Resize using PIL
image_pil = PILImage.fromarray((image[0] * 255).astype(np.uint8))
image_pil = image_pil.resize((400, 400))
image = np.array(image_pil).astype(np.float32) / 255.0
image = np.expand_dims(image, axis=0)
# Ensure float32
image = image.astype(np.float32)
# Run inference with TF2 API
# Convert to tensor
image_tensor = tf.convert_to_tensor(image)
# Call the function - use the correct input name from signature
if 'image' in self._input_names:
result = self._infer_fn(image=image_tensor)
else:
# Use the first input name
result = self._infer_fn(**{self._input_names[0]: image_tensor})
# Get decoded output - use the correct output name
if 'decoded' in self._output_names:
decoded = result['decoded'].numpy()
else:
# Use the first output
decoded = list(result.values())[0].numpy()
# Output is (batch, 100)
secret = decoded[0] # Remove batch dimension
# Threshold to get binary values
bits = (secret > 0.5).astype(np.float32)
return bits
def __call__(self, image):
"""Make the decoder callable"""
return self.decode(image)
def close(self):
"""Cleanup resources"""
# TF2 models don't need explicit session closing
pass
def __del__(self):
"""Cleanup on deletion"""
try:
self.close()
except:
pass # Ignore cleanup errors during shutdown
# Quick test
if __name__ == "__main__":
print("\n" + "="*80)
print("Testing TF Decoder with Pretrained Model")
print("="*80 + "\n")
try:
# Initialize decoder
decoder = TFStegaStampDecoderPretrained()
# Create test image
test_image = np.ones((400, 400, 3), dtype=np.float32) * 0.5
print("\nTesting decoding...")
decoded = decoder.decode(test_image)
if decoded is not None:
print(f"✓ Decoded {len(decoded)} bits")
print(f" Sample values: {decoded[:20]}")
print(f" Mean: {decoded.mean():.3f}")
else:
print("❌ Decoding failed")
decoder.close()
print("\n✓ Test complete!")
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()