import streamlit as st import os import torch from transformers import RobertaTokenizer, RobertaForMaskedLM import spacy import subprocess import sys import nltk from nltk.tokenize import word_tokenize from utils_final import extract_entities_and_pos, whole_context_process_sentence # Download NLTK data if not available def setup_nltk(): """Setup NLTK data with error handling.""" try: nltk.download('punkt_tab', quiet=True) except: pass try: nltk.download('averaged_perceptron_tagger_eng', quiet=True) except: pass try: nltk.download('wordnet', quiet=True) except: pass try: nltk.download('omw-1.4', quiet=True) except: pass setup_nltk() # Set environment cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm' # Load spaCy model - download if not available try: nlp = spacy.load("en_core_web_sm") except OSError: print("Downloading spaCy model...") subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"]) nlp = spacy.load("en_core_web_sm") # Define apply_replacements function (from Safeseal_gen_final.py) def apply_replacements(sentence, replacements): """ Apply replacements to the sentence while preserving original formatting, spacing, and punctuation. """ doc = nlp(sentence) # Tokenize the sentence tokens = [token.text_with_ws for token in doc] # Preserve original whitespace with tokens # Apply replacements based on token positions for position, target, replacement in replacements: if position < len(tokens) and tokens[position].strip() == target: tokens[position] = replacement + (" " if tokens[position].endswith(" ") else "") # Reassemble the sentence return "".join(tokens) # Initialize session state for model caching @st.cache_resource def load_model(): """Load the model and tokenizer (cached to avoid reloading on every run)""" print("Loading model...") tokenizer = RobertaTokenizer.from_pretrained('roberta-base') lm_model = RobertaForMaskedLM.from_pretrained('roberta-base', attn_implementation="eager") tokenizer.model_max_length = 512 tokenizer.max_len = 512 if hasattr(lm_model.config, 'max_position_embeddings'): lm_model.config.max_position_embeddings = 512 lm_model.eval() if torch.cuda.is_available(): lm_model = lm_model.cuda() print(f"Model loaded on GPU: {torch.cuda.get_device_name()}") else: print("Model loaded on CPU") return tokenizer, lm_model sampling_results = [] def process_text_wrapper(text, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, batch_size=32, max_length=512, similarity_context_mode='whole'): """ Wrapper function to process text and return watermarked output with tracking of changes. """ global sampling_results sampling_results = [] lines = text.splitlines(keepends=True) final_text = [] total_randomized_words = 0 total_words = len(word_tokenize(text)) # Track changed words and their positions changed_words = [] # List of (original, replacement, position) for line in lines: if line.strip(): replacements, sampling_results_line = whole_context_process_sentence( text, line.strip(), tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, "output", batch_size=batch_size, max_length=max_length, similarity_context_mode=similarity_context_mode ) sampling_results.extend(sampling_results_line) if replacements: randomized_line = apply_replacements(line, replacements) final_text.append(randomized_line) # Track ONLY actual changes (where original != replacement) for position, original, replacement in replacements: if original != replacement: changed_words.append((original, replacement, position)) total_randomized_words += 1 else: final_text.append(line) else: final_text.append(line) return "".join(final_text), total_randomized_words, total_words, changed_words, sampling_results def create_html_with_highlights(original_text, watermarked_text, changed_words_info, sampling_results): """ Create HTML with highlighted changed words using spaCy tokenization. """ # Create a set of replacement words that were actually changed (not same as original) actual_replacements = set() replacement_to_original = {} for original, replacement, _ in changed_words_info: if original.lower() != replacement.lower(): # Only map actual changes actual_replacements.add(replacement.lower()) replacement_to_original[replacement.lower()] = original # Parse watermarked text with spaCy doc_watermarked = nlp(watermarked_text) # Build HTML by processing the watermarked text result_html = [] words_highlighted = set() # Track which words we've highlighted (to avoid duplicates) for token in doc_watermarked: text = token.text_with_ws text_clean = token.text.strip('.,!?;:') text_lower = text_clean.lower() # Only highlight if this word is in our actual replacements set # and we haven't already highlighted this exact word if text_lower in actual_replacements and text_lower not in words_highlighted: original_word = replacement_to_original.get(text_lower, text_clean) # Only highlight if actually different from original if original_word.lower() != text_lower: tooltip = f"Original: {original_word} → New: {text_clean}" # Enhanced highlighting with better colors highlighted_text = f"{text_clean}" # Preserve trailing whitespace and punctuation if text != text_clean: highlighted_text += text[len(text_clean):] result_html.append(highlighted_text) words_highlighted.add(text_lower) # Mark as highlighted else: result_html.append(text) else: result_html.append(text) # Return just the inner content without the outer div (added by caller) return "".join(result_html) # Streamlit UI def main(): st.set_page_config( page_title="Watermarked Text Generator", page_icon="🔒", layout="wide" ) # Centered and styled title st.markdown( """

