File size: 5,108 Bytes
6b430c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb3322
 
6b430c5
2bb3322
 
6b430c5
2bb3322
 
 
6b430c5
2bb3322
 
6b430c5
2bb3322
 
6b430c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb3322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b430c5
 
2bb3322
6b430c5
 
 
 
 
 
 
 
 
 
 
2bb3322
 
 
6b430c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""

TensorFlow Decoder using stegastamp_pretrained model

This model contains BOTH encoder and decoder in one checkpoint

"""
import tensorflow as tf
import numpy as np
from PIL import Image as PILImage


class TFStegaStampDecoderPretrained:
    """

    TensorFlow-based StegaStamp decoder using the pretrained model

    This is the CORRECT decoder for extracting 100-bit secrets

    """
    
    def __init__(self, model_path='saved_models/stegastamp_pretrained'):
        """

        Load the TensorFlow pretrained model (contains both encoder and decoder)

        

        Args:

            model_path: Path to the pretrained saved model

        """
        print(f"Loading TensorFlow pretrained model from: {model_path}")
        
        # Load model with TF2 API
        loaded = tf.saved_model.load(model_path)
        
        # Get the serving function
        self._infer_fn = loaded.signatures['serving_default']
        
        # Get input/output info
        input_signature = self._infer_fn.structured_input_signature[1]
        output_signature = self._infer_fn.structured_outputs
        
        self._input_names = list(input_signature.keys())
        self._output_names = list(output_signature.keys())
        
        print(f"  Input tensors: {self._input_names}")
        print(f"  Output tensors: {self._output_names}")
        print("TensorFlow decoder loaded successfully!")
    
    def decode(self, image):
        """

        Decode a 100-bit secret from an image

        

        Args:

            image: numpy array of shape (H, W, 3) or (B, H, W, 3)

                   OR PIL Image

                   Values should be in [0, 1] range for numpy, [0, 255] for PIL

        

        Returns:

            numpy array of shape (100,) containing decoded bits (0 or 1)

        """
        # Convert PIL Image to numpy if needed
        if hasattr(image, 'mode'):  # It's a PIL Image
            image = np.array(image).astype(np.float32) / 255.0
        
        # Ensure batch dimension
        if image.ndim == 3:
            image = np.expand_dims(image, axis=0)
        
        # Ensure correct shape (B, H, W, C) for TensorFlow
        if image.shape[-1] != 3:
            # If input is (B, C, H, W), transpose to (B, H, W, C)
            image = np.transpose(image, (0, 2, 3, 1))
        
        # Ensure 400x400
        if image.shape[1:3] != (400, 400):
            # Resize using PIL
            image_pil = PILImage.fromarray((image[0] * 255).astype(np.uint8))
            image_pil = image_pil.resize((400, 400))
            image = np.array(image_pil).astype(np.float32) / 255.0
            image = np.expand_dims(image, axis=0)
        
        # Ensure float32
        image = image.astype(np.float32)
        
        # Run inference with TF2 API
        # Convert to tensor
        image_tensor = tf.convert_to_tensor(image)
        
        # Call the function - use the correct input name from signature
        if 'image' in self._input_names:
            result = self._infer_fn(image=image_tensor)
        else:
            # Use the first input name
            result = self._infer_fn(**{self._input_names[0]: image_tensor})
        
        # Get decoded output - use the correct output name
        if 'decoded' in self._output_names:
            decoded = result['decoded'].numpy()
        else:
            # Use the first output
            decoded = list(result.values())[0].numpy()
        
        # Output is (batch, 100)
        secret = decoded[0]  # Remove batch dimension
        
        # Threshold to get binary values
        bits = (secret > 0.5).astype(np.float32)
        
        return bits
    
    def __call__(self, image):
        """Make the decoder callable"""
        return self.decode(image)
    
    def close(self):
        """Cleanup resources"""
        # TF2 models don't need explicit session closing
        pass
    
    def __del__(self):
        """Cleanup on deletion"""
        try:
            self.close()
        except:
            pass  # Ignore cleanup errors during shutdown


# Quick test
if __name__ == "__main__":
    print("\n" + "="*80)
    print("Testing TF Decoder with Pretrained Model")
    print("="*80 + "\n")
    
    try:
        # Initialize decoder
        decoder = TFStegaStampDecoderPretrained()
        
        # Create test image
        test_image = np.ones((400, 400, 3), dtype=np.float32) * 0.5
        
        print("\nTesting decoding...")
        decoded = decoder.decode(test_image)
        
        if decoded is not None:
            print(f"✓ Decoded {len(decoded)} bits")
            print(f"  Sample values: {decoded[:20]}")
            print(f"  Mean: {decoded.mean():.3f}")
        else:
            print("❌ Decoding failed")
        
        decoder.close()
        print("\n✓ Test complete!")
        
    except Exception as e:
        print(f"\n❌ Error: {e}")
        import traceback
        traceback.print_exc()