yasserrmd commited on
Commit
4e7fa9f
·
verified ·
1 Parent(s): 602cd0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -28
app.py CHANGED
@@ -29,45 +29,58 @@ def check_plagiarism(text):
29
 
30
  # Tokenize and process the input text
31
  inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
32
-
 
33
  # Generate output with model, capturing scores (logits)
34
  with torch.no_grad():
35
  outputs = model.generate(
36
  inputs['input_ids'],
37
- max_length=inputs['input_ids'].shape[1] + 50, # Generate up to 50 additional tokens
38
  output_scores=True,
39
  return_dict_in_generate=True
40
  )
41
 
42
- # Initialize empty string to store highlighted output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  highlighted_text = ""
44
- is_watermarked = False
45
-
46
- try:
47
- # Extract generated tokens and their scores
48
- generated_tokens = outputs.sequences[0]
49
- token_scores = outputs.scores
50
-
51
- # Loop through each generated token and its corresponding score
52
- for token_id, score in zip(generated_tokens, token_scores):
53
- # Apply SynthIDLogitsProcessor to each score by calling it with 'scores=score'
54
- processed_score = logits_processor(scores=score)
55
- token_text = tokenizer.decode(token_id.unsqueeze(0)) # Decode token_id for individual token text
56
-
57
- # If processed score indicates watermark, highlight this token
58
- if processed_score.mean().item() > 0.5:
59
- is_watermarked = True
60
- highlighted_text += f"<mark>{token_text}</mark>" # Highlight AI-generated content
61
- else:
62
- highlighted_text += token_text
63
-
64
- if is_watermarked:
65
- return f"Flagged as AI-generated content (Academic Integrity Warning): {highlighted_text}"
66
  else:
67
- return "Content appears to be human-generated."
 
 
 
 
 
 
68
 
69
- except Exception as e:
70
- return f"Error in detection process: {e}"
71
 
72
  # Define the Gradio interface
73
  def create_plagiarism_checker():
 
29
 
30
  # Tokenize and process the input text
31
  inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
32
+ inputs_len = inputs['input_ids'].shape[1]
33
+
34
  # Generate output with model, capturing scores (logits)
35
  with torch.no_grad():
36
  outputs = model.generate(
37
  inputs['input_ids'],
38
+ max_length=inputs_len + 50, # Generate up to 50 additional tokens
39
  output_scores=True,
40
  return_dict_in_generate=True
41
  )
42
 
43
+ # Extract the generated tokens from the model's predictions
44
+ generated_tokens = outputs.sequences[:, inputs_len:]
45
+
46
+ # Compute masks for watermark detection
47
+ eos_token_mask = logits_processor.compute_eos_token_mask(
48
+ input_ids=generated_tokens,
49
+ eos_token_id=tokenizer.eos_token_id,
50
+ )[:, CONFIG['ngram_len'] - 1 :]
51
+
52
+ context_repetition_mask = logits_processor.compute_context_repetition_mask(
53
+ input_ids=generated_tokens
54
+ )
55
+
56
+ # Combine the masks
57
+ combined_mask = context_repetition_mask * eos_token_mask
58
+
59
+ # Compute G values for the generated text
60
+ g_values = logits_processor.compute_g_values(input_ids=generated_tokens)
61
+
62
+ # Score the G values with the combined mask
63
+ score = detector_mean.mean_score(g_values.cpu().numpy(), combined_mask.cpu().numpy())
64
+
65
+ # Initialize string to store highlighted output
66
  highlighted_text = ""
67
+
68
+ # Loop through each token and apply highlighting if it meets the watermark criteria
69
+ for token_id, g_val, mask in zip(generated_tokens[0], g_values[0], combined_mask[0]):
70
+ token_text = tokenizer.decode(token_id.unsqueeze(0))
71
+
72
+ # If the token is part of the watermark (based on mask and g_value), highlight it
73
+ if mask.item() and g_val.item() > 0.5:
74
+ highlighted_text += f"<mark>{token_text}</mark>" # Highlight watermarked content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  else:
76
+ highlighted_text += token_text
77
+
78
+ # Return the highlighted text and overall watermark score
79
+ if score > 0.5:
80
+ return f"Flagged as AI-generated content (Academic Integrity Warning): {highlighted_text}"
81
+ else:
82
+ return f"Content appears to be human-generated. {highlighted_text}"
83
 
 
 
84
 
85
  # Define the Gradio interface
86
  def create_plagiarism_checker():