File size: 13,837 Bytes
fc6dcab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import streamlit as st
import os
import torch
from transformers import RobertaTokenizer, RobertaForMaskedLM
import spacy
import subprocess
import sys
import nltk
from nltk.tokenize import word_tokenize
from utils_final import extract_entities_and_pos, whole_context_process_sentence

# Download NLTK data if not available
def setup_nltk():
    """Setup NLTK data with error handling."""
    try:
        nltk.download('punkt_tab', quiet=True)
    except:
        pass
    try:
        nltk.download('averaged_perceptron_tagger_eng', quiet=True)
    except:
        pass
    try:
        nltk.download('wordnet', quiet=True)
    except:
        pass
    try:
        nltk.download('omw-1.4', quiet=True)
    except:
        pass

setup_nltk()

# Set environment
cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm'

# Load spaCy model - download if not available
try:
    nlp = spacy.load("en_core_web_sm")
except OSError:
    print("Downloading spaCy model...")
    subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
    nlp = spacy.load("en_core_web_sm")

# Define apply_replacements function (from Safeseal_gen_final.py)
def apply_replacements(sentence, replacements):
    """
    Apply replacements to the sentence while preserving original formatting, spacing, and punctuation.
    """
    doc = nlp(sentence)  # Tokenize the sentence
    tokens = [token.text_with_ws for token in doc]  # Preserve original whitespace with tokens

    # Apply replacements based on token positions
    for position, target, replacement in replacements:
        if position < len(tokens) and tokens[position].strip() == target:
            tokens[position] = replacement + (" " if tokens[position].endswith(" ") else "")

    # Reassemble the sentence
    return "".join(tokens)

# Initialize session state for model caching
@st.cache_resource
def load_model():
    """Load the model and tokenizer (cached to avoid reloading on every run)"""
    print("Loading model...")
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    lm_model = RobertaForMaskedLM.from_pretrained('roberta-base', attn_implementation="eager")
    
    tokenizer.model_max_length = 512
    tokenizer.max_len = 512
    
    if hasattr(lm_model.config, 'max_position_embeddings'):
        lm_model.config.max_position_embeddings = 512
    
    lm_model.eval()
    
    if torch.cuda.is_available():
        lm_model = lm_model.cuda()
        print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
    else:
        print("Model loaded on CPU")
    
    return tokenizer, lm_model

sampling_results = []

def process_text_wrapper(text, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, batch_size=32, max_length=512, similarity_context_mode='whole'):
    """
    Wrapper function to process text and return watermarked output with tracking of changes.
    """
    global sampling_results
    sampling_results = []
    
    lines = text.splitlines(keepends=True)
    final_text = []
    total_randomized_words = 0
    total_words = len(word_tokenize(text))
    
    # Track changed words and their positions
    changed_words = []  # List of (original, replacement, position)
    
    for line in lines:
        if line.strip():
            replacements, sampling_results_line = whole_context_process_sentence(
                text,
                line.strip(),
                tokenizer, lm_model, Top_K, threshold,
                secret_key, m, c, h, alpha, "output",
                batch_size=batch_size, max_length=max_length, similarity_context_mode=similarity_context_mode
            )
            
            sampling_results.extend(sampling_results_line)
            
            if replacements:
                randomized_line = apply_replacements(line, replacements)
                final_text.append(randomized_line)
                
                # Track ONLY actual changes (where original != replacement)
                for position, original, replacement in replacements:
                    if original != replacement:
                        changed_words.append((original, replacement, position))
                        total_randomized_words += 1
            else:
                final_text.append(line)
        else:
            final_text.append(line)
    
    return "".join(final_text), total_randomized_words, total_words, changed_words, sampling_results

def create_html_with_highlights(original_text, watermarked_text, changed_words_info, sampling_results):
    """
    Create HTML with highlighted changed words using spaCy tokenization.
    """
    # Create a set of replacement words that were actually changed (not same as original)
    actual_replacements = set()
    replacement_to_original = {}
    
    for original, replacement, _ in changed_words_info:
        if original.lower() != replacement.lower():  # Only map actual changes
            actual_replacements.add(replacement.lower())
            replacement_to_original[replacement.lower()] = original
    
    # Parse watermarked text with spaCy
    doc_watermarked = nlp(watermarked_text)
    
    # Build HTML by processing the watermarked text
    result_html = []
    words_highlighted = set()  # Track which words we've highlighted (to avoid duplicates)
    
    for token in doc_watermarked:
        text = token.text_with_ws
        text_clean = token.text.strip('.,!?;:')
        text_lower = text_clean.lower()
        
        # Only highlight if this word is in our actual replacements set
        # and we haven't already highlighted this exact word
        if text_lower in actual_replacements and text_lower not in words_highlighted:
            original_word = replacement_to_original.get(text_lower, text_clean)
            
            # Only highlight if actually different from original
            if original_word.lower() != text_lower:
                tooltip = f"Original: {original_word} β†’ New: {text_clean}"
                # Enhanced highlighting with better colors
                highlighted_text = f"<mark style='background: linear-gradient(120deg, #84fab0 0%, #8fd3f4 100%); padding: 2px 6px; border-radius: 4px; font-weight: 500; box-shadow: 0 1px 2px rgba(0,0,0,0.1);' title='{tooltip}'>{text_clean}</mark>"
                
                # Preserve trailing whitespace and punctuation
                if text != text_clean:
                    highlighted_text += text[len(text_clean):]
                
                result_html.append(highlighted_text)
                words_highlighted.add(text_lower)  # Mark as highlighted
            else:
                result_html.append(text)
        else:
            result_html.append(text)
    
    # Return just the inner content without the outer div (added by caller)
    return "".join(result_html)