🔒 SafeSeal Watermark

Content-Preserving Watermarking for Large Language Model Deployments.
""", unsafe_allow_html=True ) # Add a nice separator st.markdown("---") # Sidebar for hyperparameters with st.sidebar: st.markdown("### ⚙️ Hyperparameters") st.caption("Configure the watermarking algorithm") # Main inputs secret_key = st.text_input( "🔑 Secret Key", value="My_Secret_Key", help="Secret key for deterministic randomization" ) threshold = st.slider( "📊 Similarity Threshold", min_value=0.0, max_value=1.0, value=0.98, step=0.01, help="BERTScore similarity threshold (higher = more similar replacements)" ) st.divider() # Tournament Sampling parameters st.markdown("### 🏆 Tournament Sampling") st.caption("Control the randomization process") # Hidden Top_K parameter (default 6) Top_K = 6 m = st.number_input( "m (Tournament Rounds)", min_value=1, max_value=20, value=10, help="Number of tournament rounds" ) c = st.number_input( "c (Competitors per Round)", min_value=2, max_value=10, value=2, help="Number of competitors per tournament match" ) h = st.number_input( "h (Context Size)", min_value=1, max_value=20, value=6, help="Number of left context tokens to consider" ) alpha = st.slider( "Alpha (Temperature)", min_value=0.1, max_value=5.0, value=1.1, step=0.1, help="Temperature scaling factor for softmax" ) # Main content area col1, col2 = st.columns(2) # Check if model is loaded if 'tokenizer' not in st.session_state: with st.spinner("Loading model... This may take a minute"): tokenizer, lm_model = load_model() st.session_state.tokenizer = tokenizer st.session_state.lm_model = lm_model with col1: st.markdown("### 📝 Input Text") input_text = st.text_area( "Enter text to watermark", height=400, placeholder="Paste your text here to generate a watermarked version...", label_visibility="collapsed" ) # Process button at the bottom of input column if st.button("🚀 Generate Watermark", type="primary", use_container_width=True): if not input_text or len(input_text.strip()) == 0: st.warning("Please enter some text to watermark.") else: with st.spinner("Generating watermarked text... This may take a few moments"): try: # Process the text watermarked_text, total_randomized_words, total_words, changed_words, sampling_results = process_text_wrapper( input_text, st.session_state.tokenizer, st.session_state.lm_model, Top_K=int(Top_K), threshold=float(threshold), secret_key=secret_key, m=int(m), c=int(c), h=int(h), alpha=float(alpha), batch_size=32, max_length=512, similarity_context_mode='whole' ) # Store results in session state st.session_state.watermarked_text = watermarked_text st.session_state.changed_words = changed_words st.session_state.sampling_results = sampling_results st.session_state.total_randomized = total_randomized_words st.session_state.total_words = total_words st.success(f"Watermark generated! Changed {total_randomized_words} out of {total_words} words ({100*total_randomized_words/max(total_words,1):.1f}%)") except Exception as e: st.error(f"Error generating watermark: {str(e)}") import traceback st.code(traceback.format_exc()) with col2: st.markdown("### 🔒 Watermarked Text") # Display watermarked text with highlights if 'watermarked_text' in st.session_state: highlight_html = create_html_with_highlights( input_text, st.session_state.watermarked_text, st.session_state.changed_words, st.session_state.sampling_results ) # Show highlighted version with border - wrap the complete HTML full_html = f"""
{highlight_html}
""" st.markdown(full_html, unsafe_allow_html=True) else: st.info("👈 Enter text in the left panel and click 'Generate Watermark' to start") # Footer st.divider() st.caption("🔒 Secure AI Watermarking Tool | Built with SafeSeal") # Demo warning at the bottom st.markdown( """
⚠️ Demo Version: This is a demonstration using a light model to showcase the watermarking pipeline. Results may not be perfect and are intended for testing purposes only.
""", unsafe_allow_html=True ) if __name__ == "__main__": # Run the app main()