import streamlit as st import os import torch from transformers import RobertaTokenizer, RobertaForMaskedLM import spacy import subprocess import sys import nltk from nltk.tokenize import word_tokenize from utils_final import extract_entities_and_pos, whole_context_process_sentence # Download NLTK data if not available def setup_nltk(): """Setup NLTK data with error handling.""" try: nltk.download('punkt_tab', quiet=True) except: pass try: nltk.download('averaged_perceptron_tagger_eng', quiet=True) except: pass try: nltk.download('wordnet', quiet=True) except: pass try: nltk.download('omw-1.4', quiet=True) except: pass setup_nltk() # Set environment cache_dir = '/network/rit/lab/Lai_ReSecureAI/kiel/wmm' # Load spaCy model - download if not available try: nlp = spacy.load("en_core_web_sm") except OSError: print("Downloading spaCy model...") subprocess.check_call([sys.executable, "-m", "spacy", "download", "en_core_web_sm"]) nlp = spacy.load("en_core_web_sm") # Define apply_replacements function (from Safeseal_gen_final.py) def apply_replacements(sentence, replacements): """ Apply replacements to the sentence while preserving original formatting, spacing, and punctuation. """ doc = nlp(sentence) # Tokenize the sentence tokens = [token.text_with_ws for token in doc] # Preserve original whitespace with tokens # Apply replacements based on token positions for position, target, replacement in replacements: if position < len(tokens) and tokens[position].strip() == target: tokens[position] = replacement + (" " if tokens[position].endswith(" ") else "") # Reassemble the sentence return "".join(tokens) # Initialize session state for model caching @st.cache_resource def load_model(): """Load the model and tokenizer (cached to avoid reloading on every run)""" print("Loading model...") tokenizer = RobertaTokenizer.from_pretrained('roberta-base') lm_model = RobertaForMaskedLM.from_pretrained('roberta-base', attn_implementation="eager") tokenizer.model_max_length = 512 tokenizer.max_len = 512 if hasattr(lm_model.config, 'max_position_embeddings'): lm_model.config.max_position_embeddings = 512 lm_model.eval() if torch.cuda.is_available(): lm_model = lm_model.cuda() print(f"Model loaded on GPU: {torch.cuda.get_device_name()}") else: print("Model loaded on CPU") return tokenizer, lm_model sampling_results = [] def process_text_wrapper(text, tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, batch_size=32, max_length=512, similarity_context_mode='whole'): """ Wrapper function to process text and return watermarked output with tracking of changes. """ global sampling_results sampling_results = [] lines = text.splitlines(keepends=True) final_text = [] total_randomized_words = 0 total_words = len(word_tokenize(text)) # Track changed words and their positions changed_words = [] # List of (original, replacement, position) for line in lines: if line.strip(): replacements, sampling_results_line = whole_context_process_sentence( text, line.strip(), tokenizer, lm_model, Top_K, threshold, secret_key, m, c, h, alpha, "output", batch_size=batch_size, max_length=max_length, similarity_context_mode=similarity_context_mode ) sampling_results.extend(sampling_results_line) if replacements: randomized_line = apply_replacements(line, replacements) final_text.append(randomized_line) # Track ONLY actual changes (where original != replacement) for position, original, replacement in replacements: if original != replacement: changed_words.append((original, replacement, position)) total_randomized_words += 1 else: final_text.append(line) else: final_text.append(line) return "".join(final_text), total_randomized_words, total_words, changed_words, sampling_results def create_html_with_highlights(original_text, watermarked_text, changed_words_info, sampling_results): """ Create HTML with highlighted changed words using spaCy tokenization. """ # Create a set of replacement words that were actually changed (not same as original) actual_replacements = set() replacement_to_original = {} for original, replacement, _ in changed_words_info: if original.lower() != replacement.lower(): # Only map actual changes actual_replacements.add(replacement.lower()) replacement_to_original[replacement.lower()] = original # Parse watermarked text with spaCy doc_watermarked = nlp(watermarked_text) # Build HTML by processing the watermarked text result_html = [] words_highlighted = set() # Track which words we've highlighted (to avoid duplicates) for token in doc_watermarked: text = token.text_with_ws text_clean = token.text.strip('.,!?;:') text_lower = text_clean.lower() # Only highlight if this word is in our actual replacements set # and we haven't already highlighted this exact word if text_lower in actual_replacements and text_lower not in words_highlighted: original_word = replacement_to_original.get(text_lower, text_clean) # Only highlight if actually different from original if original_word.lower() != text_lower: tooltip = f"Original: {original_word} → New: {text_clean}" # Enhanced highlighting with better colors highlighted_text = f"{text_clean}" # Preserve trailing whitespace and punctuation if text != text_clean: highlighted_text += text[len(text_clean):] result_html.append(highlighted_text) words_highlighted.add(text_lower) # Mark as highlighted else: result_html.append(text) else: result_html.append(text) # Return just the inner content without the outer div (added by caller) return "".join(result_html) # Streamlit UI def main(): st.set_page_config( page_title="Watermarked Text Generator", page_icon="🔒", layout="wide" ) # Centered and styled title st.markdown( """