""" Example usage of StegaStamp model """ from stegastamp_model import StegaStampModel from PIL import Image import torch print("="*80) print("StegaStamp Model Examples") print("="*80) # Load model print("\n1. Loading model...") model = StegaStampModel.from_pretrained(".") # Load from current directory # Load or create test image print("\n2. Loading image...") try: image = Image.open("test.jpg") print(f" Loaded test.jpg ({image.size})") except: image = Image.new('RGB', (400, 400), color=(128, 128, 128)) print(" Created test image (400x400 gray)") if image.size != (400, 400): image = image.resize((400, 400)) print(f" Resized to 400x400") # Example 1: Text encoding print("\n" + "-"*80) print("Example 1: Text Encoding") print("-"*80) text = "Hello!!" print(f"Encoding text: '{text}'") encoded = model.encode_text(image, text) encoded.save("example_text_encoded.png") print("✓ Saved to: example_text_encoded.png") # IMPORTANT: Load the image back from disk encoded_loaded = Image.open("example_text_encoded.png") decoded = model.decode_text(encoded_loaded) print(f"Decoded text: '{decoded}'") if decoded == text: print("✓✓ Perfect match!") else: print(f"✗ Mismatch (expected '{text}')") # Example 2: Binary encoding print("\n" + "-"*80) print("Example 2: Binary Data Encoding") print("-"*80) # Create pattern: alternating 0s and 1s secret = torch.tensor([i % 2 for i in range(100)], dtype=torch.float32) print(f"Encoding pattern: [0,1,0,1,...]") encoded = model.encode(image, secret) encoded.save("example_binary_encoded.png") print("✓ Saved to: example_binary_encoded.png") # Load back from disk encoded_loaded = Image.open("example_binary_encoded.png") decoded = model.decode(encoded_loaded) accuracy = (decoded == secret.numpy()).mean() print(f"Decoded accuracy: {accuracy*100:.1f}%") if accuracy > 0.95: print("✓✓ Excellent accuracy!") # Example 3: Multiple messages print("\n" + "-"*80) print("Example 3: Multiple Messages") print("-"*80) messages = ["Test123", "PyTorch", "CS ML", "2024"] for msg in messages: encoded = model.encode_text(image, msg) filename = f"example_{msg.replace(' ', '_')}.png" encoded.save(filename) # Load back from disk encoded_loaded = Image.open(filename) decoded = model.decode_text(encoded_loaded) status = "✓" if decoded == msg else "✗" print(f" {status} '{msg}' → '{decoded}'") # Cleanup print("\n" + "-"*80) model.close() print("✓ Examples complete!") print("="*80)