# Streamlit UI
def main():
    st.set_page_config(
        page_title="Watermarked Text Generator",
        page_icon="πŸ”’",
        layout="wide"
    )
    
    # Centered and styled title
    st.markdown(
        """
        <div style="text-align: center; margin-bottom: 10px;">
            <h1 style="color: #4A90E2; font-size: 2.5rem; font-weight: bold; margin: 0;">
                πŸ”’ SafeSeal Watermark
            </h1>
        </div>
        <div style="text-align: center; margin-bottom: 20px; color: #666; font-size: 1.1rem;">
            Content-Preserving Watermarking for Large Language Model Deployments.
        </div>
        """,
        unsafe_allow_html=True
    )
    
    # Add a nice separator
    st.markdown("---")
    
    # Sidebar for hyperparameters
    with st.sidebar:
        st.markdown("### βš™οΈ Hyperparameters")
        st.caption("Configure the watermarking algorithm")
        
        # Main inputs
        secret_key = st.text_input(
            "πŸ”‘ Secret Key",
            value="My_Secret_Key",
            help="Secret key for deterministic randomization"
        )
        
        threshold = st.slider(
            "πŸ“Š Similarity Threshold",
            min_value=0.0,
            max_value=1.0,
            value=0.98,
            step=0.01,
            help="BERTScore similarity threshold (higher = more similar replacements)"
        )
        
        st.divider()
        
        # Tournament Sampling parameters
        st.markdown("### πŸ† Tournament Sampling")
        st.caption("Control the randomization process")
        
        # Hidden Top_K parameter (default 6)
        Top_K = 6
        
        m = st.number_input(
            "m (Tournament Rounds)",
            min_value=1,
            max_value=20,
            value=10,
            help="Number of tournament rounds"
        )
        
        c = st.number_input(
            "c (Competitors per Round)",
            min_value=2,
            max_value=10,
            value=2,
            help="Number of competitors per tournament match"
        )
        
        h = st.number_input(
            "h (Context Size)",
            min_value=1,
            max_value=20,
            value=6,
            help="Number of left context tokens to consider"
        )
        
        alpha = st.slider(
            "Alpha (Temperature)",
            min_value=0.1,
            max_value=5.0,
            value=1.1,
            step=0.1,
            help="Temperature scaling factor for softmax"
        )
    
    # Main content area
    col1, col2 = st.columns(2)
    
    # Check if model is loaded
    if 'tokenizer' not in st.session_state:
        with st.spinner("Loading model... This may take a minute"):
            tokenizer, lm_model = load_model()
            st.session_state.tokenizer = tokenizer
            st.session_state.lm_model = lm_model
    
    with col1:
        st.markdown("### πŸ“ Input Text")
        input_text = st.text_area(
            "Enter text to watermark",
            height=400,
            placeholder="Paste your text here to generate a watermarked version...",
            label_visibility="collapsed"
        )
        
        # Process button at the bottom of input column
        if st.button("πŸš€ Generate Watermark", type="primary", use_container_width=True):
            if not input_text or len(input_text.strip()) == 0:
                st.warning("Please enter some text to watermark.")
            else:
                with st.spinner("Generating watermarked text... This may take a few moments"):
                    try:
                        # Process the text
                        watermarked_text, total_randomized_words, total_words, changed_words, sampling_results = process_text_wrapper(
                            input_text,
                            st.session_state.tokenizer,
                            st.session_state.lm_model,
                            Top_K=int(Top_K),
                            threshold=float(threshold),
                            secret_key=secret_key,
                            m=int(m),
                            c=int(c),
                            h=int(h),
                            alpha=float(alpha),
                            batch_size=32,
                            max_length=512,
                            similarity_context_mode='whole'
                        )
                        
                        # Store results in session state
                        st.session_state.watermarked_text = watermarked_text
                        st.session_state.changed_words = changed_words
                        st.session_state.sampling_results = sampling_results
                        st.session_state.total_randomized = total_randomized_words
                        st.session_state.total_words = total_words
                        
                        st.success(f"Watermark generated! Changed {total_randomized_words} out of {total_words} words ({100*total_randomized_words/max(total_words,1):.1f}%)")
                    except Exception as e:
                        st.error(f"Error generating watermark: {str(e)}")
                        import traceback
                        st.code(traceback.format_exc())
    
    with col2:
        st.markdown("### πŸ”’ Watermarked Text")
        
        # Display watermarked text with highlights
        if 'watermarked_text' in st.session_state:
            highlight_html = create_html_with_highlights(
                input_text,
                st.session_state.watermarked_text,
                st.session_state.changed_words,
                st.session_state.sampling_results
            )
            # Show highlighted version with border - wrap the complete HTML
            full_html = f"""
            <div style='padding: 15px; background-color: #f8f9fa; border-radius: 8px; border: 1px solid #e0e0e0; min-height: 400px; max-height: 400px; overflow-y: auto; line-height: 1.8; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; font-size: 15px; white-space: pre-wrap; word-wrap: break-word;'>
            {highlight_html}
            </div>
            """
            st.markdown(full_html, unsafe_allow_html=True)
        else:
            st.info("πŸ‘ˆ Enter text in the left panel and click 'Generate Watermark' to start")
    
    # Footer
    st.divider()
    st.caption("πŸ”’ Secure AI Watermarking Tool | Built with SafeSeal")
    
    # Demo warning at the bottom
    st.markdown(
        """
        <div style="text-align: center; margin-top: 20px; padding: 10px; font-size: 0.85rem; color: #666;">
            ⚠️ <strong>Demo Version</strong>: This is a demonstration using a light model to showcase the watermarking pipeline. 
            Results may not be perfect and are intended for testing purposes only.
        </div>
        """,
        unsafe_allow_html=True
    )

if __name__ == "__main__":
    # Run the app
    main()