Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import torch | |
| from transformers import RobertaTokenizer, RobertaForMaskedLM | |
| import spacy | |
| import subprocess | |
| import sys | |
| import nltk | |
| from nltk.tokenize import word_tokenize | |
| from utils_final import extract_entities_and_pos, whole_context_process_sentence | |
| # Download NLTK data if not available | |
| def setup_nltk(): | |
| """Setup NLTK data with error handling.""" | |
| try: | |
| nltk.download('punkt_tab', quiet=True) | |
| except: | |
| pass | |
| try: | |
| nltk.download('averaged_perceptron_tagger_eng', quiet=True) | |
| except: | |
| pass | |
| try: | |
| nltk.download('wordnet', quiet=True) | |
| except: | |
| pass | |
| try: | |
| nltk.download('omw-1.4', quiet=True) | |
| except: | |
| pass | |
| setup_nltk() | |
| # Set environment | |
| cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm' | |
| # Load spaCy model - download if not available | |
| try: | |
| nlp = spacy.load("en_core_web_sm") | |
| except OSError: | |
| print("Downloading spaCy model...") | |
| subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"]) | |
| nlp = spacy.load("en_core_web_sm") | |
| # Define apply_replacements function (from Safeseal_gen_final.py) | |
| def apply_replacements(sentence, replacements): | |
| """ | |
| Apply replacements to the sentence while preserving original formatting, spacing, and punctuation. | |
| """ | |
| doc = nlp(sentence) # Tokenize the sentence | |
| tokens = [token.text_with_ws for token in doc] # Preserve original whitespace with tokens | |
| # Apply replacements based on token positions | |
| for position, target, replacement in replacements: | |
| if position < len(tokens) and tokens[position].strip() == target: | |
| tokens[position] = replacement + (" " if tokens[position].endswith(" ") else "") | |
| # Reassemble the sentence | |
| return "".join(tokens) | |
| # Initialize session state for model caching | |
| def load_model(): | |
| """Load the model and tokenizer (cached to avoid reloading on every run)""" | |
| print("Loading model...") | |
| tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | |
| lm_model = RobertaForMaskedLM.from_pretrained('roberta-base', attn_implementation="eager") | |
| tokenizer.model_max_length = 512 | |
| tokenizer.max_len = 512 | |
| if hasattr(lm_model.config, 'max_position_embeddings'): | |
| lm_model.config.max_position_embeddings = 512 | |
| lm_model.eval() | |
| if torch.cuda.is_available(): | |
| lm_model = lm_model.cuda() | |
| print(f"Model loaded on GPU: {torch.cuda.get_device_name()}") | |
| else: | |
| print("Model loaded on CPU") | |
| return tokenizer, lm_model | |
| sampling_results = [] | |
| def process_text_wrapper(text, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, batch_size=32, max_length=512, similarity_context_mode='whole'): | |
| """ | |
| Wrapper function to process text and return watermarked output with tracking of changes. | |
| """ | |
| global sampling_results | |
| sampling_results = [] | |
| lines = text.splitlines(keepends=True) | |
| final_text = [] | |
| total_randomized_words = 0 | |
| total_words = len(word_tokenize(text)) | |
| # Track changed words and their positions | |
| changed_words = [] # List of (original, replacement, position) | |
| for line in lines: | |
| if line.strip(): | |
| replacements, sampling_results_line = whole_context_process_sentence( | |
| text, | |
| line.strip(), | |
| tokenizer, lm_model, Top_K, threshold, | |
| secret_key, m, c, h, alpha, "output", | |
| batch_size=batch_size, max_length=max_length, similarity_context_mode=similarity_context_mode | |
| ) | |
| sampling_results.extend(sampling_results_line) | |
| if replacements: | |
| randomized_line = apply_replacements(line, replacements) | |
| final_text.append(randomized_line) | |
| # Track ONLY actual changes (where original != replacement) | |
| for position, original, replacement in replacements: | |
| if original != replacement: | |
| changed_words.append((original, replacement, position)) | |
| total_randomized_words += 1 | |
| else: | |
| final_text.append(line) | |
| else: | |
| final_text.append(line) | |
| return "".join(final_text), total_randomized_words, total_words, changed_words, sampling_results | |
| def create_html_with_highlights(original_text, watermarked_text, changed_words_info, sampling_results): | |
| """ | |
| Create HTML with highlighted changed words using spaCy tokenization. | |
| """ | |
| # Create a set of replacement words that were actually changed (not same as original) | |
| actual_replacements = set() | |
| replacement_to_original = {} | |
| for original, replacement, _ in changed_words_info: | |
| if original.lower() != replacement.lower(): # Only map actual changes | |
| actual_replacements.add(replacement.lower()) | |
| replacement_to_original[replacement.lower()] = original | |
| # Parse watermarked text with spaCy | |
| doc_watermarked = nlp(watermarked_text) | |
| # Build HTML by processing the watermarked text | |
| result_html = [] | |
| words_highlighted = set() # Track which words we've highlighted (to avoid duplicates) | |
| for token in doc_watermarked: | |
| text = token.text_with_ws | |
| text_clean = token.text.strip('.,!?;:') | |
| text_lower = text_clean.lower() | |
| # Only highlight if this word is in our actual replacements set | |
| # and we haven't already highlighted this exact word | |
| if text_lower in actual_replacements and text_lower not in words_highlighted: | |
| original_word = replacement_to_original.get(text_lower, text_clean) | |
| # Only highlight if actually different from original | |
| if original_word.lower() != text_lower: | |
| tooltip = f"Original: {original_word} β New: {text_clean}" | |
| # Enhanced highlighting with better colors | |
| highlighted_text = f"<mark style='background: linear-gradient(120deg, #84fab0 0%, #8fd3f4 100%); padding: 2px 6px; border-radius: 4px; font-weight: 500; box-shadow: 0 1px 2px rgba(0,0,0,0.1);' title='{tooltip}'>{text_clean}</mark>" | |
| # Preserve trailing whitespace and punctuation | |
| if text != text_clean: | |
| highlighted_text += text[len(text_clean):] | |
| result_html.append(highlighted_text) | |
| words_highlighted.add(text_lower) # Mark as highlighted | |
| else: | |
| result_html.append(text) | |
| else: | |
| result_html.append(text) | |
| # Return just the inner content without the outer div (added by caller) | |
| return "".join(result_html) | |
| # Streamlit UI | |
| def main(): | |
| st.set_page_config( | |
| page_title="Watermarked Text Generator", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| # Centered and styled title | |
| st.markdown( | |
| """ | |
| <div style="text-align: center; margin-bottom: 10px;"> | |
| <h1 style="color: #4A90E2; font-size: 2.5rem; font-weight: bold; margin: 0;"> | |
| π SafeSeal Watermark | |
| </h1> | |
| </div> | |
| <div style="text-align: center; margin-bottom: 20px; color: #666; font-size: 1.1rem;"> | |
| Content-Preserving Watermarking for Large Language Model Deployments. | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Add a nice separator | |
| st.markdown("---") | |
| # Sidebar for hyperparameters | |
| with st.sidebar: | |
| st.markdown("### βοΈ Hyperparameters") | |
| st.caption("Configure the watermarking algorithm") | |
| # Main inputs | |
| secret_key = st.text_input( | |
| "π Secret Key", | |
| value="My_Secret_Key", | |
| help="Secret key for deterministic randomization" | |
| ) | |
| threshold = st.slider( | |
| "π Similarity Threshold", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.98, | |
| step=0.01, | |
| help="BERTScore similarity threshold (higher = more similar replacements)" | |
| ) | |
| st.divider() | |
| # Tournament Sampling parameters | |
| st.markdown("### π Tournament Sampling") | |
| st.caption("Control the randomization process") | |
| # Hidden Top_K parameter (default 6) | |
| Top_K = 6 | |
| m = st.number_input( | |
| "m (Tournament Rounds)", | |
| min_value=1, | |
| max_value=20, | |
| value=10, | |
| help="Number of tournament rounds" | |
| ) | |
| c = st.number_input( | |
| "c (Competitors per Round)", | |
| min_value=2, | |
| max_value=10, | |
| value=2, | |
| help="Number of competitors per tournament match" | |
| ) | |
| h = st.number_input( | |
| "h (Context Size)", | |
| min_value=1, | |
| max_value=20, | |
| value=6, | |
| help="Number of left context tokens to consider" | |
| ) | |
| alpha = st.slider( | |
| "Alpha (Temperature)", | |
| min_value=0.1, | |
| max_value=5.0, | |
| value=1.1, | |
| step=0.1, | |
| help="Temperature scaling factor for softmax" | |
| ) | |
| # Main content area | |
| col1, col2 = st.columns(2) | |
| # Check if model is loaded | |
| if 'tokenizer' not in st.session_state: | |
| with st.spinner("Loading model... This may take a minute"): | |
| tokenizer, lm_model = load_model() | |
| st.session_state.tokenizer = tokenizer | |
| st.session_state.lm_model = lm_model | |
| with col1: | |
| st.markdown("### π Input Text") | |
| input_text = st.text_area( | |
| "Enter text to watermark", | |
| height=400, | |
| placeholder="Paste your text here to generate a watermarked version...", | |
| label_visibility="collapsed" | |
| ) | |
| # Process button at the bottom of input column | |
| if st.button("π Generate Watermark", type="primary", use_container_width=True): | |
| if not input_text or len(input_text.strip()) == 0: | |
| st.warning("Please enter some text to watermark.") | |
| else: | |
| with st.spinner("Generating watermarked text... This may take a few moments"): | |
| try: | |
| # Process the text | |
| watermarked_text, total_randomized_words, total_words, changed_words, sampling_results = process_text_wrapper( | |
| input_text, | |
| st.session_state.tokenizer, | |
| st.session_state.lm_model, | |
| Top_K=int(Top_K), | |
| threshold=float(threshold), | |
| secret_key=secret_key, | |
| m=int(m), | |
| c=int(c), | |
| h=int(h), | |
| alpha=float(alpha), | |
| batch_size=32, | |
| max_length=512, | |
| similarity_context_mode='whole' | |
| ) | |
| # Store results in session state | |
| st.session_state.watermarked_text = watermarked_text | |
| st.session_state.changed_words = changed_words | |
| st.session_state.sampling_results = sampling_results | |
| st.session_state.total_randomized = total_randomized_words | |
| st.session_state.total_words = total_words | |
| st.success(f"Watermark generated! Changed {total_randomized_words} out of {total_words} words ({100*total_randomized_words/max(total_words,1):.1f}%)") | |
| except Exception as e: | |
| st.error(f"Error generating watermark: {str(e)}") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| with col2: | |
| st.markdown("### π Watermarked Text") | |
| # Display watermarked text with highlights | |
| if 'watermarked_text' in st.session_state: | |
| highlight_html = create_html_with_highlights( | |
| input_text, | |
| st.session_state.watermarked_text, | |
| st.session_state.changed_words, | |
| st.session_state.sampling_results | |
| ) | |
| # Show highlighted version with border - wrap the complete HTML | |
| full_html = f""" | |
| <div style='padding: 15px; background-color: #f8f9fa; border-radius: 8px; border: 1px solid #e0e0e0; min-height: 400px; max-height: 400px; overflow-y: auto; line-height: 1.8; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; font-size: 15px; white-space: pre-wrap; word-wrap: break-word;'> | |
| {highlight_html} | |
| </div> | |
| """ | |
| st.markdown(full_html, unsafe_allow_html=True) | |
| else: | |
| st.info("π Enter text in the left panel and click 'Generate Watermark' to start") | |
| # Footer | |
| st.divider() | |
| st.caption("π Secure AI Watermarking Tool | Built with SafeSeal") | |
| # Demo warning at the bottom | |
| st.markdown( | |
| """ | |
| <div style="text-align: center; margin-top: 20px; padding: 10px; font-size: 0.85rem; color: #666;"> | |
| β οΈ <strong>Demo Version</strong>: This is a demonstration using a light model to showcase the watermarking pipeline. | |
| Results may not be perfect and are intended for testing purposes only. | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| if __name__ == "__main__": | |
| # Run the app | |
| main() | |