yasserrmd commited on
Commit
af005d6
·
verified ·
1 Parent(s): a9a3049

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -36
app.py CHANGED
@@ -32,55 +32,43 @@ CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG
32
  def check_plagiarism(text):
33
  # Logits processor for SynthID
34
  logits_processor = logits_processing.SynthIDLogitsProcessor(
35
- **CONFIG, top_k=TOP_K, temperature=TEMPERATURE
36
  )
37
 
38
  # Tokenize and process the input text
39
- inputs = tokenizer(text, return_tensors="pt", padding=True).to(DEVICE)
40
- inputs_len = inputs['input_ids'].shape[1]
41
-
42
- # Generate output with model, capturing scores (logits)
43
- with torch.no_grad():
44
- outputs = model.generate(
45
- **inputs,
46
- do_sample=True,
47
- max_length=1024,
48
- temperature=TEMPERATURE,
49
- top_k=TOP_K,
50
- top_p=TOP_P,
51
- )
52
 
53
-
54
- # Extract the generated tokens from the model's predictions
55
- generated_tokens = outputs[:, inputs_len:]
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Compute masks for watermark detection
58
- eos_token_mask = logits_processor.compute_eos_token_mask(
59
- input_ids=generated_tokens,
60
- eos_token_id=tokenizer.eos_token_id,
61
- )[:, CONFIG['ngram_len'] - 1 :]
62
-
63
- context_repetition_mask = logits_processor.compute_context_repetition_mask(
64
- input_ids=generated_tokens
65
- )
66
-
67
- # Combine the masks
68
- combined_mask = context_repetition_mask * eos_token_mask
69
-
70
- # Compute G values for the generated text
71
- g_values = logits_processor.compute_g_values(input_ids=generated_tokens)
72
-
73
  # Score the G values with the combined mask
74
  score = mean_score(g_values.cpu().numpy(), combined_mask.cpu().numpy())
75
 
76
  # Initialize string to store highlighted output
77
  highlighted_text = ""
78
 
79
- for token_id, g_val, mask in zip(generated_tokens[0], g_values[0], combined_mask[0]):
 
80
  token_text = tokenizer.decode(token_id.unsqueeze(0))
81
 
82
- # If the token is part of the watermark, use a mean or max threshold on g_val if it's multi-element
83
- if mask.item() and g_val.float().mean().item() > 0.5: # Use .mean() to get a scalar value
84
  highlighted_text += f"<mark>{token_text}</mark>" # Highlight watermarked content
85
  else:
86
  highlighted_text += token_text
 
32
  def check_plagiarism(text):
33
  # Logits processor for SynthID
34
  logits_processor = logits_processing.SynthIDLogitsProcessor(
35
+ **CONFIG, top_k=40, temperature=0.5
36
  )
37
 
38
  # Tokenize and process the input text
39
+ inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
40
+
41
+ # Extract token IDs for the input text only
42
+ input_ids = inputs['input_ids']
 
 
 
 
 
 
 
 
 
43
 
44
+ # Compute masks for watermark detection
45
+ eos_token_mask = logits_processor.compute_eos_token_mask(
46
+ input_ids=input_ids,
47
+ eos_token_id=tokenizer.eos_token_id,
48
+ )[:, CONFIG['ngram_len'] - 1:]
49
+
50
+ context_repetition_mask = logits_processor.compute_context_repetition_mask(
51
+ input_ids=input_ids
52
+ )
53
+
54
+ # Combine the masks
55
+ combined_mask = context_repetition_mask * eos_token_mask
56
+
57
+ # Compute G values for the input text
58
+ g_values = logits_processor.compute_g_values(input_ids=input_ids)
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Score the G values with the combined mask
61
  score = mean_score(g_values.cpu().numpy(), combined_mask.cpu().numpy())
62
 
63
  # Initialize string to store highlighted output
64
  highlighted_text = ""
65
 
66
+ # Loop through each token in the input text and apply highlighting if it meets the watermark criteria
67
+ for token_id, g_val, mask in zip(input_ids[0], g_values[0], combined_mask[0]):
68
  token_text = tokenizer.decode(token_id.unsqueeze(0))
69
 
70
+ # Convert g_val to float and highlight if it meets the threshold
71
+ if mask.item() and g_val.float().mean().item() > 0.5:
72
  highlighted_text += f"<mark>{token_text}</mark>" # Highlight watermarked content
73
  else:
74
  highlighted_text += token_text