dishitanagi commited on
Commit
03cf094
·
verified ·
1 Parent(s): 6621c09

Upload watermark_processor.py

Browse files
Files changed (1) hide show
  1. watermark_processor.py +280 -0
watermark_processor.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import annotations
18
+ import collections
19
+ from math import sqrt
20
+
21
+ import scipy.stats
22
+
23
+ import torch
24
+ from torch import Tensor
25
+ from tokenizers import Tokenizer
26
+ from transformers import LogitsProcessor
27
+
28
+ from nltk.util import ngrams
29
+
30
+ from normalizers import normalization_strategy_lookup
31
+
32
+ class WatermarkBase:
33
+ def __init__(
34
+ self,
35
+ vocab: list[int] = None,
36
+ gamma: float = 0.5,
37
+ delta: float = 2.0,
38
+ seeding_scheme: str = "simple_1", # mostly unused/always default
39
+ hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
40
+ select_green_tokens: bool = True,
41
+ ):
42
+
43
+ # watermarking parameters
44
+ self.vocab = vocab
45
+ self.vocab_size = len(vocab)
46
+ self.gamma = gamma
47
+ self.delta = delta
48
+ self.seeding_scheme = seeding_scheme
49
+ self.rng = None
50
+ self.hash_key = hash_key
51
+ self.select_green_tokens = select_green_tokens
52
+
53
+ def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
54
+ # can optionally override the seeding scheme,
55
+ # but uses the instance attr by default
56
+ if seeding_scheme is None:
57
+ seeding_scheme = self.seeding_scheme
58
+
59
+ if seeding_scheme == "simple_1":
60
+ assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
61
+ prev_token = input_ids[-1].item()
62
+ self.rng.manual_seed(self.hash_key * prev_token)
63
+ else:
64
+ raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
65
+ return
66
+
67
+ def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
68
+ # seed the rng using the previous tokens/prefix
69
+ # according to the seeding_scheme
70
+ self._seed_rng(input_ids)
71
+
72
+ greenlist_size = int(self.vocab_size * self.gamma)
73
+ vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
74
+ if self.select_green_tokens: # directly
75
+ greenlist_ids = vocab_permutation[:greenlist_size] # new
76
+ else: # select green via red
77
+ greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
78
+ return greenlist_ids
79
+
80
+
81
+ class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
82
+
83
+ def __init__(self, *args, **kwargs):
84
+ super().__init__(*args, **kwargs)
85
+
86
+ def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
87
+ # TODO lets see if we can lose this loop
88
+ green_tokens_mask = torch.zeros_like(scores)
89
+ for b_idx in range(len(greenlist_token_ids)):
90
+ green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
91
+ final_mask = green_tokens_mask.bool()
92
+ return final_mask
93
+
94
+ def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
95
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
96
+ return scores
97
+
98
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
99
+
100
+ # this is lazy to allow us to colocate on the watermarked model's device
101
+ if self.rng is None:
102
+ self.rng = torch.Generator(device=input_ids.device)
103
+
104
+ # NOTE, it would be nice to get rid of this batch loop, but currently,
105
+ # the seed and partition operations are not tensor/vectorized, thus
106
+ # each sequence in the batch needs to be treated separately.
107
+ batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
108
+
109
+ for b_idx in range(input_ids.shape[0]):
110
+ greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
111
+ batched_greenlist_ids[b_idx] = greenlist_ids
112
+
113
+ green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)
114
+
115
+ scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
116
+ return scores
117
+
118
+
119
+ class WatermarkDetector(WatermarkBase):
120
+ def __init__(
121
+ self,
122
+ *args,
123
+ device: torch.device = None,
124
+ tokenizer: Tokenizer = None,
125
+ z_threshold: float = 4.0,
126
+ normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
127
+ ignore_repeated_bigrams: bool = False,
128
+ **kwargs,
129
+ ):
130
+ super().__init__(*args, **kwargs)
131
+ # also configure the metrics returned/preprocessing options
132
+ assert device, "Must pass device"
133
+ assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
134
+
135
+ self.tokenizer = tokenizer
136
+ self.device = device
137
+ self.z_threshold = z_threshold
138
+ self.rng = torch.Generator(device=self.device)
139
+
140
+ if self.seeding_scheme == "simple_1":
141
+ self.min_prefix_len = 1
142
+ else:
143
+ raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
144
+
145
+ self.normalizers = []
146
+ for normalization_strategy in normalizers:
147
+ self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
148
+
149
+ self.ignore_repeated_bigrams = ignore_repeated_bigrams
150
+ if self.ignore_repeated_bigrams:
151
+ assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme."
152
+
153
+
154
+ def _compute_z_score(self, observed_count, T):
155
+ # count refers to number of green tokens, T is total number of tokens
156
+ expected_count = self.gamma
157
+ numer = observed_count - expected_count * T
158
+ denom = sqrt(T * expected_count * (1 - expected_count))
159
+ z = numer / denom
160
+ return z
161
+
162
+ def _compute_p_value(self, z):
163
+ p_value = scipy.stats.norm.sf(z)
164
+ return p_value
165
+
166
+ def _score_sequence(
167
+ self,
168
+ input_ids: Tensor,
169
+ return_num_tokens_scored: bool = True,
170
+ return_num_green_tokens: bool = True,
171
+ return_green_fraction: bool = True,
172
+ return_green_token_mask: bool = False,
173
+ return_z_score: bool = True,
174
+ return_p_value: bool = True,
175
+ ):
176
+ if self.ignore_repeated_bigrams:
177
+ # Method that only counts a green/red hit once per unique bigram.
178
+ # New num total tokens scored (T) becomes the number unique bigrams.
179
+ # We iterate over all unqiue token bigrams in the input, computing the greenlist
180
+ # induced by the first token in each, and then checking whether the second
181
+ # token falls in that greenlist.
182
+ assert return_green_token_mask == False, "Can't return the green/red mask when ignoring repeats."
183
+ bigram_table = {}
184
+ token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
185
+ freq = collections.Counter(token_bigram_generator)
186
+ num_tokens_scored = len(freq.keys())
187
+ for idx, bigram in enumerate(freq.keys()):
188
+ prefix = torch.tensor([bigram[0]], device=self.device) # expects a 1-d prefix tensor on the randperm device
189
+ greenlist_ids = self._get_greenlist_ids(prefix)
190
+ bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
191
+ green_token_count = sum(bigram_table.values())
192
+ else:
193
+ num_tokens_scored = len(input_ids) - self.min_prefix_len
194
+ if num_tokens_scored < 1:
195
+ raise ValueError((f"Must have at least {1} token to score after "
196
+ f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."))
197
+ # Standard method.
198
+ # Since we generally need at least 1 token (for the simplest scheme)
199
+ # we start the iteration over the token sequence with a minimum
200
+ # num tokens as the first prefix for the seeding scheme,
201
+ # and at each step, compute the greenlist induced by the
202
+ # current prefix and check if the current token falls in the greenlist.
203
+ green_token_count, green_token_mask = 0, []
204
+ for idx in range(self.min_prefix_len, len(input_ids)):
205
+ curr_token = input_ids[idx]
206
+ greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
207
+ if curr_token in greenlist_ids:
208
+ green_token_count += 1
209
+ green_token_mask.append(True)
210
+ else:
211
+ green_token_mask.append(False)
212
+
213
+ score_dict = dict()
214
+ if return_num_tokens_scored:
215
+ score_dict.update(dict(num_tokens_scored=num_tokens_scored))
216
+ if return_num_green_tokens:
217
+ score_dict.update(dict(num_green_tokens=green_token_count))
218
+ if return_green_fraction:
219
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
220
+ if return_z_score:
221
+ score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
222
+ if return_p_value:
223
+ z_score = score_dict.get("z_score")
224
+ if z_score is None:
225
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
226
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
227
+ if return_green_token_mask:
228
+ score_dict.update(dict(green_token_mask=green_token_mask))
229
+
230
+ return score_dict
231
+
232
+ def detect(
233
+ self,
234
+ text: str = None,
235
+ tokenized_text: list[int] = None,
236
+ return_prediction: bool = True,
237
+ return_scores: bool = True,
238
+ z_threshold: float = None,
239
+ **kwargs,
240
+ ) -> dict:
241
+
242
+ assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
243
+ if return_prediction:
244
+ kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections
245
+
246
+ # run optional normalizers on text
247
+ for normalizer in self.normalizers:
248
+ text = normalizer(text)
249
+ if len(self.normalizers) > 0:
250
+ print(f"Text after normalization:\n\n{text}\n")
251
+
252
+ if tokenized_text is None:
253
+ assert self.tokenizer is not None, (
254
+ "Watermark detection on raw string ",
255
+ "requires an instance of the tokenizer ",
256
+ "that was used at generation time.",
257
+ )
258
+ tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
259
+ if tokenized_text[0] == self.tokenizer.bos_token_id:
260
+ tokenized_text = tokenized_text[1:]
261
+ else:
262
+ # try to remove the bos_tok at beginning if it's there
263
+ if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
264
+ tokenized_text = tokenized_text[1:]
265
+
266
+ # call score method
267
+ output_dict = {}
268
+ score_dict = self._score_sequence(tokenized_text, **kwargs)
269
+ if return_scores:
270
+ output_dict.update(score_dict)
271
+ # if passed return_prediction then perform the hypothesis test and return the outcome
272
+ if return_prediction:
273
+ z_threshold = z_threshold if z_threshold else self.z_threshold
274
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
275
+ output_dict["prediction"] = score_dict["z_score"] > z_threshold
276
+ if output_dict["prediction"]:
277
+ output_dict["confidence"] = 1 - score_dict["p_value"]
278
+
279
+ return output_dict
280
+