JosephH commited on
Commit
a699258
·
1 Parent(s): 943f417

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +297 -25
model.py CHANGED
@@ -1,4 +1,8 @@
 
 
1
  """
 
 
2
  This code a slight modification of perplexity by hugging face
3
  https://huggingface.co/docs/transformers/perplexity
4
 
@@ -6,42 +10,299 @@ Both this code and the orignal code are published under the MIT license.
6
 
7
  by Burhan Ul tayyab and Nicholas Chua
8
  """
9
-
10
  import torch
 
 
 
 
11
  import re
 
12
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast
 
 
 
 
13
  from collections import OrderedDict
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- class GPT2PPL:
17
- def __init__(self, device="cpu", model_id="gpt2"):
 
 
 
 
18
  self.device = device
19
  self.model_id = model_id
20
  self.model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
21
  self.tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
22
 
23
  self.max_length = self.model.config.n_positions
24
- self.stride = 512
25
-
26
- def getResults(self, threshold):
27
- if threshold < 60:
28
- label = 0
29
- return "The Text is generated by AI.", label
30
- elif threshold < 80:
31
- label = 0
32
- return "The Text is most probably contain parts which are generated by AI. (require more text for better Judgement)", label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  else:
34
- label = 1
35
- return "The Text is written by Human.", label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def __call__(self, sentence):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  """
39
- Takes in a sentence split by full stop
40
- and print the perplexity of the total sentence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  split the lines based on full stop and find the perplexity of each sentence and print
43
  average perplexity
44
-
45
  Burstiness is the max perplexity of each sentence
46
  """
47
  results = OrderedDict()
@@ -49,13 +310,13 @@ class GPT2PPL:
49
  total_valid_char = re.findall("[a-zA-Z0-9]+", sentence)
50
  total_valid_char = sum([len(x) for x in total_valid_char]) # finds len of all the valid characters a sentence
51
 
52
- if total_valid_char < 100:
53
- return {"status": "Please input more text (min 100 characters)"}, "Please input more text (min 100 characters)"
54
-
55
  lines = re.split(r'(?<=[.?!][ \[\(])|(?<=\n)\s*',sentence)
56
  lines = list(filter(lambda x: (x is not None) and (len(x) > 0), lines))
57
 
58
- ppl = self.getPPL(sentence)
59
  print(f"Perplexity {ppl}")
60
  results["Perplexity"] = ppl
61
 
@@ -75,7 +336,7 @@ class GPT2PPL:
75
  elif line[-1] == "[" or line[-1] == "(":
76
  offset = line[-1]
77
  line = line[:-1]
78
- ppl = self.getPPL(line)
79
  Perplexity_per_line.append(ppl)
80
  print(f"Perplexity per line {sum(Perplexity_per_line)/len(Perplexity_per_line)}")
81
  results["Perplexity per line"] = sum(Perplexity_per_line)/len(Perplexity_per_line)
@@ -83,12 +344,12 @@ class GPT2PPL:
83
  print(f"Burstiness {max(Perplexity_per_line)}")
84
  results["Burstiness"] = max(Perplexity_per_line)
85
 
86
- out, label = self.getResults(results["Perplexity per line"])
87
  results["label"] = label
88
 
89
  return results, out
90
 
91
- def getPPL(self,sentence):
92
  encodings = self.tokenizer(sentence, return_tensors="pt")
93
  seq_len = encodings.input_ids.size(1)
94
 
@@ -114,3 +375,14 @@ class GPT2PPL:
114
  break
115
  ppl = int(torch.exp(torch.stack(nlls).sum() / end_loc))
116
  return ppl
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
  """
4
+ T5
5
+
6
  This code a slight modification of perplexity by hugging face
7
  https://huggingface.co/docs/transformers/perplexity
8
 
 
10
 
11
  by Burhan Ul tayyab and Nicholas Chua
12
  """
13
+ import time
14
  import torch
15
+ import itertools
16
+ import math
17
+ import numpy as np
18
+ import random
19
  import re
20
+ import transformers
21
  from transformers import GPT2LMHeadModel, GPT2TokenizerFast
22
+ from transformers import pipeline
23
+ from transformers import T5Tokenizer
24
+ from transformers import AutoTokenizer, BartForConditionalGeneration
25
+
26
  from collections import OrderedDict
27
 
