Jay-Rajput commited on
Commit
a74afb3
·
1 Parent(s): 5391c7d

fixaidetector

Browse files
Files changed (1) hide show
  1. text_detector.py +5 -6
text_detector.py CHANGED
@@ -2,7 +2,7 @@ import math
2
  import statistics
3
  import numpy as np
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  from collections import Counter
7
 
8
 
@@ -17,6 +17,7 @@ class AITextDetector:
17
  def __init__(self, model_name="roberta-base-openai-detector", device=None):
18
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
  self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
20
 
21
  if device:
22
  self.device = device
@@ -24,19 +25,17 @@ class AITextDetector:
24
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
  self.model.to(self.device)
 
27
  self.model.eval()
28
 
29
  def _compute_perplexity(self, text: str) -> float:
30
  """
31
  Approximate perplexity using NLL from model.
32
  """
33
- encodings = self.tokenizer(text, return_tensors="pt", truncation=True)
34
- input_ids = encodings.input_ids.to(self.device)
35
-
36
  with torch.no_grad():
37
- outputs = self.model(input_ids, labels=input_ids)
38
  loss = outputs.loss.item()
39
-
40
  return math.exp(loss)
41
 
42
  def _compute_burstiness(self, text: str) -> float:
 
2
  import statistics
3
  import numpy as np
4
  import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
6
  from collections import Counter
7
 
8
 
 
17
  def __init__(self, model_name="roberta-base-openai-detector", device=None):
18
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
  self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
20
+ self.lm_model = AutoModelForCausalLM.from_pretrained("gpt2")
21
 
22
  if device:
23
  self.device = device
 
25
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
  self.model.to(self.device)
28
+ self.lm_model.to(self.device)
29
  self.model.eval()
30
 
31
  def _compute_perplexity(self, text: str) -> float:
32
  """
33
  Approximate perplexity using NLL from model.
34
  """
35
+ encodings = self.tokenizer(text, return_tensors="pt", truncation=True).to(self.device)
 
 
36
  with torch.no_grad():
37
+ outputs = self.lm_model(**encodings, labels=encodings.input_ids)
38
  loss = outputs.loss.item()
 
39
  return math.exp(loss)
40
 
41
  def _compute_burstiness(self, text: str) -> float: