File size: 2,630 Bytes
6b430c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""

Example usage of StegaStamp model

"""
from stegastamp_model import StegaStampModel
from PIL import Image
import torch

print("="*80)
print("StegaStamp Model Examples")
print("="*80)

# Load model
print("\n1. Loading model...")
model = StegaStampModel.from_pretrained(".")  # Load from current directory

# Load or create test image
print("\n2. Loading image...")
try:
    image = Image.open("test.jpg")
    print(f"   Loaded test.jpg ({image.size})")
except:
    image = Image.new('RGB', (400, 400), color=(128, 128, 128))
    print("   Created test image (400x400 gray)")

if image.size != (400, 400):
    image = image.resize((400, 400))
    print(f"   Resized to 400x400")

# Example 1: Text encoding
print("\n" + "-"*80)
print("Example 1: Text Encoding")
print("-"*80)

text = "Hello!!"
print(f"Encoding text: '{text}'")

encoded = model.encode_text(image, text)
encoded.save("example_text_encoded.png")
print("βœ“ Saved to: example_text_encoded.png")

# IMPORTANT: Load the image back from disk
encoded_loaded = Image.open("example_text_encoded.png")
decoded = model.decode_text(encoded_loaded)
print(f"Decoded text: '{decoded}'")

if decoded == text:
    print("βœ“βœ“ Perfect match!")
else:
    print(f"βœ— Mismatch (expected '{text}')")

# Example 2: Binary encoding
print("\n" + "-"*80)
print("Example 2: Binary Data Encoding")
print("-"*80)

# Create pattern: alternating 0s and 1s
secret = torch.tensor([i % 2 for i in range(100)], dtype=torch.float32)
print(f"Encoding pattern: [0,1,0,1,...]")

encoded = model.encode(image, secret)
encoded.save("example_binary_encoded.png")
print("βœ“ Saved to: example_binary_encoded.png")

# Load back from disk
encoded_loaded = Image.open("example_binary_encoded.png")
decoded = model.decode(encoded_loaded)
accuracy = (decoded == secret.numpy()).mean()
print(f"Decoded accuracy: {accuracy*100:.1f}%")

if accuracy > 0.95:
    print("βœ“βœ“ Excellent accuracy!")

# Example 3: Multiple messages
print("\n" + "-"*80)
print("Example 3: Multiple Messages")
print("-"*80)

messages = ["Test123", "PyTorch", "CS ML", "2024"]

for msg in messages:
    encoded = model.encode_text(image, msg)
    filename = f"example_{msg.replace(' ', '_')}.png"
    encoded.save(filename)
    
    # Load back from disk
    encoded_loaded = Image.open(filename)
    decoded = model.decode_text(encoded_loaded)
    
    status = "βœ“" if decoded == msg else "βœ—"
    print(f"  {status} '{msg}' β†’ '{decoded}'")

# Cleanup
print("\n" + "-"*80)
model.close()
print("βœ“ Examples complete!")
print("="*80)