SafeSeal / app.py
kirudang's picture
Sync SafeSeal app
fc6dcab
import streamlit as st
import os
import torch
from transformers import RobertaTokenizer, RobertaForMaskedLM
import spacy
import subprocess
import sys
import nltk
from nltk.tokenize import word_tokenize
from utils_final import extract_entities_and_pos, whole_context_process_sentence
# Download NLTK data if not available
def setup_nltk():
"""Setup NLTK data with error handling."""
try:
nltk.download('punkt_tab', quiet=True)
except:
pass
try:
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
except:
pass
try:
nltk.download('wordnet', quiet=True)
except:
pass
try:
nltk.download('omw-1.4', quiet=True)
except:
pass
setup_nltk()
# Set environment
cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm'
# Load spaCy model - download if not available
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
print("Downloading spaCy model...")
subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
nlp = spacy.load("en_core_web_sm")
# Define apply_replacements function (from Safeseal_gen_final.py)
def apply_replacements(sentence, replacements):
"""
Apply replacements to the sentence while preserving original formatting, spacing, and punctuation.
"""
doc = nlp(sentence) # Tokenize the sentence
tokens = [token.text_with_ws for token in doc] # Preserve original whitespace with tokens
# Apply replacements based on token positions
for position, target, replacement in replacements:
if position < len(tokens) and tokens[position].strip() == target:
tokens[position] = replacement + (" " if tokens[position].endswith(" ") else "")
# Reassemble the sentence
return "".join(tokens)
# Initialize session state for model caching
@st.cache_resource
def load_model():
"""Load the model and tokenizer (cached to avoid reloading on every run)"""
print("Loading model...")
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
lm_model = RobertaForMaskedLM.from_pretrained('roberta-base', attn_implementation="eager")
tokenizer.model_max_length = 512
tokenizer.max_len = 512
if hasattr(lm_model.config, 'max_position_embeddings'):
lm_model.config.max_position_embeddings = 512
lm_model.eval()
if torch.cuda.is_available():
lm_model = lm_model.cuda()
print(f"Model loaded on GPU: {torch.cuda.get_device_name()}")
else:
print("Model loaded on CPU")
return tokenizer, lm_model
sampling_results = []
def process_text_wrapper(text, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, batch_size=32, max_length=512, similarity_context_mode='whole'):
"""
Wrapper function to process text and return watermarked output with tracking of changes.
"""
global sampling_results
sampling_results = []
lines = text.splitlines(keepends=True)
final_text = []
total_randomized_words = 0
total_words = len(word_tokenize(text))
# Track changed words and their positions
changed_words = [] # List of (original, replacement, position)
for line in lines:
if line.strip():
replacements, sampling_results_line = whole_context_process_sentence(
text,
line.strip(),
tokenizer, lm_model, Top_K, threshold,
secret_key, m, c, h, alpha, "output",
batch_size=batch_size, max_length=max_length, similarity_context_mode=similarity_context_mode
)
sampling_results.extend(sampling_results_line)
if replacements:
randomized_line = apply_replacements(line, replacements)
final_text.append(randomized_line)
# Track ONLY actual changes (where original != replacement)
for position, original, replacement in replacements:
if original != replacement:
changed_words.append((original, replacement, position))
total_randomized_words += 1
else:
final_text.append(line)
else:
final_text.append(line)
return "".join(final_text), total_randomized_words, total_words, changed_words, sampling_results
def create_html_with_highlights(original_text, watermarked_text, changed_words_info, sampling_results):
"""
Create HTML with highlighted changed words using spaCy tokenization.
"""
# Create a set of replacement words that were actually changed (not same as original)
actual_replacements = set()
replacement_to_original = {}
for original, replacement, _ in changed_words_info:
if original.lower() != replacement.lower(): # Only map actual changes
actual_replacements.add(replacement.lower())
replacement_to_original[replacement.lower()] = original
# Parse watermarked text with spaCy
doc_watermarked = nlp(watermarked_text)
# Build HTML by processing the watermarked text
result_html = []
words_highlighted = set() # Track which words we've highlighted (to avoid duplicates)
for token in doc_watermarked:
text = token.text_with_ws
text_clean = token.text.strip('.,!?;:')
text_lower = text_clean.lower()
# Only highlight if this word is in our actual replacements set
# and we haven't already highlighted this exact word
if text_lower in actual_replacements and text_lower not in words_highlighted:
original_word = replacement_to_original.get(text_lower, text_clean)
# Only highlight if actually different from original
if original_word.lower() != text_lower:
tooltip = f"Original: {original_word} β†’ New: {text_clean}"
# Enhanced highlighting with better colors
highlighted_text = f"<mark style='background: linear-gradient(120deg, #84fab0 0%, #8fd3f4 100%); padding: 2px 6px; border-radius: 4px; font-weight: 500; box-shadow: 0 1px 2px rgba(0,0,0,0.1);' title='{tooltip}'>{text_clean}</mark>"
# Preserve trailing whitespace and punctuation
if text != text_clean:
highlighted_text += text[len(text_clean):]
result_html.append(highlighted_text)
words_highlighted.add(text_lower) # Mark as highlighted
else:
result_html.append(text)
else:
result_html.append(text)
# Return just the inner content without the outer div (added by caller)
return "".join(result_html)
# Streamlit UI
def main():
st.set_page_config(
page_title="Watermarked Text Generator",
page_icon="πŸ”’",
layout="wide"
)
# Centered and styled title
st.markdown(
"""
<div style="text-align: center; margin-bottom: 10px;">
<h1 style="color: #4A90E2; font-size: 2.5rem; font-weight: bold; margin: 0;">
πŸ”’ SafeSeal Watermark
</h1>
</div>
<div style="text-align: center; margin-bottom: 20px; color: #666; font-size: 1.1rem;">
Content-Preserving Watermarking for Large Language Model Deployments.
</div>
""",
unsafe_allow_html=True
)
# Add a nice separator
st.markdown("---")
# Sidebar for hyperparameters
with st.sidebar:
st.markdown("### βš™οΈ Hyperparameters")
st.caption("Configure the watermarking algorithm")
# Main inputs
secret_key = st.text_input(
"πŸ”‘ Secret Key",
value="My_Secret_Key",
help="Secret key for deterministic randomization"
)
threshold = st.slider(
"πŸ“Š Similarity Threshold",
min_value=0.0,
max_value=1.0,
value=0.98,
step=0.01,
help="BERTScore similarity threshold (higher = more similar replacements)"
)
st.divider()
# Tournament Sampling parameters
st.markdown("### πŸ† Tournament Sampling")
st.caption("Control the randomization process")
# Hidden Top_K parameter (default 6)
Top_K = 6
m = st.number_input(
"m (Tournament Rounds)",
min_value=1,
max_value=20,
value=10,
help="Number of tournament rounds"
)
c = st.number_input(
"c (Competitors per Round)",
min_value=2,
max_value=10,
value=2,
help="Number of competitors per tournament match"
)
h = st.number_input(
"h (Context Size)",
min_value=1,
max_value=20,
value=6,
help="Number of left context tokens to consider"
)
alpha = st.slider(
"Alpha (Temperature)",
min_value=0.1,
max_value=5.0,
value=1.1,
step=0.1,
help="Temperature scaling factor for softmax"
)
# Main content area
col1, col2 = st.columns(2)
# Check if model is loaded
if 'tokenizer' not in st.session_state:
with st.spinner("Loading model... This may take a minute"):
tokenizer, lm_model = load_model()
st.session_state.tokenizer = tokenizer
st.session_state.lm_model = lm_model
with col1:
st.markdown("### πŸ“ Input Text")
input_text = st.text_area(
"Enter text to watermark",
height=400,
placeholder="Paste your text here to generate a watermarked version...",
label_visibility="collapsed"
)
# Process button at the bottom of input column
if st.button("πŸš€ Generate Watermark", type="primary", use_container_width=True):
if not input_text or len(input_text.strip()) == 0:
st.warning("Please enter some text to watermark.")
else:
with st.spinner("Generating watermarked text... This may take a few moments"):
try:
# Process the text
watermarked_text, total_randomized_words, total_words, changed_words, sampling_results = process_text_wrapper(
input_text,
st.session_state.tokenizer,
st.session_state.lm_model,
Top_K=int(Top_K),
threshold=float(threshold),
secret_key=secret_key,
m=int(m),
c=int(c),
h=int(h),
alpha=float(alpha),
batch_size=32,
max_length=512,
similarity_context_mode='whole'
)
# Store results in session state
st.session_state.watermarked_text = watermarked_text
st.session_state.changed_words = changed_words
st.session_state.sampling_results = sampling_results
st.session_state.total_randomized = total_randomized_words
st.session_state.total_words = total_words
st.success(f"Watermark generated! Changed {total_randomized_words} out of {total_words} words ({100*total_randomized_words/max(total_words,1):.1f}%)")
except Exception as e:
st.error(f"Error generating watermark: {str(e)}")
import traceback
st.code(traceback.format_exc())
with col2:
st.markdown("### πŸ”’ Watermarked Text")
# Display watermarked text with highlights
if 'watermarked_text' in st.session_state:
highlight_html = create_html_with_highlights(
input_text,
st.session_state.watermarked_text,
st.session_state.changed_words,
st.session_state.sampling_results
)
# Show highlighted version with border - wrap the complete HTML
full_html = f"""
<div style='padding: 15px; background-color: #f8f9fa; border-radius: 8px; border: 1px solid #e0e0e0; min-height: 400px; max-height: 400px; overflow-y: auto; line-height: 1.8; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; font-size: 15px; white-space: pre-wrap; word-wrap: break-word;'>
{highlight_html}
</div>
"""
st.markdown(full_html, unsafe_allow_html=True)
else:
st.info("πŸ‘ˆ Enter text in the left panel and click 'Generate Watermark' to start")
# Footer
st.divider()
st.caption("πŸ”’ Secure AI Watermarking Tool | Built with SafeSeal")
# Demo warning at the bottom
st.markdown(
"""
<div style="text-align: center; margin-top: 20px; padding: 10px; font-size: 0.85rem; color: #666;">
⚠️ <strong>Demo Version</strong>: This is a demonstration using a light model to showcase the watermarking pipeline.
Results may not be perfect and are intended for testing purposes only.
</div>
""",
unsafe_allow_html=True
)
if __name__ == "__main__":
# Run the app
main()