28
+ from scipy.stats import norm
29
+ from difflib import SequenceMatcher
30
+ from multiprocessing.pool import ThreadPool
31
+
32
+ def similar(a, b):
33
+ return SequenceMatcher(None, a, b).ratio()
34
+
35
+ def normCdf(x):
36
+ return norm.cdf(x)
37
+
38
+ def likelihoodRatio(x, y):
39
+ return normCdf(x)/normCdf(y)
40
 
41
+ torch.manual_seed(0)
42
+ np.random.seed(0)
43
+
44
+ # find a better way to abstract the class
45
+ class GPT2PPLV2:
46
+ def __init__(self, device="cpu", model_id="gpt2-medium"):
47
  self.device = device
48
  self.model_id = model_id
49
  self.model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
50
  self.tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
51
 
52
  self.max_length = self.model.config.n_positions
53
+ self.stride = 51
54
+ self.threshold = 0.7
55
+
56
+ self.t5_model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-large").to(device).half()
57
+ self.t5_tokenizer = T5Tokenizer.from_pretrained("t5-large", model_max_length=512)
58
+
59
+ def apply_extracted_fills(self, masked_texts, extracted_fills):
60
+ texts = []
61
+ for idx, (text, fills) in enumerate(zip(masked_texts, extracted_fills)):
62
+ tokens = list(re.finditer("<extra_id_\d+>", text))
63
+ if len(fills) < len(tokens):
64
+ continue
65
+
66
+ offset = 0
67
+ for fill_idx in range(len(tokens)):
68
+ start, end = tokens[fill_idx].span()
69
+ text = text[:start+offset] + fills[fill_idx] + text[end+offset:]
70
+ offset = offset - (end - start) + len(fills[fill_idx])
71
+ texts.append(text)
72
+
73
+ return texts
74
+
75
+ def unmasker(self, text, num_of_masks):
76
+ num_of_masks = max(num_of_masks)
77
+ stop_id = self.t5_tokenizer.encode(f"<extra_id_{num_of_masks}>")[0]
78
+ tokens = self.t5_tokenizer(text, return_tensors="pt", padding=True)
79
+ for key in tokens:
80
+ tokens[key] = tokens[key].to(self.device)
81
+
82
+ output_sequences = self.t5_model.generate(**tokens, max_length=512, do_sample=True, top_p=0.96, num_return_sequences=1, eos_token_id=stop_id)
83
+ results = self.t5_tokenizer.batch_decode(output_sequences, skip_special_tokens=False)
84
+
85
+ texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in results]
86
+ pattern = re.compile("<extra_id_\d+>")
87
+ extracted_fills = [pattern.split(x)[1:-1] for x in texts]
88
+ extracted_fills = [[y.strip() for y in x] for x in extracted_fills]
89
+
90
+ perturbed_texts = self.apply_extracted_fills(text, extracted_fills)
91
+
92
+ return perturbed_texts
93
+
94
+
95
+ def __call__(self, *args):
96
+ version = args[-1]
97
+ sentence = args[0]
98
+ if version == "v1.1":
99
+ return self.call_1_1(sentence, args[1])
100
+ elif version == "v1":
101
+ return self.call_1(sentence)
102
  else:
103
+ return "Model version not defined"
104
+
105
+ #################################ppp###############
106
+ # Version 1.1 apis
107
+ ###############################################
108
+
109
+ def replaceMask(self, text, num_of_masks):
110
+ with torch.no_grad():
111
+ list_generated_texts = self.unmasker(text, num_of_masks)
112
+
113
+ return list_generated_texts
114
+
115
+ def isSame(self, text1, text2):
116
+ return text1 == text2
117
+
118
+ # code took reference from https://github.com/eric-mitchell/detect-gpt
119
+ def maskRandomWord(self, text, ratio):
120
+ span = 2
121
+ tokens = text.split(' ')
122
+ mask_string = '<<<mask>>>'
123
+
124
+ n_spans = ratio//(span + 2)
125
 
