""" TensorFlow Decoder using stegastamp_pretrained model This model contains BOTH encoder and decoder in one checkpoint """ import tensorflow as tf import numpy as np from PIL import Image as PILImage class TFStegaStampDecoderPretrained: """ TensorFlow-based StegaStamp decoder using the pretrained model This is the CORRECT decoder for extracting 100-bit secrets """ def __init__(self, model_path='saved_models/stegastamp_pretrained'): """ Load the TensorFlow pretrained model (contains both encoder and decoder) Args: model_path: Path to the pretrained saved model """ print(f"Loading TensorFlow pretrained model from: {model_path}") # Load model with TF2 API loaded = tf.saved_model.load(model_path) # Get the serving function self._infer_fn = loaded.signatures['serving_default'] # Get input/output info input_signature = self._infer_fn.structured_input_signature[1] output_signature = self._infer_fn.structured_outputs self._input_names = list(input_signature.keys()) self._output_names = list(output_signature.keys()) print(f" Input tensors: {self._input_names}") print(f" Output tensors: {self._output_names}") print("TensorFlow decoder loaded successfully!") def decode(self, image): """ Decode a 100-bit secret from an image Args: image: numpy array of shape (H, W, 3) or (B, H, W, 3) OR PIL Image Values should be in [0, 1] range for numpy, [0, 255] for PIL Returns: numpy array of shape (100,) containing decoded bits (0 or 1) """ # Convert PIL Image to numpy if needed if hasattr(image, 'mode'): # It's a PIL Image image = np.array(image).astype(np.float32) / 255.0 # Ensure batch dimension if image.ndim == 3: image = np.expand_dims(image, axis=0) # Ensure correct shape (B, H, W, C) for TensorFlow if image.shape[-1] != 3: # If input is (B, C, H, W), transpose to (B, H, W, C) image = np.transpose(image, (0, 2, 3, 1)) # Ensure 400x400 if image.shape[1:3] != (400, 400): # Resize using PIL image_pil = PILImage.fromarray((image[0] * 255).astype(np.uint8)) image_pil = image_pil.resize((400, 400)) image = np.array(image_pil).astype(np.float32) / 255.0 image = np.expand_dims(image, axis=0) # Ensure float32 image = image.astype(np.float32) # Run inference with TF2 API # Convert to tensor image_tensor = tf.convert_to_tensor(image) # Call the function - use the correct input name from signature if 'image' in self._input_names: result = self._infer_fn(image=image_tensor) else: # Use the first input name result = self._infer_fn(**{self._input_names[0]: image_tensor}) # Get decoded output - use the correct output name if 'decoded' in self._output_names: decoded = result['decoded'].numpy() else: # Use the first output decoded = list(result.values())[0].numpy() # Output is (batch, 100) secret = decoded[0] # Remove batch dimension # Threshold to get binary values bits = (secret > 0.5).astype(np.float32) return bits def __call__(self, image): """Make the decoder callable""" return self.decode(image) def close(self): """Cleanup resources""" # TF2 models don't need explicit session closing pass def __del__(self): """Cleanup on deletion""" try: self.close() except: pass # Ignore cleanup errors during shutdown # Quick test if __name__ == "__main__": print("\n" + "="*80) print("Testing TF Decoder with Pretrained Model") print("="*80 + "\n") try: # Initialize decoder decoder = TFStegaStampDecoderPretrained() # Create test image test_image = np.ones((400, 400, 3), dtype=np.float32) * 0.5 print("\nTesting decoding...") decoded = decoder.decode(test_image) if decoded is not None: print(f"āœ“ Decoded {len(decoded)} bits") print(f" Sample values: {decoded[:20]}") print(f" Mean: {decoded.mean():.3f}") else: print("āŒ Decoding failed") decoder.close() print("\nāœ“ Test complete!") except Exception as e: print(f"\nāŒ Error: {e}") import traceback traceback.print_exc()