yasserrmd's picture
Update app.py
75f1c2d verified
import gradio as gr
import torch
import transformers
import spaces
from synthid_text import synthid_mixin,logits_processing
from synthid_text.detector_mean import mean_score
# Configurations and model selection
MODEL_NAME = "google/gemma-7b-it"
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
TOP_K = 40
TOP_P = 0.99
TEMPERATURE= 0.5
# Initialize model and tokenizer
# model = synthid_mixin.SynthIDGemmaForCausalLM.from_pretrained(
# MODEL_NAME,
# device_map=DEVICE,
# torch_dtype=torch.bfloat16,
# )
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Watermarking configuration
CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG
# Function to check for AI-generated content using SynthID and highlight watermark
@spaces.GPU
def check_plagiarism(text):
# Logits processor for SynthID
logits_processor = logits_processing.SynthIDLogitsProcessor(
**CONFIG, top_k=40, temperature=0.5
)
# Tokenize and process the input text
inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
# Extract token IDs for the input text only
input_ids = inputs['input_ids']
# Compute masks for watermark detection
eos_token_mask = logits_processor.compute_eos_token_mask(
input_ids=input_ids,
eos_token_id=tokenizer.eos_token_id,
)[:, CONFIG['ngram_len'] - 1:]
context_repetition_mask = logits_processor.compute_context_repetition_mask(
input_ids=input_ids
)
# Combine the masks
combined_mask = context_repetition_mask * eos_token_mask
# Compute G values for the input text
g_values = logits_processor.compute_g_values(input_ids=input_ids)
# Score the G values with the combined mask
score = mean_score(g_values.cpu().numpy(), combined_mask.cpu().numpy())
# Initialize string to store highlighted output
highlighted_text = ""
# Loop through each token in the input text and apply highlighting if it meets the watermark criteria
for token_id, g_val, mask in zip(input_ids[0], g_values[0], combined_mask[0]):
token_text = tokenizer.decode(token_id.unsqueeze(0))
# Convert g_val to float and highlight if it meets the threshold
if mask.item() and g_val.float().mean().item() > 0.55:
highlighted_text += f"<mark>{token_text}</mark>" # Highlight watermarked content
else:
highlighted_text += token_text
# Return the highlighted text and overall watermark score
if score > 0.5:
return f"Flagged as AI-generated content (Academic Integrity Warning): {highlighted_text}"
else:
return f"Content appears to be human-generated. {highlighted_text}"
# Define the Gradio interface
def create_plagiarism_checker():
with gr.Blocks() as app:
# Add custom CSS styling using gr.HTML
gr.HTML("""
<style>
#text_input { font-size: 16px; border: 1px solid #ddd; padding: 8px; }
#output { font-size: 16px; padding: 8px; border-radius: 5px; }
#check_button { font-size: 16px; background-color: #4CAF50; color: white; border: none; padding: 10px 20px; cursor: pointer; }
#check_button:hover { background-color: #45a049; }
</style>
""")
# Title and description
gr.Markdown("""
# πŸ“ Plagiarism and Academic Integrity Checker
Use this tool to detect AI-generated content in your text using SynthID technology.
Paste your text below to check if it contains AI-generated segments.
---
""")
# Layout the components
with gr.Row():
# Input textbox for users to paste text
text_input = gr.Textbox(
placeholder="Paste your text here...",
label="Input Text",
lines=10,
max_lines=20,
elem_id="text_input",
)
# Divider for clarity
gr.Markdown("---")
# Output box to display the result with highlighted watermark
output = gr.HTML(label="Integrity Check Result", elem_id="output")
# Button to initiate the check, styled with a color accent
check_button = gr.Button("πŸ” Check Text", elem_id="check_button")
# Define the click event for the button
check_button.click(fn=check_plagiarism, inputs=text_input, outputs=output)
return app
# Launch the app
plagiarism_checker_app = create_plagiarism_checker()
plagiarism_checker_app.launch()