yasserrmd commited on
Commit
69c8fc6
·
verified ·
1 Parent(s): 43b9fdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -22,23 +22,29 @@ CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG
22
  # Function to check for AI-generated content using SynthID
23
  @spaces.GPU
24
  def check_plagiarism(text):
25
- # Tokenize and process the input text
26
- tokens = tokenizer.encode_plus(text, return_tensors="pt", truncation=True, padding=True)
27
- tokens = tokens.to(DEVICE)
28
-
29
  # Logits processor for SynthID
30
  logits_processor = logits_processing.SynthIDLogitsProcessor(
31
  **CONFIG, top_k=40, temperature=0.5
32
  )
33
 
34
 
35
- # Use SynthID's bayesian detector to check for AI generation likelihood
 
 
 
 
 
 
 
 
 
 
 
 
36
  try:
37
- # Assuming the logits processor can be used to score watermarked content
38
- logits_scores = logits_processor(tokens['input_ids'])
39
-
40
- # Simple threshold: assuming logits indicate watermark presence
41
- is_watermarked = logits_scores.mean().item() > 0.5
42
 
43
  if is_watermarked:
44
  return "Flagged as AI-generated content (Academic Integrity Warning)."
 
22
  # Function to check for AI-generated content using SynthID
23
  @spaces.GPU
24
  def check_plagiarism(text):
 
 
 
 
25
  # Logits processor for SynthID
26
  logits_processor = logits_processing.SynthIDLogitsProcessor(
27
  **CONFIG, top_k=40, temperature=0.5
28
  )
29
 
30
 
31
+ # Tokenize and process the input text
32
+ inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
33
+
34
+ # Generate output with model, capturing scores (logits)
35
+ with torch.no_grad():
36
+ outputs = model.generate(
37
+ inputs['input_ids'],
38
+ max_length=inputs['input_ids'].shape[1] + 50, # Generate up to 50 additional tokens
39
+ output_scores=True,
40
+ return_dict_in_generate=True
41
+ )
42
+
43
+ # Process logits through SynthID to check for watermark presence
44
  try:
45
+ # Pass logits (scores) to the SynthIDLogitsProcessor
46
+ logits = outputs.scores # Extract logits from the generation output
47
+ is_watermarked = logits_processor(inputs['input_ids'], logits=logits).mean().item() > 0.5
 
 
48
 
49
  if is_watermarked:
50
  return "Flagged as AI-generated content (Academic Integrity Warning)."