|
|
"""
|
|
|
Example usage of StegaStamp model
|
|
|
"""
|
|
|
from stegastamp_model import StegaStampModel
|
|
|
from PIL import Image
|
|
|
import torch
|
|
|
|
|
|
print("="*80)
|
|
|
print("StegaStamp Model Examples")
|
|
|
print("="*80)
|
|
|
|
|
|
|
|
|
print("\n1. Loading model...")
|
|
|
model = StegaStampModel.from_pretrained(".")
|
|
|
|
|
|
|
|
|
print("\n2. Loading image...")
|
|
|
try:
|
|
|
image = Image.open("test.jpg")
|
|
|
print(f" Loaded test.jpg ({image.size})")
|
|
|
except:
|
|
|
image = Image.new('RGB', (400, 400), color=(128, 128, 128))
|
|
|
print(" Created test image (400x400 gray)")
|
|
|
|
|
|
if image.size != (400, 400):
|
|
|
image = image.resize((400, 400))
|
|
|
print(f" Resized to 400x400")
|
|
|
|
|
|
|
|
|
print("\n" + "-"*80)
|
|
|
print("Example 1: Text Encoding")
|
|
|
print("-"*80)
|
|
|
|
|
|
text = "Hello!!"
|
|
|
print(f"Encoding text: '{text}'")
|
|
|
|
|
|
encoded = model.encode_text(image, text)
|
|
|
encoded.save("example_text_encoded.png")
|
|
|
print("β Saved to: example_text_encoded.png")
|
|
|
|
|
|
|
|
|
encoded_loaded = Image.open("example_text_encoded.png")
|
|
|
decoded = model.decode_text(encoded_loaded)
|
|
|
print(f"Decoded text: '{decoded}'")
|
|
|
|
|
|
if decoded == text:
|
|
|
print("ββ Perfect match!")
|
|
|
else:
|
|
|
print(f"β Mismatch (expected '{text}')")
|
|
|
|
|
|
|
|
|
print("\n" + "-"*80)
|
|
|
print("Example 2: Binary Data Encoding")
|
|
|
print("-"*80)
|
|
|
|
|
|
|
|
|
secret = torch.tensor([i % 2 for i in range(100)], dtype=torch.float32)
|
|
|
print(f"Encoding pattern: [0,1,0,1,...]")
|
|
|
|
|
|
encoded = model.encode(image, secret)
|
|
|
encoded.save("example_binary_encoded.png")
|
|
|
print("β Saved to: example_binary_encoded.png")
|
|
|
|
|
|
|
|
|
encoded_loaded = Image.open("example_binary_encoded.png")
|
|
|
decoded = model.decode(encoded_loaded)
|
|
|
accuracy = (decoded == secret.numpy()).mean()
|
|
|
print(f"Decoded accuracy: {accuracy*100:.1f}%")
|
|
|
|
|
|
if accuracy > 0.95:
|
|
|
print("ββ Excellent accuracy!")
|
|
|
|
|
|
|
|
|
print("\n" + "-"*80)
|
|
|
print("Example 3: Multiple Messages")
|
|
|
print("-"*80)
|
|
|
|
|
|
messages = ["Test123", "PyTorch", "CS ML", "2024"]
|
|
|
|
|
|
for msg in messages:
|
|
|
encoded = model.encode_text(image, msg)
|
|
|
filename = f"example_{msg.replace(' ', '_')}.png"
|
|
|
encoded.save(filename)
|
|
|
|
|
|
|
|
|
encoded_loaded = Image.open(filename)
|
|
|
decoded = model.decode_text(encoded_loaded)
|
|
|
|
|
|
status = "β" if decoded == msg else "β"
|
|
|
print(f" {status} '{msg}' β '{decoded}'")
|
|
|
|
|
|
|
|
|
print("\n" + "-"*80)
|
|
|
model.close()
|
|
|
print("β Examples complete!")
|
|
|
print("="*80) |