dishitanagi commited on
Commit
cfcc8ce
·
verified ·
1 Parent(s): fc81394

Upload demo_watermark.py

Browse files
Files changed (1) hide show
  1. demo_watermark.py +975 -0
demo_watermark.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import argparse
19
+ from pprint import pprint
20
+ from functools import partial
21
+
22
+ import gc
23
+
24
+ import numpy # for gradio hot reload
25
+ import gradio as gr
26
+
27
+ import torch
28
+
29
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
30
+
31
+ from transformers import (AutoTokenizer,
32
+ AutoModelForSeq2SeqLM,
33
+ AutoModelForCausalLM,
34
+ LogitsProcessorList)
35
+
36
+ # from local_tokenizers.tokenization_llama import LLaMATokenizer
37
+
38
+ from transformers import GPT2TokenizerFast
39
+ OPT_TOKENIZER = GPT2TokenizerFast
40
+
41
+ from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
42
+
43
+
44
+ # ALPACA_MODEL_NAME = "alpaca"
45
+ # ALPACA_MODEL_TOKENIZER = LLaMATokenizer
46
+ # ALPACA_TOKENIZER_PATH = "/cmlscratch/jkirchen/llama"
47
+
48
+ # FIXME correct lengths for all models
49
+ API_MODEL_MAP = {
50
+ "google/flan-ul2" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
51
+ "google/flan-t5-xxl" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
52
+ "EleutherAI/gpt-neox-20b" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
53
+ # "bigscience/bloom" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
54
+ # "bigscience/bloomz" : {"max_length": 1000, "gamma": 0.5, "delta": 2.0},
55
+ }
56
+
57
+ def str2bool(v):
58
+ """Util function for user friendly boolean flag args"""
59
+ if isinstance(v, bool):
60
+ return v
61
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
62
+ return True
63
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
64
+ return False
65
+ else:
66
+ raise argparse.ArgumentTypeError('Boolean value expected.')
67
+
68
+ def parse_args():
69
+ """Command line argument specification"""
70
+
71
+ parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")
72
+
73
+ parser.add_argument(
74
+ "--run_gradio",
75
+ type=str2bool,
76
+ default=True,
77
+ help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
78
+ )
79
+ parser.add_argument(
80
+ "--demo_public",
81
+ type=str2bool,
82
+ default=False,
83
+ help="Whether to expose the gradio demo to the internet.",
84
+ )
85
+ parser.add_argument(
86
+ "--model_name_or_path",
87
+ type=str,
88
+ default="facebook/opt-6.7b",
89
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
90
+ )
91
+ parser.add_argument(
92
+ "--prompt_max_length",
93
+ type=int,
94
+ default=None,
95
+ help="Truncation length for prompt, overrides model config's max length field.",
96
+ )
97
+ parser.add_argument(
98
+ "--max_new_tokens",
99
+ type=int,
100
+ default=200,
101
+ help="Maximmum number of new tokens to generate.",
102
+ )
103
+ parser.add_argument(
104
+ "--generation_seed",
105
+ type=int,
106
+ default=123,
107
+ help="Seed for setting the torch global rng prior to generation.",
108
+ )
109
+ parser.add_argument(
110
+ "--use_sampling",
111
+ type=str2bool,
112
+ default=True,
113
+ help="Whether to generate using multinomial sampling.",
114
+ )
115
+ parser.add_argument(
116
+ "--sampling_temp",
117
+ type=float,
118
+ default=0.7,
119
+ help="Sampling temperature to use when generating using multinomial sampling.",
120
+ )
121
+ parser.add_argument(
122
+ "--n_beams",
123
+ type=int,
124
+ default=1,
125
+ help="Number of beams to use for beam search. 1 is normal greedy decoding",
126
+ )
127
+ parser.add_argument(
128
+ "--use_gpu",
129
+ type=str2bool,
130
+ default=True,
131
+ help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
132
+ )
133
+ parser.add_argument(
134
+ "--seeding_scheme",
135
+ type=str,
136
+ default="simple_1",
137
+ help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
138
+ )
139
+ parser.add_argument(
140
+ "--gamma",
141
+ type=float,
142
+ default=0.25,
143
+ help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
144
+ )
145
+ parser.add_argument(
146
+ "--delta",
147
+ type=float,
148
+ default=2.0,
149
+ help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
150
+ )
151
+ parser.add_argument(
152
+ "--normalizers",
153
+ type=str,
154
+ default="",
155
+ help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
156
+ )
157
+ parser.add_argument(
158
+ "--ignore_repeated_bigrams",
159
+ type=str2bool,
160
+ default=False,
161
+ help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
162
+ )
163
+ parser.add_argument(
164
+ "--detection_z_threshold",
165
+ type=float,
166
+ default=4.0,
167
+ help="The test statistic threshold for the detection hypothesis test.",
168
+ )
169
+ parser.add_argument(
170
+ "--select_green_tokens",
171
+ type=str2bool,
172
+ default=True,
173
+ help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
174
+ )
175
+ parser.add_argument(
176
+ "--skip_model_load",
177
+ type=str2bool,
178
+ default=False,
179
+ help="Skip the model loading to debug the interface.",
180
+ )
181
+ parser.add_argument(
182
+ "--seed_separately",
183
+ type=str2bool,
184
+ default=True,
185
+ help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
186
+ )
187
+ parser.add_argument(
188
+ "--load_fp16",
189
+ type=str2bool,
190
+ default=False,
191
+ help="Whether to run model in float16 precsion.",
192
+ )
193
+ parser.add_argument(
194
+ "--load_bf16",
195
+ type=str2bool,
196
+ default=False,
197
+ help="Whether to run model in float16 precsion.",
198
+ )
199
+ args = parser.parse_args()
200
+ return args
201
+
202
+ def load_model(args):
203
+ """Load and return the model and tokenizer"""
204
+
205
+ args.is_seq2seq_model = any([(model_type in args.model_name_or_path.lower()) for model_type in ["t5","T0"]])
206
+ args.is_decoder_only_model = any([(model_type in args.model_name_or_path.lower()) for model_type in ["gpt","opt","bloom","llama","qwen"]])
207
+ if args.is_seq2seq_model:
208
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
209
+ elif args.is_decoder_only_model:
210
+ if args.load_fp16:
211
+ # model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
212
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16)
213
+ elif args.load_bf16:
214
+ # model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16, device_map='auto')
215
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.bfloat16)
216
+ else:
217
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
218
+ else:
219
+ raise ValueError(f"Unknown model type: {args.model_name_or_path}")
220
+
221
+ if args.use_gpu:
222
+ device = "cuda" if torch.cuda.is_available() else "cpu"
223
+ # if args.load_fp16 or args.load_bf16:
224
+ # pass
225
+ # else:
226
+ model = model.to(device)
227
+ else:
228
+ device = "cpu"
229
+
230
+ if args.load_bf16:
231
+ model = model.to(torch.bfloat16)
232
+ if args.load_fp16:
233
+ model = model.to(torch.float16)
234
+
235
+ model.eval()
236
+
237
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
238
+
239
+ return model, tokenizer, device
240
+
241
+
242
+ from text_generation import InferenceAPIClient
243
+ from requests.exceptions import ReadTimeout
244
+ def generate_with_api(prompt, args):
245
+ hf_api_key = os.environ.get("HF_API_KEY")
246
+ if hf_api_key is None:
247
+ raise ValueError("HF_API_KEY environment variable not set, cannot use HF API to generate text.")
248
+
249
+ client = InferenceAPIClient(args.model_name_or_path, token=hf_api_key, timeout=60)
250
+
251
+ assert args.n_beams == 1, "HF API models do not support beam search."
252
+ generation_params = {
253
+ "max_new_tokens": args.max_new_tokens,
254
+ "do_sample": args.use_sampling,
255
+ }
256
+ if args.use_sampling:
257
+ generation_params["temperature"] = args.sampling_temp
258
+ generation_params["seed"] = args.generation_seed
259
+
260
+ timeout_msg = "[Model API timeout error. Try reducing the max_new_tokens parameter or the prompt length.]"
261
+ try:
262
+ generation_params["watermark"] = False
263
+ without_watermark_iterator = client.generate_stream(prompt, **generation_params)
264
+ except ReadTimeout as e:
265
+ print(e)
266
+ without_watermark_iterator = (char for char in timeout_msg)
267
+ try:
268
+ generation_params["watermark"] = True
269
+ with_watermark_iterator = client.generate_stream(prompt, **generation_params)
270
+ except ReadTimeout as e:
271
+ print(e)
272
+ with_watermark_iterator = (char for char in timeout_msg)
273
+
274
+ all_without_words, all_with_words = "", ""
275
+ for without_word, with_word in zip(without_watermark_iterator, with_watermark_iterator):
276
+ all_without_words += without_word.token.text
277
+ all_with_words += with_word.token.text
278
+ yield all_without_words, all_with_words
279
+
280
+
281
+ def check_prompt(prompt, args, tokenizer, model=None, device=None):
282
+
283
+ # This applies to both the local and API model scenarios
284
+ if args.model_name_or_path in API_MODEL_MAP:
285
+ args.prompt_max_length = API_MODEL_MAP[args.model_name_or_path]["max_length"]
286
+ elif hasattr(model.config,"max_position_embedding"):
287
+ args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
288
+ else:
289
+ args.prompt_max_length = 2048-args.max_new_tokens
290
+
291
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
292
+ truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
293
+ redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
294
+
295
+ return (redecoded_input,
296
+ int(truncation_warning),
297
+ args)
298
+
299
+
300
+
301
+ def generate(prompt, args, tokenizer, model=None, device=None):
302
+ """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
303
+ and generate watermarked text by passing it to the generate method of the model
304
+ as a logits processor. """
305
+
306
+ print(f"Generating with {args}")
307
+ print(f"Prompt: {prompt}")
308
+
309
+ if args.model_name_or_path in API_MODEL_MAP:
310
+ api_outputs = generate_with_api(prompt, args)
311
+ yield from api_outputs
312
+ else:
313
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
314
+
315
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
316
+ gamma=args.gamma,
317
+ delta=args.delta,
318
+ seeding_scheme=args.seeding_scheme,
319
+ select_green_tokens=args.select_green_tokens)
320
+
321
+ gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
322
+
323
+ if args.use_sampling:
324
+ gen_kwargs.update(dict(
325
+ do_sample=True,
326
+ top_k=0,
327
+ temperature=args.sampling_temp
328
+ ))
329
+ else:
330
+ gen_kwargs.update(dict(
331
+ num_beams=args.n_beams
332
+ ))
333
+
334
+ generate_without_watermark = partial(
335
+ model.generate,
336
+ **gen_kwargs
337
+ )
338
+ generate_with_watermark = partial(
339
+ model.generate,
340
+ logits_processor=LogitsProcessorList([watermark_processor]),
341
+ **gen_kwargs
342
+ )
343
+
344
+ torch.manual_seed(args.generation_seed)
345
+ output_without_watermark = generate_without_watermark(**tokd_input)
346
+
347
+ # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
348
+ if args.seed_separately:
349
+ torch.manual_seed(args.generation_seed)
350
+ output_with_watermark = generate_with_watermark(**tokd_input)
351
+
352
+ if args.is_decoder_only_model:
353
+ # need to isolate the newly generated tokens
354
+ output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
355
+ output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
356
+
357
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
358
+ decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
359
+
360
+ # mocking the API outputs in a whitespace split generator style
361
+ all_without_words, all_with_words = "", ""
362
+ for without_word, with_word in zip(decoded_output_without_watermark.split(), decoded_output_with_watermark.split()):
363
+ all_without_words += without_word + " "
364
+ all_with_words += with_word + " "
365
+ yield all_without_words, all_with_words
366
+
367
+
368
+ def format_names(s):
369
+ """Format names for the gradio demo interface"""
370
+ s=s.replace("num_tokens_scored","Tokens Counted (T)")
371
+ s=s.replace("num_green_tokens","# Tokens in Greenlist")
372
+ s=s.replace("green_fraction","Fraction of T in Greenlist")
373
+ s=s.replace("z_score","z-score")
374
+ s=s.replace("p_value","p value")
375
+ s=s.replace("prediction","Prediction")
376
+ s=s.replace("confidence","Confidence")
377
+ return s
378
+
379
+ def list_format_scores(score_dict, detection_threshold):
380
+ """Format the detection metrics into a gradio dataframe input format"""
381
+ lst_2d = []
382
+ for k,v in score_dict.items():
383
+ if k=='green_fraction':
384
+ lst_2d.append([format_names(k), f"{v:.1%}"])
385
+ elif k=='confidence':
386
+ lst_2d.append([format_names(k), f"{v:.3%}"])
387
+ elif isinstance(v, float):
388
+ lst_2d.append([format_names(k), f"{v:.3g}"])
389
+ elif isinstance(v, bool):
390
+ lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
391
+ else:
392
+ lst_2d.append([format_names(k), f"{v}"])
393
+ if "confidence" in score_dict:
394
+ lst_2d.insert(-2,["z-score Threshold", f"{detection_threshold}"])
395
+ else:
396
+ lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
397
+ return lst_2d
398
+
399
+ def detect(input_text, args, tokenizer, device=None, return_green_token_mask=True):
400
+ """Instantiate the WatermarkDetection object and call detect on
401
+ the input text returning the scores and outcome of the test"""
402
+
403
+ print(f"Detecting with {args}")
404
+ print(f"Detection Tokenizer: {type(tokenizer)}")
405
+
406
+ watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
407
+ gamma=args.gamma,
408
+ seeding_scheme=args.seeding_scheme,
409
+ device=device,
410
+ tokenizer=tokenizer,
411
+ z_threshold=args.detection_z_threshold,
412
+ normalizers=args.normalizers,
413
+ ignore_repeated_bigrams=args.ignore_repeated_bigrams,
414
+ select_green_tokens=args.select_green_tokens)
415
+ # for now, just don't display the green token mask
416
+ # if we're using normalizers or ignore_repeated_bigrams
417
+ if args.normalizers != [] or args.ignore_repeated_bigrams:
418
+ return_green_token_mask = False
419
+
420
+ error = False
421
+ green_token_mask = None
422
+ if input_text == "":
423
+ error = True
424
+ else:
425
+ try:
426
+ score_dict = watermark_detector.detect(input_text, return_green_token_mask=return_green_token_mask)
427
+ green_token_mask = score_dict.pop("green_token_mask", None)
428
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
429
+ except ValueError as e:
430
+ print(e)
431
+ error = True
432
+ if error:
433
+ output = [["Error","string too short to compute metrics"]]
434
+ output += [["",""] for _ in range(6)]
435
+
436
+
437
+ html_output = "[No highlight markup generated]"
438
+
439
+ if green_token_mask is None:
440
+ html_output = "[Visualizing masks with ignore_repeated_bigrams enabled is not supported, toggle off to see the mask for this text. The mask is the same in both cases - only counting/stats are affected.]"
441
+
442
+ if green_token_mask is not None:
443
+ # hack bc we need a fast tokenizer with charspan support
444
+ if "opt" in args.model_name_or_path:
445
+ tokenizer = OPT_TOKENIZER.from_pretrained(args.model_name_or_path)
446
+
447
+ tokens = tokenizer(input_text)
448
+ if tokens["input_ids"][0] == tokenizer.bos_token_id:
449
+ tokens["input_ids"] = tokens["input_ids"][1:] # ignore attention mask
450
+ skip = watermark_detector.min_prefix_len
451
+ charspans = [tokens.token_to_chars(i) for i in range(skip,len(tokens["input_ids"]))]
452
+ charspans = [cs for cs in charspans if cs is not None] # remove the special token spans
453
+
454
+ if len(charspans) != len(green_token_mask): breakpoint()
455
+ assert len(charspans) == len(green_token_mask)
456
+
457
+ tags = [(f'<span class="green">{input_text[cs.start:cs.end]}</span>' if m else f'<span class="red">{input_text[cs.start:cs.end]}</span>') for cs, m in zip(charspans, green_token_mask)]
458
+ html_output = f'<p>{" ".join(tags)}</p>'
459
+
460
+ return output, args, tokenizer, html_output
461
+
462
+ def run_gradio(args, model=None, device=None, tokenizer=None):
463
+ """Define and launch the gradio demo interface"""
464
+
465
+ css = """
466
+ .green { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ccffcc; border-radius:0.5rem;}
467
+ .red { color: black!important;line-height:1.9em; padding: 0.2em 0.2em; background: #ffad99; border-radius:0.5rem;}
468
+ """
469
+
470
+ with gr.Blocks(css=css) as demo:
471
+ # Top section, greeting and instructions
472
+ with gr.Row():
473
+ with gr.Column(scale=9):
474
+ gr.Markdown(
475
+ """
476
+ ## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍
477
+ """
478
+ )
479
+ with gr.Column(scale=1):
480
+ gr.Markdown(
481
+ """
482
+ [![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)
483
+ """
484
+ )
485
+ # if model_name_or_path at startup not one of the API models then add to dropdown
486
+ # all_models = sorted(list(set(list(API_MODEL_MAP.keys())+[args.model_name_or_path])))
487
+ # all_models = [args.model_name_or_path]
488
+ all_models = args.all_models
489
+ model_selector = gr.Dropdown(
490
+ all_models,
491
+ value=args.model_name_or_path,
492
+ label="Language Model",
493
+ )
494
+
495
+ # Construct state for parameters, define updates and toggles
496
+ default_prompt = args.__dict__.pop("default_prompt")
497
+ session_args = gr.State(value=args)
498
+ # note that state obj automatically calls value if it's a callable, want to avoid calling tokenizer at startup
499
+ session_tokenizer = gr.State(value=lambda : tokenizer)
500
+
501
+ check_prompt_partial = partial(check_prompt, model=model, device=device)
502
+ generate_partial = partial(generate, model=model, device=device)
503
+ detect_partial = partial(detect, device=device)
504
+
505
+ with gr.Tab("Welcome"):
506
+ with gr.Row():
507
+ with gr.Column(scale=2):
508
+ gr.Markdown(
509
+ """
510
+ Potential harms of large language models can be mitigated by *watermarking* a model's output.
511
+ *Watermarks* are embedded signals in the generated text that are invisible to humans but algorithmically
512
+ detectable, that allow *anyone* to later check whether a given span of text
513
+ was likely to have been generated by a model that uses the watermark.
514
+
515
+ This space showcases a watermarking approach that can be applied to _any_ generative language model.
516
+ For demonstration purposes, the space demos a relatively small open-source language model.
517
+ Such a model is less powerful than proprietary commercial tools like ChatGPT, Claude, or Gemini.
518
+ Generally, prompts that entail a short, low entropy response such as the few word answer to a factual trivia question,
519
+ will not exhibit a strong watermark presence, while longer watermarked outputs will produce higher detection statistics.
520
+ """
521
+ )
522
+ gr.Markdown(
523
+ """
524
+ **[Generate & Detect]**: The first tab shows that the watermark can be embedded with
525
+ negligible impact on text quality. You can try any prompt and compare the quality of
526
+ normal text (*Output Without Watermark*) to the watermarked text (*Output With Watermark*) below it.
527
+ You can also "see" the watermark by looking at the **Highlighted** tab where the tokens are
528
+ colored green or red depending on which list they are in.
529
+ Metrics on the right show that the watermark can be reliably detected given a reasonably small number of tokens (25-50).
530
+ Detection is very efficient and does not use the language model or its parameters.
531
+
532
+ **[Detector Only]**: You can also copy-paste the watermarked text (or any other text)
533
+ into the second tab. This can be used to see how many sentences you could remove and still detect the watermark.
534
+ You can also verify here that the detection has, by design, a low false-positive rate;
535
+ This means that human-generated text that you copy into this detector will not be marked as machine-generated.
536
+
537
+ You can find more details about how this watermark functions in our paper ["A Watermark for Large Language Models"](https://arxiv.org/abs/2301.10226), presented at ICML 2023.
538
+ Additionally, read about our study on the reliabilty of this watermarking style in ["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634), presented at ICLR 2024.
539
+ """
540
+ )
541
+
542
+ with gr.Column(scale=1):
543
+ gr.Markdown(
544
+ """
545
+ ![](https://drive.google.com/uc?export=view&id=1yVLPcjm-xvaCjQyc3FGLsWIU84v1QRoC)
546
+ """
547
+ )
548
+
549
+ with gr.Tab("Generate & Detect"):
550
+
551
+ with gr.Row():
552
+ prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
553
+ with gr.Row():
554
+ generate_btn = gr.Button("Generate")
555
+ with gr.Row():
556
+ with gr.Column(scale=2):
557
+ with gr.Tab("Output Without Watermark (Raw Text)"):
558
+ output_without_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14)
559
+ with gr.Tab("Highlighted"):
560
+ html_without_watermark = gr.HTML(elem_id="html-without-watermark")
561
+ with gr.Column(scale=1):
562
+ without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
563
+ with gr.Row():
564
+ with gr.Column(scale=2):
565
+ with gr.Tab("Output With Watermark (Raw Text)"):
566
+ output_with_watermark = gr.Textbox(interactive=False,lines=14,max_lines=14)
567
+ with gr.Tab("Highlighted"):
568
+ html_with_watermark = gr.HTML(elem_id="html-with-watermark")
569
+ with gr.Column(scale=1):
570
+ with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
571
+
572
+ redecoded_input = gr.Textbox(visible=False)
573
+ truncation_warning = gr.Number(visible=False)
574
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
575
+ if truncation_warning:
576
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
577
+ else:
578
+ return orig_prompt, args
579
+
580
+ with gr.Tab("Detector Only"):
581
+ with gr.Row():
582
+ with gr.Column(scale=2):
583
+ with gr.Tab("Text to Analyze"):
584
+ detection_input = gr.Textbox(interactive=True,lines=14,max_lines=14)
585
+ with gr.Tab("Highlighted"):
586
+ html_detection_input = gr.HTML(elem_id="html-detection-input")
587
+ with gr.Column(scale=1):
588
+ detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
589
+ with gr.Row():
590
+ detect_btn = gr.Button("Detect")
591
+
592
+ # Parameter selection group
593
+ with gr.Accordion("Advanced Settings",open=False):
594
+ with gr.Row():
595
+ with gr.Column(scale=1):
596
+ gr.Markdown(f"#### Generation Parameters")
597
+ with gr.Row():
598
+ decoding = gr.Radio(label="Decoding Method",choices=["multinomial", "greedy"], value=("multinomial" if args.use_sampling else "greedy"))
599
+ with gr.Row():
600
+ sampling_temp = gr.Slider(label="Sampling Temperature", minimum=0.1, maximum=1.0, step=0.1, value=args.sampling_temp, visible=True)
601
+ with gr.Row():
602
+ generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
603
+ with gr.Row():
604
+ n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=((not args.use_sampling) and (not args.model_name_or_path in API_MODEL_MAP)))
605
+ with gr.Row():
606
+ max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
607
+
608
+ with gr.Column(scale=1):
609
+ gr.Markdown(f"#### Watermark Parameters")
610
+ with gr.Row():
611
+ gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
612
+ with gr.Row():
613
+ delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
614
+ gr.Markdown(f"#### Detector Parameters")
615
+ with gr.Row():
616
+ detection_z_threshold = gr.Slider(label="z-score threshold",minimum=0.0, maximum=10.0, step=0.1, value=args.detection_z_threshold)
617
+ with gr.Row():
618
+ ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
619
+ with gr.Row():
620
+ normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
621
+ with gr.Row():
622
+ gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._")
623
+ with gr.Row():
624
+ current_parameters = gr.Textbox(label="Current Parameters", value=args)
625
+ with gr.Accordion("Legacy Settings",open=False):
626
+ with gr.Row():
627
+ with gr.Column(scale=1):
628
+ seed_separately = gr.Checkbox(label="Seed both generations separately", value=args.seed_separately)
629
+ with gr.Column(scale=1):
630
+ select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens)
631
+
632
+
633
+ with gr.Accordion("What do the settings do?",open=False):
634
+ gr.Markdown(
635
+ """
636
+ #### Generation Parameters:
637
+
638
+ - **Decoding Method** : We can generate tokens from the model using either multinomial sampling or we can generate using greedy decoding.
639
+ - **Sampling Temperature** : If using multinomial sampling we can set the temperature of the sampling distribution.
640
+ 0.0 is equivalent to greedy decoding, and 1.0 is the maximum amount of variability/entropy in the next token distribution.
641
+ 0.7 strikes a nice balance between faithfulness to the model's estimate of top candidates while adding variety. Does not apply for greedy decoding.
642
+ - **Generation Seed** : The integer to pass to the torch random number generator before running generation. Makes the multinomial sampling strategy
643
+ outputs reproducible. Does not apply for greedy decoding.
644
+ - **Number of Beams** : When using greedy decoding, we can also set the number of beams to > 1 to enable beam search.
645
+ This is not implemented/excluded from paper for multinomial sampling but may be added in future.
646
+ - **Max Generated Tokens** : The `max_new_tokens` parameter passed to the generation method to stop the output at a certain number of new tokens.
647
+ Note that the model is free to generate fewer tokens depending on the prompt.
648
+ Implicitly this sets the maximum number of prompt tokens possible as the model's maximum input length minus `max_new_tokens`,
649
+ and inputs will be truncated accordingly.
650
+
651
+ #### Watermark Parameters:
652
+
653
+ - **gamma** : The fraction of the vocabulary to be partitioned into the greenlist at each generation step.
654
+ Smaller gamma values create a stronger watermark by enabling the watermarked model to achieve
655
+ a greater differentiation from human/unwatermarked text because it is preferentially sampling
656
+ from a smaller green set making those tokens less likely to occur by chance.
657
+ - **delta** : The amount of positive bias to add to the logits of every token in the greenlist
658
+ at each generation step before sampling/choosing the next token. Higher delta values
659
+ mean that the greenlist tokens are more heavily preferred by the watermarked model
660
+ and as the bias becomes very large the watermark transitions from "soft" to "hard".
661
+ For a hard watermark, nearly all tokens are green, but this can have a detrimental effect on
662
+ generation quality, especially when there is not a lot of flexibility in the distribution.
663
+
664
+ #### Detector Parameters:
665
+
666
+ - **z-score threshold** : the z-score cuttoff for the hypothesis test. Higher thresholds (such as 4.0) make
667
+ _false positives_ (predicting that human/unwatermarked text is watermarked) very unlikely
668
+ as a genuine human text with a significant number of tokens will almost never achieve
669
+ that high of a z-score. Lower thresholds will capture more _true positives_ as some watermarked
670
+ texts will contain less green tokens and achive a lower z-score, but still pass the lower bar and
671
+ be flagged as "watermarked". However, a lowere threshold will increase the chance that human text
672
+ that contains a slightly higher than average number of green tokens is erroneously flagged.
673
+ 4.0-5.0 offers extremely low false positive rates while still accurately catching most watermarked text.
674
+ - **Ignore Bigram Repeats** : This alternate detection algorithm only considers the unique bigrams in the text during detection,
675
+ computing the greenlists based on the first in each pair and checking whether the second falls within the list.
676
+ This means that `T` is now the unique number of bigrams in the text, which becomes less than the total
677
+ number of tokens generated if the text contains a lot of repetition. See the paper for a more detailed discussion.
678
+ - **Normalizations** : we implement a few basic normaliations to defend against various adversarial perturbations of the
679
+ text analyzed during detection. Currently we support converting all chracters to unicode,
680
+ replacing homoglyphs with a canonical form, and standardizing the capitalization.
681
+ See the paper for a detailed discussion of input normalization.
682
+ """
683
+ )
684
+
685
+ with gr.Accordion("What do the output metrics mean?",open=False):
686
+ gr.Markdown(
687
+ """
688
+ - `z-score threshold` : The cuttoff for the hypothesis test
689
+ - `Tokens Counted (T)` : The number of tokens in the output that were counted by the detection algorithm.
690
+ The first token is ommitted in the simple, single token seeding scheme since there is no way to generate
691
+ a greenlist for it as it has no prefix token(s). Under the "Ignore Bigram Repeats" detection algorithm,
692
+ described in the bottom panel, this can be much less than the total number of tokens generated if there is a lot of repetition.
693
+ - `# Tokens in Greenlist` : The number of tokens that were observed to fall in their respective greenlist
694
+ - `Fraction of T in Greenlist` : The `# Tokens in Greenlist` / `T`. This is expected to be approximately `gamma` for human/unwatermarked text.
695
+ - `z-score` : The test statistic for the detection hypothesis test. If larger than the `z-score threshold`
696
+ we "reject the null hypothesis" that the text is human/unwatermarked, and conclude it is watermarked
697
+ - `p value` : The likelihood of observing the computed `z-score` under the null hypothesis. This is the likelihood of
698
+ observing the `Fraction of T in Greenlist` given that the text was generated without knowledge of the watermark procedure/greenlists.
699
+ If this is extremely _small_ we are confident that this many green tokens was not chosen by random chance.
700
+ - `prediction` : The outcome of the hypothesis test - whether the observed `z-score` was higher than the `z-score threshold`
701
+ - `confidence` : If we reject the null hypothesis, and the `prediction` is "Watermarked", then we report 1-`p value` to represent
702
+ the confidence of the detection based on the unlikeliness of this `z-score` observation.
703
+ """
704
+ )
705
+
706
+ gr.HTML("""
707
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
708
+ Follow the github link at the top and host the demo on your own GPU hardware to test out larger models.
709
+ <br/>
710
+ <a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true">
711
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
712
+ <p/>
713
+ """)
714
+
715
+ # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag, then call detection
716
+ generate_btn.click(fn=check_prompt_partial, inputs=[prompt,session_args,session_tokenizer], outputs=[redecoded_input, truncation_warning, session_args]).success(
717
+ fn=generate_partial, inputs=[redecoded_input,session_args,session_tokenizer], outputs=[output_without_watermark, output_with_watermark]).success(
718
+ fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark]).success(
719
+ fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
720
+ # Show truncated version of prompt if truncation occurred
721
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
722
+ # Register main detection tab click
723
+ detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer,html_detection_input], api_name="detection")
724
+
725
+ # State management logic
726
+ # define update callbacks that change the state dict
727
+ def update_model_state(session_state, value): session_state.model_name_or_path = value; return session_state
728
+ def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
729
+ def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
730
+ def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
731
+ def update_delta(session_state, value): session_state.delta = float(value); return session_state
732
+ def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value); return session_state
733
+ def update_decoding(session_state, value):
734
+ if value == "multinomial":
735
+ session_state.use_sampling = True
736
+ elif value == "greedy":
737
+ session_state.use_sampling = False
738
+ return session_state
739
+ def toggle_sampling_vis(value):
740
+ if value == "multinomial":
741
+ return gr.update(visible=True)
742
+ elif value == "greedy":
743
+ return gr.update(visible=False)
744
+ def toggle_sampling_vis_inv(value):
745
+ if value == "multinomial":
746
+ return gr.update(visible=False)
747
+ elif value == "greedy":
748
+ return gr.update(visible=True)
749
+ # if model name is in the list of api models, set the num beams parameter to 1 and hide n_beams
750
+ def toggle_vis_for_api_model(value):
751
+ if value in API_MODEL_MAP:
752
+ return gr.update(visible=False)
753
+ else:
754
+ return gr.update(visible=True)
755
+ def toggle_beams_for_api_model(value, orig_n_beams):
756
+ if value in API_MODEL_MAP:
757
+ return gr.update(value=1)
758
+ else:
759
+ return gr.update(value=orig_n_beams)
760
+ # if model name is in the list of api models, set the interactive parameter to false
761
+ def toggle_interactive_for_api_model(value):
762
+ if value in API_MODEL_MAP:
763
+ return gr.update(interactive=False)
764
+ else:
765
+ return gr.update(interactive=True)
766
+ # if model name is in the list of api models, set gamma and delta based on API map
767
+ def toggle_gamma_for_api_model(value, orig_gamma):
768
+ if value in API_MODEL_MAP:
769
+ return gr.update(value=API_MODEL_MAP[value]["gamma"])
770
+ else:
771
+ return gr.update(value=orig_gamma)
772
+ def toggle_delta_for_api_model(value, orig_delta):
773
+ if value in API_MODEL_MAP:
774
+ return gr.update(value=API_MODEL_MAP[value]["delta"])
775
+ else:
776
+ return gr.update(value=orig_delta)
777
+
778
+ def update_n_beams(session_state, value): session_state.n_beams = value; return session_state
779
+ def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
780
+ def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
781
+ def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
782
+ def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
783
+ def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
784
+ def update_tokenizer(model_name_or_path):
785
+ # if model_name_or_path == ALPACA_MODEL_NAME:
786
+ # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
787
+ # else:
788
+ return AutoTokenizer.from_pretrained(model_name_or_path)
789
+
790
+ def update_model(state, old_model):
791
+ del old_model
792
+ torch.cuda.empty_cache()
793
+ gc.collect()
794
+ model, _, _ = load_model(state)
795
+ return model
796
+
797
+ def check_model(value): return value if (value!="" and value is not None) else args.model_name_or_path
798
+ # enforce constraint that model cannot be null or empty
799
+ # then attach model callbacks in particular
800
+ model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
801
+ toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams]
802
+ ).then(
803
+ toggle_beams_for_api_model,inputs=[model_selector,n_beams], outputs=[n_beams]
804
+ ).then(
805
+ toggle_interactive_for_api_model,inputs=[model_selector], outputs=[gamma]
806
+ ).then(
807
+ toggle_interactive_for_api_model,inputs=[model_selector], outputs=[delta]
808
+ ).then(
809
+ toggle_gamma_for_api_model,inputs=[model_selector,gamma], outputs=[gamma]
810
+ ).then(
811
+ toggle_delta_for_api_model,inputs=[model_selector,delta], outputs=[delta]
812
+ ).then(
813
+ update_model_state,inputs=[session_args, model_selector], outputs=[session_args]
814
+ ).then(
815
+ update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
816
+ ).then(
817
+ lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
818
+ )
819
+ # registering callbacks for toggling the visibilty of certain parameters based on the values of others
820
+ decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
821
+ decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
822
+ decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
823
+ decoding.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
824
+ # registering all state update callbacks
825
+ decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
826
+ sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
827
+ generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
828
+ n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
829
+ max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
830
+ gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
831
+ delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
832
+ detection_z_threshold.change(update_detection_z_threshold,inputs=[session_args, detection_z_threshold], outputs=[session_args])
833
+ ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
834
+ normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
835
+ seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args])
836
+ select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args])
837
+ # register additional callback on button clicks that updates the shown parameters window
838
+ generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
839
+ detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
840
+ # When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
841
+ delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
842
+ gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
843
+ gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
844
+ gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
845
+ gamma.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
846
+ detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
847
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
848
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
849
+ detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
850
+ ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
851
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
852
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
853
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
854
+ normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
855
+ normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
856
+ normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
857
+ normalizers.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
858
+ select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
859
+ select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args,session_tokenizer], outputs=[without_watermark_detection_result,session_args,session_tokenizer,html_without_watermark])
860
+ select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args,session_tokenizer], outputs=[with_watermark_detection_result,session_args,session_tokenizer,html_with_watermark])
861
+ select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result,session_args,session_tokenizer,html_detection_input])
862
+
863
+
864
+ demo.queue()
865
+
866
+ if args.demo_public:
867
+ demo.launch(share=True) # exposes app to the internet via randomly generated link
868
+ else:
869
+ demo.launch()
870
+
871
+ def main(args):
872
+ """Run a command line version of the generation and detection operations
873
+ and optionally launch and serve the gradio demo"""
874
+ # Initial arg processing and log
875
+ args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
876
+ print(args)
877
+
878
+ if not args.skip_model_load:
879
+ model, tokenizer, device = load_model(args)
880
+ else:
881
+ model, tokenizer, device = None, None, None
882
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
883
+ if args.use_gpu:
884
+ device = "cuda" if torch.cuda.is_available() else "cpu"
885
+ else:
886
+ device = "cpu"
887
+
888
+
889
+ # terrapin example
890
+ input_text = (
891
+ "The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a "
892
+ "species of turtle native to the brackish coastal tidal marshes of the "
893
+ "Northeastern and southern United States, and in Bermuda.[6] It belongs "
894
+ "to the monotypic genus Malaclemys. It has one of the largest ranges of "
895
+ "all turtles in North America, stretching as far south as the Florida Keys "
896
+ "and as far north as Cape Cod.[7] The name 'terrapin' is derived from the "
897
+ "Algonquian word torope.[8] It applies to Malaclemys terrapin in both "
898
+ "British English and American English. The name originally was used by "
899
+ "early European settlers in North America to describe these brackish-water "
900
+ "turtles that inhabited neither freshwater habitats nor the sea. It retains "
901
+ "this primary meaning in American English.[8] In British English, however, "
902
+ "other semi-aquatic turtle species, such as the red-eared slider, might "
903
+ "also be called terrapins. The common name refers to the diamond pattern "
904
+ "on top of its shell (carapace), but the overall pattern and coloration "
905
+ "vary greatly. The shell is usually wider at the back than in the front, "
906
+ "and from above it appears wedge-shaped. The shell coloring can vary "
907
+ "from brown to grey, and its body color can be grey, brown, yellow, "
908
+ "or white. All have a unique pattern of wiggly, black markings or spots "
909
+ "on their body and head. The diamondback terrapin has large webbed "
910
+ "feet.[9] The species is"
911
+ )
912
+
913
+ args.default_prompt = input_text
914
+
915
+
916
+ # Generate and detect, report to stdout
917
+ if not args.skip_model_load:
918
+
919
+ term_width = 80
920
+ print("#"*term_width)
921
+ print("Prompt:")
922
+ print(input_text)
923
+
924
+ # a generator that yields (without_watermark, with_watermark) pairs
925
+ generator_outputs = generate(input_text,
926
+ args,
927
+ model=model,
928
+ device=device,
929
+ tokenizer=tokenizer)
930
+ # we need to iterate over it,
931
+ # but we only want the last output in this case
932
+ for out in generator_outputs:
933
+ decoded_output_without_watermark = out[0]
934
+ decoded_output_with_watermark = out[1]
935
+
936
+ without_watermark_detection_result = detect(decoded_output_without_watermark,
937
+ args,
938
+ device=device,
939
+ tokenizer=tokenizer,
940
+ return_green_token_mask=False)
941
+ with_watermark_detection_result = detect(decoded_output_with_watermark,
942
+ args,
943
+ device=device,
944
+ tokenizer=tokenizer,
945
+ return_green_token_mask=False)
946
+
947
+ print("#"*term_width)
948
+ print("Output without watermark:")
949
+ print(decoded_output_without_watermark)
950
+ print("-"*term_width)
951
+ print(f"Detection result @ {args.detection_z_threshold}:")
952
+ pprint(without_watermark_detection_result)
953
+ print("-"*term_width)
954
+
955
+ print("#"*term_width)
956
+ print("Output with watermark:")
957
+ print(decoded_output_with_watermark)
958
+ print("-"*term_width)
959
+ print(f"Detection result @ {args.detection_z_threshold}:")
960
+ pprint(with_watermark_detection_result)
961
+ print("-"*term_width)
962
+
963
+
964
+ # Launch the app to generate and detect interactively (implements the hf space demo)
965
+ if args.run_gradio:
966
+ run_gradio(args, model=model, tokenizer=tokenizer, device=device)
967
+
968
+ return
969
+
970
+ if __name__ == "__main__":
971
+
972
+ args = parse_args()
973
+ print(args)
974
+
975
+ main(args)