stegastamp / example_usage.py
KingTechnician's picture
Upload folder using huggingface_hub
6b430c5 verified
"""
Example usage of StegaStamp model
"""
from stegastamp_model import StegaStampModel
from PIL import Image
import torch
print("="*80)
print("StegaStamp Model Examples")
print("="*80)
# Load model
print("\n1. Loading model...")
model = StegaStampModel.from_pretrained(".") # Load from current directory
# Load or create test image
print("\n2. Loading image...")
try:
image = Image.open("test.jpg")
print(f" Loaded test.jpg ({image.size})")
except:
image = Image.new('RGB', (400, 400), color=(128, 128, 128))
print(" Created test image (400x400 gray)")
if image.size != (400, 400):
image = image.resize((400, 400))
print(f" Resized to 400x400")
# Example 1: Text encoding
print("\n" + "-"*80)
print("Example 1: Text Encoding")
print("-"*80)
text = "Hello!!"
print(f"Encoding text: '{text}'")
encoded = model.encode_text(image, text)
encoded.save("example_text_encoded.png")
print("βœ“ Saved to: example_text_encoded.png")
# IMPORTANT: Load the image back from disk
encoded_loaded = Image.open("example_text_encoded.png")
decoded = model.decode_text(encoded_loaded)
print(f"Decoded text: '{decoded}'")
if decoded == text:
print("βœ“βœ“ Perfect match!")
else:
print(f"βœ— Mismatch (expected '{text}')")
# Example 2: Binary encoding
print("\n" + "-"*80)
print("Example 2: Binary Data Encoding")
print("-"*80)
# Create pattern: alternating 0s and 1s
secret = torch.tensor([i % 2 for i in range(100)], dtype=torch.float32)
print(f"Encoding pattern: [0,1,0,1,...]")
encoded = model.encode(image, secret)
encoded.save("example_binary_encoded.png")
print("βœ“ Saved to: example_binary_encoded.png")
# Load back from disk
encoded_loaded = Image.open("example_binary_encoded.png")
decoded = model.decode(encoded_loaded)
accuracy = (decoded == secret.numpy()).mean()
print(f"Decoded accuracy: {accuracy*100:.1f}%")
if accuracy > 0.95:
print("βœ“βœ“ Excellent accuracy!")
# Example 3: Multiple messages
print("\n" + "-"*80)
print("Example 3: Multiple Messages")
print("-"*80)
messages = ["Test123", "PyTorch", "CS ML", "2024"]
for msg in messages:
encoded = model.encode_text(image, msg)
filename = f"example_{msg.replace(' ', '_')}.png"
encoded.save(filename)
# Load back from disk
encoded_loaded = Image.open(filename)
decoded = model.decode_text(encoded_loaded)
status = "βœ“" if decoded == msg else "βœ—"
print(f" {status} '{msg}' β†’ '{decoded}'")
# Cleanup
print("\n" + "-"*80)
model.close()
print("βœ“ Examples complete!")
print("="*80)