quentinL52 commited on
Commit
ae5d0c1
·
1 Parent(s): ef16617

perpexity update

Browse files
Files changed (1) hide show
  1. src/services/nlp_service.py +38 -36
src/services/nlp_service.py CHANGED
@@ -35,47 +35,49 @@ class NLPService:
35
  MAX_PERPLEXITY_CHARS = 50000
36
 
37
  def calculate_perplexity(self, text: str) -> float:
38
- """
39
- Calculate perplexity of the text using a small GPT-2 model.
40
- Lower perplexity = more likely to be generated by AI (or very standard human text).
41
- """
42
- if not text or len(text.strip()) < 10:
43
- return 0.0
44
-
45
- # Truncate to avoid memory overflow on very long inputs
46
- if len(text) > self.MAX_PERPLEXITY_CHARS:
47
- text = text[:self.MAX_PERPLEXITY_CHARS]
48
-
49
- self._load_model()
50
-
51
- encodings = self._perplex_tokenizer(text, return_tensors='pt')
52
- max_length = self._perplex_model.config.n_positions
53
- stride = 512
54
- seq_len = encodings.input_ids.size(1)
 
 
 
 
 
 
 
 
 
 
55
 
56
- nlls = []
57
- prev_end_loc = 0
58
- for begin_loc in range(0, seq_len, stride):
59
- end_loc = min(begin_loc + max_length, seq_len)
60
- trg_len = end_loc - prev_end_loc # may be different from stride on last loop
61
- input_ids = encodings.input_ids[:, begin_loc:end_loc]
62
- target_ids = input_ids.clone()
63
- target_ids[:, :-trg_len] = -100
64
 
65
- with torch.no_grad():
66
- outputs = self._perplex_model(input_ids, labels=target_ids)
67
- neg_log_likelihood = outputs.loss
 
68
 
69
- nlls.append(neg_log_likelihood)
70
- prev_end_loc = end_loc
71
- if end_loc == seq_len:
72
- break
73
 
74
- if not nlls:
75
- return 0.0
76
 
77
- ppl = torch.exp(torch.stack(nlls).mean())
78
- return float(ppl)
79
 
80
  def analyze_sentiment(self, text: str) -> dict:
81
  """
 
35
  MAX_PERPLEXITY_CHARS = 50000
36
 
37
  def calculate_perplexity(self, text: str) -> float:
38
+ if not text or len(text.strip()) < 10:
39
+ return 0.0
40
+
41
+ if len(text) > self.MAX_PERPLEXITY_CHARS:
42
+ text = text[:self.MAX_PERPLEXITY_CHARS]
43
+
44
+ self._load_model()
45
+
46
+ encodings = self._perplex_tokenizer(text, return_tensors='pt', truncation=True, max_length=self.MAX_PERPLEXITY_CHARS)
47
+
48
+ max_length = self._perplex_model.config.n_positions
49
+ stride = 512
50
+ seq_len = encodings.input_ids.size(1)
51
+
52
+ nlls = []
53
+ prev_end_loc = 0
54
+
55
+ for begin_loc in range(0, seq_len, stride):
56
+ end_loc = min(begin_loc + max_length, seq_len)
57
+ trg_len = end_loc - prev_end_loc
58
+
59
+ input_ids = encodings.input_ids[:, begin_loc:end_loc]
60
+ if input_ids.size(1) > max_length:
61
+ input_ids = input_ids[:, :max_length]
62
+
63
+ target_ids = input_ids.clone()
64
+ target_ids[:, :-trg_len] = -100
65
 
66
+ with torch.no_grad():
67
+ outputs = self._perplex_model(input_ids, labels=target_ids)
68
+ neg_log_likelihood = outputs.loss
 
 
 
 
 
69
 
70
+ nlls.append(neg_log_likelihood)
71
+ prev_end_loc = end_loc
72
+ if end_loc == seq_len:
73
+ break
74
 
75
+ if not nlls:
76
+ return 0.0
 
 
77
 
 
 
78
 
79
+ ppl = torch.exp(torch.stack(nlls).mean())
80
+ return round(float(ppl), 2)
81
 
82
  def analyze_sentiment(self, text: str) -> dict:
83
  """