126
+ n_masks = 0
127
+ while n_masks < n_spans:
128
+ start = np.random.randint(0, len(tokens) - span)
129
+ end = start + span
130
+ search_start = max(0, start - 1)
131
+ search_end = min(len(tokens), end + 1)
132
+ if mask_string not in tokens[search_start:search_end]:
133
+ tokens[start:end] = [mask_string]
134
+ n_masks += 1
135
+
136
+ # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
137
+ num_filled = 0
138
+ for idx, token in enumerate(tokens):
139
+ if token == mask_string:
140
+ tokens[idx] = f'<extra_id_{num_filled}>'
141
+ num_filled += 1
142
+ assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
143
+ text = ' '.join(tokens)
144
+ return text, n_masks
145
+
146
+ def multiMaskRandomWord(self, text, ratio, n):
147
+ mask_texts = []
148
+ list_num_of_masks = []
149
+ for i in range(n):
150
+ mask_text, num_of_masks = self.maskRandomWord(text, ratio)
151
+ mask_texts.append(mask_text)
152
+ list_num_of_masks.append(num_of_masks)
153
+ return mask_texts, list_num_of_masks
154
+
155
+ def getGeneratedTexts(self, args):
156
+ original_text = args[0]
157
+ n = args[1]
158
+ texts = list(re.finditer("[^\d\W]+", original_text))
159
+ ratio = int(0.3 * len(texts))
160
+
161
+ mask_texts, list_num_of_masks = self.multiMaskRandomWord(original_text, ratio, n)
162
+ list_generated_sentences = self.replaceMask(mask_texts, list_num_of_masks)
163
+ return list_generated_sentences
164
+
165
+ def mask(self, original_text, text, n=2, remaining=100):
166
  """
167
+ text: string representing the sentence
168
+ n: top n mask-filling to be choosen
169
+ remaining: The remaining slots to be fill
170
+ """
171
+
172
+ if remaining <= 0:
173
+ return []
174
+
175
+ torch.manual_seed(0)
176
+ np.random.seed(0)
177
+ start_time = time.time()
178
+ out_sentences = []
179
+ pool = ThreadPool(remaining//n)
180
+ out_sentences = pool.map(self.getGeneratedTexts, [(original_text, n) for _ in range(remaining//n)])
181
+ out_sentences = list(itertools.chain.from_iterable(out_sentences))
182
+ end_time = time.time()
183
+
184
+ return out_sentences
185
+
186
+ def getVerdict(self, score):
187
+ if score < self.threshold:
188
+ return "This text is most likely written by an Human"
189
+ else:
190
+ return "This text is most likely generated by an A.I."
191
+
192
+ def getScore(self, sentence):
193
+ original_sentence = sentence
194
+ sentence_length = len(list(re.finditer("[^\d\W]+", sentence)))
195
+ # remaining = int(min(max(100, sentence_length * 1/9), 200))
196
+ remaining = 50
197
+ sentences = self.mask(original_sentence, original_sentence, n=50, remaining=remaining)
198
+
199
+ real_log_likelihood = self.getLogLikelihood(original_sentence)
200
+
201
+ generated_log_likelihoods = []
202
+ for sentence in sentences:
203
+ generated_log_likelihoods.append(self.getLogLikelihood(sentence).cpu().detach().numpy())
204
+
205
+ if len(generated_log_likelihoods) == 0:
206
+ return -1
207
+
208
+ generated_log_likelihoods = np.asarray(generated_log_likelihoods)
209
+ mean_generated_log_likelihood = np.mean(generated_log_likelihoods)
210
+ std_generated_log_likelihood = np.std(generated_log_likelihoods)
211
+
212
+ diff = real_log_likelihood - mean_generated_log_likelihood
213
+
214
+ score = diff/(std_generated_log_likelihood)
215
+
216
+ return float(score), float(diff), float(std_generated_log_likelihood)
217
 
218
+ def call_1_1(self, sentence, chunk_value):
219
+ sentence = re.sub("\[[0-9]+\]", "", sentence) # remove all the [numbers] cause of wiki
220
+
221
+ words = re.split("[ \n]", sentence)
222
+
223
+ # if len(words) < 100:
224
+ # return {"status": "Please input more text (min 100 words)"}, "Please input more text (min 100 characters)", None
225
+
226
+ groups = len(words) // chunk_value + 1
227
+ lines = []
228
+ stride = len(words) // groups + 1
229
+ for i in range(0, len(words), stride):
230
+ start_pos = i
231
+ end_pos = min(i+stride, len(words))
232
+
233
+ selected_text = " ".join(words[start_pos:end_pos])
234
+ selected_text = selected_text.strip()
235
+ if selected_text == "":
236
+ continue
237
+
238
+ lines.append(selected_text)
239
+
240
+ # sentence by sentence
241
+ offset = ""
242
+ scores = []
243
+ probs = []
244
+ final_lines = []
245
+ labels = []
246
+ for line in lines:
247
+ if re.search("[a-zA-Z0-9]+", line) == None:
248
+ continue
249
+ score, diff, sd = self.getScore(line)
250
+ if score == -1 or math.isnan(score):
251
+ continue
252
+ scores.append(score)
253
+
254
+ final_lines.append(line)
255
+ if score > self.threshold:
256
+ labels.append(1)
257
+ prob = "{:.2f}%\n(A.I.)".format(normCdf(abs(self.threshold - score)) * 100)
258
+ probs.append(prob)
259
+ else:
260
+ labels.append(0)
261
+ prob = "{:.2f}%\n(Human)".format(normCdf(abs(self.threshold - score)) * 100)
262
+ probs.append(prob)
263
+
264
+ mean_score = sum(scores)/len(scores)
265
+
266
+ mean_prob = normCdf(abs(self.threshold - mean_score)) * 100
267
+ label = 0 if mean_score > self.threshold else 1
268
+ print(f"probability for {'A.I.' if label == 0 else 'Human'}:", "{:.2f}%".format(mean_prob))
269
+ return {"prob": "{:.2f}%".format(mean_prob), "label": label}, self.getVerdict(mean_score)
270
+
271
+ def getLogLikelihood(self,sentence):
272
+ encodings = self.tokenizer(sentence, return_tensors="pt")
273
+ seq_len = encodings.input_ids.size(1)
274
+
275
+ nlls = []
276
+ prev_end_loc = 0
277
+ for begin_loc in range(0, seq_len, self.stride):
278
+ end_loc = min(begin_loc + self.max_length, seq_len)
279
+ trg_len = end_loc - prev_end_loc
280
+ input_ids = encodings.input_ids[:, begin_loc:end_loc].to(self.device)
281
+ target_ids = input_ids.clone()
282
+ target_ids[:, :-trg_len] = -100
283
+
284
+ with torch.no_grad():
285
+ outputs = self.model(input_ids, labels=target_ids)
286
+
287
+ neg_log_likelihood = outputs.loss * trg_len
288
+
289
+ nlls.append(neg_log_likelihood)
290
+
291
+ prev_end_loc = end_loc
292
+ if end_loc == seq_len:
293
+ break
294
+ return -1 * torch.stack(nlls).sum() / end_loc
295
+
296
+ ################################################
297
+ # Version 1 apis
298
+ ###############################################
299
+
300
+ def call_1(self, sentence):
301
+ """
302
+ Takes in a sentence split by full stop
303
+ p and print the perplexity of the total sentence
304
  split the lines based on full stop and find the perplexity of each sentence and print
305
  average perplexity
 
306
  Burstiness is the max perplexity of each sentence
307
  """
308
  results = OrderedDict()
 
310
  total_valid_char = re.findall("[a-zA-Z0-9]+", sentence)
311
  total_valid_char = sum([len(x) for x in total_valid_char]) # finds len of all the valid characters a sentence
312
 
313
+ # if total_valid_char < 100:
314
+ # return {"status": "Please input more text (min 100 characters)"}, "Please input more text (min 100 characters)"
315
+
316
  lines = re.split(r'(?<=[.?!][ \[\(])|(?<=\n)\s*',sentence)
317
  lines = list(filter(lambda x: (x is not None) and (len(x) > 0), lines))
318
 
319
+ ppl = self.getPPL_1(sentence)
320
  print(f"Perplexity {ppl}")
321
  results["Perplexity"] = ppl
322
 
 
336
  elif line[-1] == "[" or line[-1] == "(":
337
  offset = line[-1]
338
  line = line[:-1]
339
+ ppl = self.getPPL_1(line)
340
  Perplexity_per_line.append(ppl)
341
  print(f"Perplexity per line {sum(Perplexity_per_line)/len(Perplexity_per_line)}")
342
  results["Perplexity per line"] = sum(Perplexity_per_line)/len(Perplexity_per_line)
 
344
  print(f"Burstiness {max(Perplexity_per_line)}")
345
  results["Burstiness"] = max(Perplexity_per_line)
346
 
347
+ out, label = self.getResults_1(results["Perplexity per line"])
348
  results["label"] = label
349
 
350
  return results, out
351
 
352
+ def getPPL_1(self,sentence):
353
  encodings = self.tokenizer(sentence, return_tensors="pt")
354
  seq_len = encodings.input_ids.size(1)
355
 
 
375
  break
376
  ppl = int(torch.exp(torch.stack(nlls).sum() / end_loc))
377
  return ppl
378
+
379
+ def getResults_1(self, threshold):
380
+ if threshold < 60:
381
+ label = 0
382
+ return "The Text is generated by AI.", label
383
+ elif threshold < 80:
384
+ label = 0
385
+ return "The Text is most probably contain parts which are generated by AI. (require more text for better Judgement)", label
386
+ else:
387
+ label = 1
388
+ return "The Text is written by Human.", label