Upload 6 files
Browse files- app.py +57 -0
- models/__pycache__/watermark_faster.cpython-39.pyc +0 -0
- models/watermark_faster.py +465 -0
- models/watermark_original.py +368 -0
- options.py +14 -0
- requirements.txt +12 -0
app.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from models.watermark_faster import watermark_model
|
| 3 |
+
import pdb
|
| 4 |
+
from options import get_parser_main_model
|
| 5 |
+
|
| 6 |
+
opts = get_parser_main_model().parse_args()
|
| 7 |
+
model = watermark_model(language=opts.language, mode=opts.mode, tau_word=opts.tau_word, lamda=opts.lamda)
|
| 8 |
+
def watermark_embed_demo(raw):
|
| 9 |
+
|
| 10 |
+
watermarked_text = model.embed(raw)
|
| 11 |
+
return watermarked_text
|
| 12 |
+
|
| 13 |
+
def watermark_extract(raw):
|
| 14 |
+
is_watermark, p_value, n, ones, z_value = model.watermark_detector_fast(raw)
|
| 15 |
+
confidence = (1 - p_value) * 100
|
| 16 |
+
|
| 17 |
+
return f"{confidence:.2f}%"
|
| 18 |
+
|
| 19 |
+
def precise_watermark_detect(raw):
|
| 20 |
+
is_watermark, p_value, n, ones, z_value = model.watermark_detector_precise(raw)
|
| 21 |
+
confidence = (1 - p_value) * 100
|
| 22 |
+
|
| 23 |
+
return f"{confidence:.2f}%"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
demo = gr.Blocks()
|
| 27 |
+
with demo:
|
| 28 |
+
with gr.Column():
|
| 29 |
+
gr.Markdown("# Watermarking Text Generated by Black-Box Language Models")
|
| 30 |
+
|
| 31 |
+
inputs = gr.TextArea(label="Input text", placeholder="Copy your text here...")
|
| 32 |
+
output = gr.Textbox(label="Watermarked Text")
|
| 33 |
+
analysis_button = gr.Button("Inject Watermark")
|
| 34 |
+
inputs_embed = [inputs]
|
| 35 |
+
analysis_button.click(fn=watermark_embed_demo, inputs=inputs_embed, outputs=output)
|
| 36 |
+
|
| 37 |
+
inputs_w = gr.TextArea(label="Text to Analyze", placeholder="Copy your watermarked text here...")
|
| 38 |
+
|
| 39 |
+
mode = gr.Dropdown(
|
| 40 |
+
label="Detection Mode", choices=["Fast", "Precise"], default="Fast"
|
| 41 |
+
)
|
| 42 |
+
output_detect = gr.Textbox(label="Confidence (the likelihood of the text containing a watermark)")
|
| 43 |
+
detect_button = gr.Button("Detect")
|
| 44 |
+
|
| 45 |
+
def detect_watermark(inputs_w, mode):
|
| 46 |
+
if mode == "Fast":
|
| 47 |
+
return watermark_extract(inputs_w)
|
| 48 |
+
else:
|
| 49 |
+
return precise_watermark_detect(inputs_w)
|
| 50 |
+
|
| 51 |
+
detect_button.click(fn=detect_watermark, inputs=[inputs_w, mode], outputs=output_detect)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
gr.close_all()
|
| 56 |
+
demo.title = "Watermarking Text Generated by Black-Box Language Models"
|
| 57 |
+
demo.launch(share = True, server_port=8899)
|
models/__pycache__/watermark_faster.cpython-39.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
models/watermark_faster.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nltk
|
| 2 |
+
from nltk.corpus import stopwords
|
| 3 |
+
from nltk import word_tokenize, pos_tag
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
import hashlib
|
| 8 |
+
from scipy.stats import norm
|
| 9 |
+
import gensim
|
| 10 |
+
import pdb
|
| 11 |
+
from transformers import BertForMaskedLM as WoBertForMaskedLM
|
| 12 |
+
from wobert import WoBertTokenizer
|
| 13 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 14 |
+
|
| 15 |
+
from transformers import BertForMaskedLM, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer
|
| 16 |
+
import gensim.downloader as api
|
| 17 |
+
import Levenshtein
|
| 18 |
+
import string
|
| 19 |
+
import spacy
|
| 20 |
+
import paddle
|
| 21 |
+
from jieba import posseg
|
| 22 |
+
paddle.enable_static()
|
| 23 |
+
import re
|
| 24 |
+
def cut_sent(para):
|
| 25 |
+
para = re.sub('([。!?\?])([^”’])', r'\1\n\2', para)
|
| 26 |
+
para = re.sub('([。!?\?][”’])([^,。!?\?\n ])', r'\1\n\2', para)
|
| 27 |
+
para = re.sub('(\.{6}|\…{2})([^”’\n])', r'\1\n\2', para)
|
| 28 |
+
para = re.sub('([^。!?\?]*)([::][^。!?\?\n]*)', r'\1\n\2', para)
|
| 29 |
+
para = re.sub('([。!?\?][”’])$', r'\1\n', para)
|
| 30 |
+
para = para.rstrip()
|
| 31 |
+
return para.split("\n")
|
| 32 |
+
|
| 33 |
+
def is_subword(token: str):
|
| 34 |
+
return token.startswith('##')
|
| 35 |
+
|
| 36 |
+
def binary_encoding_function(token):
|
| 37 |
+
hash_value = int(hashlib.sha256(token.encode('utf-8')).hexdigest(), 16)
|
| 38 |
+
random_bit = hash_value % 2
|
| 39 |
+
return random_bit
|
| 40 |
+
|
| 41 |
+
def is_similar(x, y, threshold=0.5):
|
| 42 |
+
distance = Levenshtein.distance(x, y)
|
| 43 |
+
if distance / max(len(x), len(y)) < threshold:
|
| 44 |
+
return True
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
class watermark_model:
|
| 48 |
+
def __init__(self, language, mode, tau_word, lamda):
|
| 49 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
+
self.language = language
|
| 51 |
+
self.mode = mode
|
| 52 |
+
self.tau_word = tau_word
|
| 53 |
+
self.tau_sent = 0.8
|
| 54 |
+
self.lamda = lamda
|
| 55 |
+
self.cn_tag_black_list = set(['','x','u','j','k','zg','y','eng','uv','uj','ud','nr','nrfg','nrt','nw','nz','ns','nt','m','mq','r','w','PER','LOC','ORG'])#set(['','f','u','nr','nw','nz','m','r','p','c','w','PER','LOC','ORG'])
|
| 56 |
+
self.en_tag_white_list = set(['MD', 'NN', 'NNS', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'RP', 'RB', 'RBR', 'RBS', 'JJ', 'JJR', 'JJS'])
|
| 57 |
+
if language == 'Chinese':
|
| 58 |
+
self.relatedness_tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity")
|
| 59 |
+
self.relatedness_model = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity").to(self.device)
|
| 60 |
+
self.tokenizer = WoBertTokenizer.from_pretrained("junnyu/wobert_chinese_plus_base")
|
| 61 |
+
self.model = WoBertForMaskedLM.from_pretrained("junnyu/wobert_chinese_plus_base", output_hidden_states=True).to(self.device)
|
| 62 |
+
self.w2v_model = gensim.models.KeyedVectors.load_word2vec_format('sgns.merge.word.bz2', binary=False, unicode_errors='ignore', limit=50000)
|
| 63 |
+
elif language == 'English':
|
| 64 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 65 |
+
self.model = BertForMaskedLM.from_pretrained('bert-base-cased', output_hidden_states=True).to(self.device)
|
| 66 |
+
self.relatedness_model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli').to(self.device)
|
| 67 |
+
self.relatedness_tokenizer = RobertaTokenizer.from_pretrained('roberta-large-mnli')
|
| 68 |
+
self.w2v_model = api.load("glove-wiki-gigaword-100")
|
| 69 |
+
nltk.download('stopwords')
|
| 70 |
+
self.stop_words = set(stopwords.words('english'))
|
| 71 |
+
self.nlp = spacy.load('en_core_web_sm')
|
| 72 |
+
|
| 73 |
+
def cut(self,ori_text,text_len):
|
| 74 |
+
if self.language == 'Chinese':
|
| 75 |
+
if len(ori_text) > text_len+5:
|
| 76 |
+
ori_text = ori_text[:text_len+5]
|
| 77 |
+
if len(ori_text) < text_len-5:
|
| 78 |
+
return 'Short'
|
| 79 |
+
return ori_text
|
| 80 |
+
elif self.language == 'English':
|
| 81 |
+
tokens = self.tokenizer.tokenize(ori_text)
|
| 82 |
+
if len(tokens) > text_len+5:
|
| 83 |
+
ori_text = self.tokenizer.convert_tokens_to_string(tokens[:text_len+5])
|
| 84 |
+
if len(tokens) < text_len-5:
|
| 85 |
+
return 'Short'
|
| 86 |
+
return ori_text
|
| 87 |
+
else:
|
| 88 |
+
print(f'Unsupported Language:{self.language}')
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
def sent_tokenize(self,ori_text):
|
| 92 |
+
if self.language == 'Chinese':
|
| 93 |
+
return cut_sent(ori_text)
|
| 94 |
+
elif self.language == 'English':
|
| 95 |
+
return nltk.sent_tokenize(ori_text)
|
| 96 |
+
|
| 97 |
+
def pos_filter(self, tokens, masked_token_index, input_text):
|
| 98 |
+
if self.language == 'Chinese':
|
| 99 |
+
pairs = posseg.lcut(input_text)
|
| 100 |
+
pos_dict = {word: pos for word, pos in pairs}
|
| 101 |
+
pos_list_input = [pos for _, pos in pairs]
|
| 102 |
+
pos = pos_dict.get(tokens[masked_token_index], '')
|
| 103 |
+
if pos in self.cn_tag_black_list:
|
| 104 |
+
return False
|
| 105 |
+
else:
|
| 106 |
+
return True
|
| 107 |
+
elif self.language == 'English':
|
| 108 |
+
pos_tags = pos_tag(tokens)
|
| 109 |
+
pos = pos_tags[masked_token_index][1]
|
| 110 |
+
if pos not in self.en_tag_white_list:
|
| 111 |
+
return False
|
| 112 |
+
if is_subword(tokens[masked_token_index]) or is_subword(tokens[masked_token_index+1]) or (tokens[masked_token_index] in self.stop_words or tokens[masked_token_index] in string.punctuation):
|
| 113 |
+
return False
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
def filter_special_candidate(self, top_n_tokens, tokens,masked_token_index,input_text):
|
| 117 |
+
if self.language == 'English':
|
| 118 |
+
filtered_tokens = [tok for tok in top_n_tokens if tok not in self.stop_words and tok not in string.punctuation and pos_tag([tok])[0][1] in self.en_tag_white_list and not is_subword(tok)]
|
| 119 |
+
|
| 120 |
+
base_word = tokens[masked_token_index]
|
| 121 |
+
|
| 122 |
+
processed_tokens = [tok for tok in filtered_tokens if not is_similar(tok,base_word)]
|
| 123 |
+
return processed_tokens
|
| 124 |
+
elif self.language == 'Chinese':
|
| 125 |
+
pairs = posseg.lcut(input_text)
|
| 126 |
+
pos_dict = {word: pos for word, pos in pairs}
|
| 127 |
+
pos_list_input = [pos for _, pos in pairs]
|
| 128 |
+
pos = pos_dict.get(tokens[masked_token_index], '')
|
| 129 |
+
filtered_tokens = []
|
| 130 |
+
for tok in top_n_tokens:
|
| 131 |
+
watermarked_text_segtest = self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [tok] + tokens[masked_token_index+1:-1])
|
| 132 |
+
watermarked_text_segtest = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text_segtest)
|
| 133 |
+
pairs_tok = posseg.lcut(watermarked_text_segtest)
|
| 134 |
+
pos_dict_tok = {word: pos for word, pos in pairs_tok}
|
| 135 |
+
flag = pos_dict_tok.get(tok, '')
|
| 136 |
+
if flag not in self.cn_tag_black_list and flag == pos:
|
| 137 |
+
filtered_tokens.append(tok)
|
| 138 |
+
processed_tokens = filtered_tokens
|
| 139 |
+
return processed_tokens
|
| 140 |
+
|
| 141 |
+
def global_word_sim(self,word,ori_word):
|
| 142 |
+
try:
|
| 143 |
+
global_score = self.w2v_model.similarity(word,ori_word)
|
| 144 |
+
except KeyError:
|
| 145 |
+
global_score = 0
|
| 146 |
+
return global_score
|
| 147 |
+
|
| 148 |
+
def context_word_sim(self, init_candidates_list, tokens, index_space, input_text):
|
| 149 |
+
original_input_tensor = self.tokenizer.encode(input_text, return_tensors='pt').to(self.device)
|
| 150 |
+
|
| 151 |
+
all_cos_sims = []
|
| 152 |
+
|
| 153 |
+
for init_candidates, masked_token_index in zip(init_candidates_list, index_space):
|
| 154 |
+
batch_input_ids = [
|
| 155 |
+
[self.tokenizer.convert_tokens_to_ids(['[CLS]'] + tokens[1:masked_token_index] + [token] + tokens[masked_token_index + 1:-1] + ['[SEP]'])] for token in
|
| 156 |
+
init_candidates]
|
| 157 |
+
batch_input_tensors = torch.tensor(batch_input_ids).squeeze(1).to(self.device)
|
| 158 |
+
|
| 159 |
+
batch_input_tensors = torch.cat((batch_input_tensors, original_input_tensor), dim=0)
|
| 160 |
+
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
outputs = self.model(batch_input_tensors)
|
| 163 |
+
cos_sims = torch.zeros([len(init_candidates)]).to(self.device)
|
| 164 |
+
num_layers = len(outputs[1])
|
| 165 |
+
N = 8
|
| 166 |
+
i = masked_token_index
|
| 167 |
+
# We want to calculate similarity for the last N layers
|
| 168 |
+
hidden_states = outputs[1][-N:]
|
| 169 |
+
|
| 170 |
+
# Shape of hidden_states: [N, batch_size, sequence_length, hidden_size]
|
| 171 |
+
hidden_states = torch.stack(hidden_states)
|
| 172 |
+
|
| 173 |
+
# Separate the source and candidate hidden states
|
| 174 |
+
source_hidden_states = hidden_states[:, len(init_candidates):, i, :]
|
| 175 |
+
candidate_hidden_states = hidden_states[:, :len(init_candidates), i, :]
|
| 176 |
+
|
| 177 |
+
# Calculate cosine similarities across all layers and sum
|
| 178 |
+
cos_sim_sum = F.cosine_similarity(source_hidden_states.unsqueeze(2), candidate_hidden_states.unsqueeze(1), dim=-1).sum(dim=0)
|
| 179 |
+
|
| 180 |
+
cos_sim_avg = cos_sim_sum / N
|
| 181 |
+
cos_sims += cos_sim_avg.squeeze()
|
| 182 |
+
|
| 183 |
+
all_cos_sims.append(cos_sims.tolist())
|
| 184 |
+
|
| 185 |
+
return all_cos_sims
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def sentence_sim(self, init_candidates_list, tokens, index_space, input_text):
|
| 189 |
+
|
| 190 |
+
batch_size=128
|
| 191 |
+
all_batch_sentences = []
|
| 192 |
+
all_index_lengths = []
|
| 193 |
+
for init_candidates, masked_token_index in zip(init_candidates_list, index_space):
|
| 194 |
+
if self.language == 'Chinese':
|
| 195 |
+
batch_sents = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index + 1:-1]) for token in init_candidates]
|
| 196 |
+
batch_sentences = [re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', sent) for sent in batch_sents]
|
| 197 |
+
all_batch_sentences.extend([input_text + '[SEP]' + s for s in batch_sentences])
|
| 198 |
+
elif self.language == 'English':
|
| 199 |
+
batch_sentences = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index + 1:-1]) for token in init_candidates]
|
| 200 |
+
all_batch_sentences.extend([input_text + '</s></s>' + s for s in batch_sentences])
|
| 201 |
+
|
| 202 |
+
all_index_lengths.append(len(init_candidates))
|
| 203 |
+
|
| 204 |
+
all_relatedness_scores = []
|
| 205 |
+
start_index = 0
|
| 206 |
+
for i in range(0, len(all_batch_sentences), batch_size):
|
| 207 |
+
batch_sentences = all_batch_sentences[i: i + batch_size]
|
| 208 |
+
encoded_dict = self.relatedness_tokenizer.batch_encode_plus(
|
| 209 |
+
batch_sentences,
|
| 210 |
+
padding=True,
|
| 211 |
+
truncation=True,
|
| 212 |
+
max_length=512,
|
| 213 |
+
return_tensors='pt')
|
| 214 |
+
|
| 215 |
+
input_ids = encoded_dict['input_ids'].to(self.device)
|
| 216 |
+
attention_masks = encoded_dict['attention_mask'].to(self.device)
|
| 217 |
+
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
outputs = self.relatedness_model(input_ids=input_ids, attention_mask=attention_masks)
|
| 220 |
+
logits = outputs[0]
|
| 221 |
+
probs = torch.softmax(logits, dim=1)
|
| 222 |
+
if self.language == 'Chinese':
|
| 223 |
+
relatedness_scores = probs[:, 1]#.tolist()
|
| 224 |
+
elif self.language == 'English':
|
| 225 |
+
relatedness_scores = probs[:, 2]#.tolist()
|
| 226 |
+
all_relatedness_scores.extend(relatedness_scores)
|
| 227 |
+
|
| 228 |
+
all_relatedness_scores_split = []
|
| 229 |
+
for length in all_index_lengths:
|
| 230 |
+
all_relatedness_scores_split.append(all_relatedness_scores[start_index:start_index + length])
|
| 231 |
+
start_index += length
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
return all_relatedness_scores_split
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def candidates_gen(self, tokens, index_space, input_text, topk=64, dropout_prob=0.3):
|
| 238 |
+
input_ids_bert = self.tokenizer.convert_tokens_to_ids(tokens)
|
| 239 |
+
new_index_space = []
|
| 240 |
+
masked_text = self.tokenizer.convert_tokens_to_string(tokens)
|
| 241 |
+
# Create a tensor of input IDs
|
| 242 |
+
input_tensor = torch.tensor([input_ids_bert]).to(self.device)
|
| 243 |
+
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
embeddings = self.model.bert.embeddings(input_tensor.repeat(len(index_space), 1))
|
| 246 |
+
|
| 247 |
+
dropout = nn.Dropout2d(p=dropout_prob)
|
| 248 |
+
|
| 249 |
+
masked_indices = torch.tensor(index_space).to(self.device)
|
| 250 |
+
embeddings[torch.arange(len(index_space)), masked_indices] = dropout(embeddings[torch.arange(len(index_space)), masked_indices])
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
with torch.no_grad():
|
| 254 |
+
outputs = self.model(inputs_embeds=embeddings)
|
| 255 |
+
|
| 256 |
+
all_processed_tokens = []
|
| 257 |
+
for i, masked_token_index in enumerate(index_space):
|
| 258 |
+
predicted_logits = outputs[0][i][masked_token_index]
|
| 259 |
+
# Set the number of top predictions to return
|
| 260 |
+
n = topk
|
| 261 |
+
# Get the top n predicted tokens and their probabilities
|
| 262 |
+
probs = torch.nn.functional.softmax(predicted_logits, dim=-1)
|
| 263 |
+
top_n_probs, top_n_indices = torch.topk(probs, n)
|
| 264 |
+
top_n_tokens = self.tokenizer.convert_ids_to_tokens(top_n_indices.tolist())
|
| 265 |
+
processed_tokens = self.filter_special_candidate(top_n_tokens, tokens, masked_token_index,input_text)
|
| 266 |
+
|
| 267 |
+
if tokens[masked_token_index] not in processed_tokens:
|
| 268 |
+
processed_tokens = [tokens[masked_token_index]] + processed_tokens
|
| 269 |
+
all_processed_tokens.append(processed_tokens)
|
| 270 |
+
new_index_space.append(masked_token_index)
|
| 271 |
+
|
| 272 |
+
return all_processed_tokens,new_index_space
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def filter_candidates(self, init_candidates_list, tokens, index_space, input_text):
|
| 276 |
+
|
| 277 |
+
all_context_word_similarity_scores = self.context_word_sim(init_candidates_list, tokens, index_space, input_text)
|
| 278 |
+
|
| 279 |
+
all_sentence_similarity_scores = self.sentence_sim(init_candidates_list, tokens, index_space, input_text)
|
| 280 |
+
|
| 281 |
+
all_filtered_candidates = []
|
| 282 |
+
new_index_space = []
|
| 283 |
+
|
| 284 |
+
for init_candidates, context_word_similarity_scores, sentence_similarity_scores, masked_token_index in zip(init_candidates_list, all_context_word_similarity_scores, all_sentence_similarity_scores, index_space):
|
| 285 |
+
filtered_candidates = []
|
| 286 |
+
for idx, candidate in enumerate(init_candidates):
|
| 287 |
+
global_word_similarity_score = self.global_word_sim(tokens[masked_token_index], candidate)
|
| 288 |
+
word_similarity_score = self.lamda*context_word_similarity_scores[idx]+(1-self.lamda)*global_word_similarity_score
|
| 289 |
+
if word_similarity_score >= self.tau_word and sentence_similarity_scores[idx] >= self.tau_sent:
|
| 290 |
+
filtered_candidates.append((candidate, word_similarity_score))
|
| 291 |
+
|
| 292 |
+
if len(filtered_candidates) >= 1:
|
| 293 |
+
all_filtered_candidates.append(filtered_candidates)
|
| 294 |
+
new_index_space.append(masked_token_index)
|
| 295 |
+
return all_filtered_candidates, new_index_space
|
| 296 |
+
|
| 297 |
+
def get_candidate_encodings(self, tokens, enhanced_candidates, index_space):
|
| 298 |
+
best_candidates = []
|
| 299 |
+
new_index_space = []
|
| 300 |
+
|
| 301 |
+
for init_candidates, masked_token_index in zip(enhanced_candidates, index_space):
|
| 302 |
+
filtered_candidates = []
|
| 303 |
+
|
| 304 |
+
for idx, candidate in enumerate(init_candidates):
|
| 305 |
+
if masked_token_index-1 in new_index_space:
|
| 306 |
+
bit = binary_encoding_function(best_candidates[-1]+candidate[0])
|
| 307 |
+
else:
|
| 308 |
+
bit = binary_encoding_function(tokens[masked_token_index-1]+candidate[0])
|
| 309 |
+
|
| 310 |
+
if bit==1:
|
| 311 |
+
filtered_candidates.append(candidate)
|
| 312 |
+
|
| 313 |
+
# Sort the candidates based on their scores
|
| 314 |
+
filtered_candidates = sorted(filtered_candidates, key=lambda x: x[1], reverse=True)
|
| 315 |
+
|
| 316 |
+
if len(filtered_candidates) >= 1:
|
| 317 |
+
best_candidates.append(filtered_candidates[0][0])
|
| 318 |
+
new_index_space.append(masked_token_index)
|
| 319 |
+
|
| 320 |
+
return best_candidates, new_index_space
|
| 321 |
+
|
| 322 |
+
def watermark_embed(self,text):
|
| 323 |
+
input_text = text
|
| 324 |
+
# Tokenize the input text
|
| 325 |
+
tokens = self.tokenizer.tokenize(input_text)
|
| 326 |
+
tokens = ['[CLS]'] + tokens + ['[SEP]']
|
| 327 |
+
masked_tokens=tokens.copy()
|
| 328 |
+
start_index = 1
|
| 329 |
+
end_index = len(tokens) - 1
|
| 330 |
+
|
| 331 |
+
index_space = []
|
| 332 |
+
|
| 333 |
+
for masked_token_index in range(start_index+1, end_index-1):
|
| 334 |
+
binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tokens[masked_token_index])
|
| 335 |
+
if binary_encoding == 1 and masked_token_index-1 not in index_space:
|
| 336 |
+
continue
|
| 337 |
+
if not self.pos_filter(tokens,masked_token_index,input_text):
|
| 338 |
+
continue
|
| 339 |
+
index_space.append(masked_token_index)
|
| 340 |
+
|
| 341 |
+
if len(index_space)==0:
|
| 342 |
+
return text
|
| 343 |
+
init_candidates, new_index_space = self.candidates_gen(tokens,index_space,input_text, 8, 0)
|
| 344 |
+
if len(new_index_space)==0:
|
| 345 |
+
return text
|
| 346 |
+
enhanced_candidates, new_index_space = self.filter_candidates(init_candidates,tokens,new_index_space,input_text)
|
| 347 |
+
|
| 348 |
+
enhanced_candidates, new_index_space = self.get_candidate_encodings(tokens, enhanced_candidates, new_index_space)
|
| 349 |
+
|
| 350 |
+
for init_candidate, masked_token_index in zip(enhanced_candidates, new_index_space):
|
| 351 |
+
tokens[masked_token_index] = init_candidate
|
| 352 |
+
watermarked_text = self.tokenizer.convert_tokens_to_string(tokens[1:-1])
|
| 353 |
+
|
| 354 |
+
if self.language == 'Chinese':
|
| 355 |
+
watermarked_text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text)
|
| 356 |
+
return watermarked_text
|
| 357 |
+
|
| 358 |
+
def embed(self, ori_text):
|
| 359 |
+
sents = self.sent_tokenize(ori_text)
|
| 360 |
+
sents = [s for s in sents if s.strip()]
|
| 361 |
+
num_sents = len(sents)
|
| 362 |
+
watermarked_text = ''
|
| 363 |
+
|
| 364 |
+
for i in range(0, num_sents, 2):
|
| 365 |
+
if i+1 < num_sents:
|
| 366 |
+
sent_pair = sents[i] + sents[i+1]
|
| 367 |
+
else:
|
| 368 |
+
sent_pair = sents[i]
|
| 369 |
+
# keywords = jieba.analyse.extract_tags(sent_pair, topK=5, withWeight=False)
|
| 370 |
+
if len(watermarked_text) == 0:
|
| 371 |
+
watermarked_text = self.watermark_embed(sent_pair)
|
| 372 |
+
else:
|
| 373 |
+
watermarked_text = watermarked_text + self.watermark_embed(sent_pair)
|
| 374 |
+
if len(self.get_encodings_fast(ori_text)) == 0:
|
| 375 |
+
# print(ori_text)
|
| 376 |
+
return ''
|
| 377 |
+
return watermarked_text
|
| 378 |
+
|
| 379 |
+
def get_encodings_fast(self,text):
|
| 380 |
+
sents = self.sent_tokenize(text)
|
| 381 |
+
sents = [s for s in sents if s.strip()]
|
| 382 |
+
num_sents = len(sents)
|
| 383 |
+
encodings = []
|
| 384 |
+
for i in range(0, num_sents, 2):
|
| 385 |
+
if i+1 < num_sents:
|
| 386 |
+
sent_pair = sents[i] + sents[i+1]
|
| 387 |
+
else:
|
| 388 |
+
sent_pair = sents[i]
|
| 389 |
+
tokens = self.tokenizer.tokenize(sent_pair)
|
| 390 |
+
|
| 391 |
+
for index in range(1,len(tokens)-1):
|
| 392 |
+
if not self.pos_filter(tokens,index,text):
|
| 393 |
+
continue
|
| 394 |
+
bit = binary_encoding_function(tokens[index-1]+tokens[index])
|
| 395 |
+
encodings.append(bit)
|
| 396 |
+
return encodings
|
| 397 |
+
|
| 398 |
+
def watermark_detector_fast(self, text,alpha=0.05):
|
| 399 |
+
p = 0.5
|
| 400 |
+
encodings = self.get_encodings_fast(text)
|
| 401 |
+
n = len(encodings)
|
| 402 |
+
ones = sum(encodings)
|
| 403 |
+
if n == 0:
|
| 404 |
+
z = 0
|
| 405 |
+
else:
|
| 406 |
+
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
|
| 407 |
+
threshold = norm.ppf(1 - alpha, loc=0, scale=1)
|
| 408 |
+
p_value = norm.sf(z)
|
| 409 |
+
# p_value = norm.sf(abs(z)) * 2
|
| 410 |
+
is_watermark = z >= threshold
|
| 411 |
+
return is_watermark, p_value, n, ones, z
|
| 412 |
+
|
| 413 |
+
def get_encodings_precise(self, text):
|
| 414 |
+
# pdb.set_trace()
|
| 415 |
+
sents = self.sent_tokenize(text)
|
| 416 |
+
sents = [s for s in sents if s.strip()]
|
| 417 |
+
num_sents = len(sents)
|
| 418 |
+
encodings = []
|
| 419 |
+
for i in range(0, num_sents, 2):
|
| 420 |
+
if i+1 < num_sents:
|
| 421 |
+
sent_pair = sents[i] + sents[i+1]
|
| 422 |
+
else:
|
| 423 |
+
sent_pair = sents[i]
|
| 424 |
+
|
| 425 |
+
tokens = self.tokenizer.tokenize(sent_pair)
|
| 426 |
+
|
| 427 |
+
tokens = ['[CLS]'] + tokens + ['[SEP]']
|
| 428 |
+
|
| 429 |
+
masked_tokens=tokens.copy()
|
| 430 |
+
|
| 431 |
+
start_index = 1
|
| 432 |
+
end_index = len(tokens) - 1
|
| 433 |
+
|
| 434 |
+
index_space = []
|
| 435 |
+
for masked_token_index in range(start_index+1, end_index-1):
|
| 436 |
+
if not self.pos_filter(tokens,masked_token_index,sent_pair):
|
| 437 |
+
continue
|
| 438 |
+
index_space.append(masked_token_index)
|
| 439 |
+
if len(index_space)==0:
|
| 440 |
+
continue
|
| 441 |
+
|
| 442 |
+
init_candidates, new_index_space = self.candidates_gen(tokens,index_space,sent_pair, 8, 0)
|
| 443 |
+
enhanced_candidates, new_index_space = self.filter_candidates(init_candidates,tokens,new_index_space,sent_pair)
|
| 444 |
+
|
| 445 |
+
# pdb.set_trace()
|
| 446 |
+
for j,idx in enumerate(new_index_space):
|
| 447 |
+
if len(enhanced_candidates[j])>1:
|
| 448 |
+
bit = binary_encoding_function(tokens[idx-1]+tokens[idx])
|
| 449 |
+
encodings.append(bit)
|
| 450 |
+
return encodings
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def watermark_detector_precise(self,text,alpha=0.05):
|
| 454 |
+
p = 0.5
|
| 455 |
+
encodings = self.get_encodings_precise(text)
|
| 456 |
+
n = len(encodings)
|
| 457 |
+
ones = sum(encodings)
|
| 458 |
+
if n == 0:
|
| 459 |
+
z = 0
|
| 460 |
+
else:
|
| 461 |
+
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
|
| 462 |
+
threshold = norm.ppf(1 - alpha, loc=0, scale=1)
|
| 463 |
+
p_value = norm.sf(z)
|
| 464 |
+
is_watermark = z >= threshold
|
| 465 |
+
return is_watermark, p_value, n, ones, z
|
models/watermark_original.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import nltk
|
| 2 |
+
from nltk.corpus import stopwords
|
| 3 |
+
from nltk import word_tokenize, pos_tag
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
import hashlib
|
| 8 |
+
from scipy.stats import norm
|
| 9 |
+
import gensim
|
| 10 |
+
import pdb
|
| 11 |
+
from transformers import BertForMaskedLM as WoBertForMaskedLM
|
| 12 |
+
from wobert import WoBertTokenizer
|
| 13 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 14 |
+
|
| 15 |
+
from transformers import BertForMaskedLM, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer
|
| 16 |
+
import gensim.downloader as api
|
| 17 |
+
import Levenshtein
|
| 18 |
+
import string
|
| 19 |
+
import spacy
|
| 20 |
+
import paddle
|
| 21 |
+
from jieba import posseg
|
| 22 |
+
|
| 23 |
+
paddle.enable_static()
|
| 24 |
+
import re
|
| 25 |
+
def cut_sent(para):
|
| 26 |
+
para = re.sub('([。!?\?])([^”’])', r'\1\n\2', para)
|
| 27 |
+
para = re.sub('([。!?\?][”’])([^,。!?\?\n ])', r'\1\n\2', para)
|
| 28 |
+
para = re.sub('(\.{6}|\…{2})([^”’\n])', r'\1\n\2', para)
|
| 29 |
+
para = re.sub('([^。!?\?]*)([::][^。!?\?\n]*)', r'\1\n\2', para)
|
| 30 |
+
para = re.sub('([。!?\?][”’])$', r'\1\n', para)
|
| 31 |
+
para = para.rstrip()
|
| 32 |
+
return para.split("\n")
|
| 33 |
+
|
| 34 |
+
def is_subword(token: str):
|
| 35 |
+
return token.startswith('##')
|
| 36 |
+
|
| 37 |
+
def binary_encoding_function(token):
|
| 38 |
+
hash_value = int(hashlib.sha256(token.encode('utf-8')).hexdigest(), 16)
|
| 39 |
+
random_bit = hash_value % 2
|
| 40 |
+
return random_bit
|
| 41 |
+
|
| 42 |
+
def is_similar(x, y, threshold=0.5):
|
| 43 |
+
distance = Levenshtein.distance(x, y)
|
| 44 |
+
if distance / max(len(x), len(y)) < threshold:
|
| 45 |
+
return True
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
class watermark_model:
|
| 49 |
+
def __init__(self, language, mode, tau_word, lamda):
|
| 50 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 51 |
+
self.language = language
|
| 52 |
+
self.mode = mode
|
| 53 |
+
self.tau_word = tau_word
|
| 54 |
+
self.tau_sent = 0.8
|
| 55 |
+
self.lamda = lamda
|
| 56 |
+
self.cn_tag_black_list = set(['','x','u','j','k','zg','y','eng','uv','uj','ud','nr','nrfg','nrt','nw','nz','ns','nt','m','mq','r','w','PER','LOC','ORG'])#set(['','f','u','nr','nw','nz','m','r','p','c','w','PER','LOC','ORG'])
|
| 57 |
+
self.en_tag_white_list = set(['MD', 'NN', 'NNS', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'RP', 'RB', 'RBR', 'RBS', 'JJ', 'JJR', 'JJS'])
|
| 58 |
+
if language == 'Chinese':
|
| 59 |
+
self.relatedness_tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity")
|
| 60 |
+
self.relatedness_model = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity").to(self.device)
|
| 61 |
+
self.tokenizer = WoBertTokenizer.from_pretrained("junnyu/wobert_chinese_plus_base")
|
| 62 |
+
self.model = WoBertForMaskedLM.from_pretrained("junnyu/wobert_chinese_plus_base", output_hidden_states=True).to(self.device)
|
| 63 |
+
self.w2v_model = gensim.models.KeyedVectors.load_word2vec_format('sgns.merge.word.bz2', binary=False, unicode_errors='ignore', limit=50000)
|
| 64 |
+
elif language == 'English':
|
| 65 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 66 |
+
self.model = BertForMaskedLM.from_pretrained('bert-base-cased', output_hidden_states=True).to(self.device)
|
| 67 |
+
self.relatedness_model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli').to(self.device)
|
| 68 |
+
self.relatedness_tokenizer = RobertaTokenizer.from_pretrained('roberta-large-mnli')
|
| 69 |
+
self.w2v_model = api.load("glove-wiki-gigaword-100")
|
| 70 |
+
nltk.download('stopwords')
|
| 71 |
+
self.stop_words = set(stopwords.words('english'))
|
| 72 |
+
self.nlp = spacy.load('en_core_web_sm')
|
| 73 |
+
|
| 74 |
+
def cut(self,ori_text,text_len):
|
| 75 |
+
if self.language == 'Chinese':
|
| 76 |
+
if len(ori_text) > text_len+5:
|
| 77 |
+
ori_text = ori_text[:text_len+5]
|
| 78 |
+
if len(ori_text) < text_len-5:
|
| 79 |
+
return 'Short'
|
| 80 |
+
elif self.language == 'English':
|
| 81 |
+
tokens = self.tokenizer.tokenize(ori_text)
|
| 82 |
+
if len(tokens) > text_len+5:
|
| 83 |
+
ori_text = self.tokenizer.convert_tokens_to_string(tokens[:text_len+5])
|
| 84 |
+
if len(tokens) < text_len-5:
|
| 85 |
+
return 'Short'
|
| 86 |
+
return ori_text
|
| 87 |
+
else:
|
| 88 |
+
print(f'Unsupported Language:{self.language}')
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
def sent_tokenize(self,ori_text):
|
| 92 |
+
if self.language == 'Chinese':
|
| 93 |
+
return cut_sent(ori_text)
|
| 94 |
+
elif self.language == 'English':
|
| 95 |
+
return nltk.sent_tokenize(ori_text)
|
| 96 |
+
|
| 97 |
+
def pos_filter(self, tokens, masked_token_index, input_text):
|
| 98 |
+
if self.language == 'Chinese':
|
| 99 |
+
pairs = posseg.lcut(input_text)
|
| 100 |
+
pos_dict = {word: pos for word, pos in pairs}
|
| 101 |
+
pos_list_input = [pos for _, pos in pairs]
|
| 102 |
+
pos = pos_dict.get(tokens[masked_token_index], '')
|
| 103 |
+
if pos in self.cn_tag_black_list:
|
| 104 |
+
return False
|
| 105 |
+
else:
|
| 106 |
+
return True
|
| 107 |
+
elif self.language == 'English':
|
| 108 |
+
pos_tags = pos_tag(tokens)
|
| 109 |
+
pos = pos_tags[masked_token_index][1]
|
| 110 |
+
if pos not in self.en_tag_white_list:
|
| 111 |
+
return False
|
| 112 |
+
if is_subword(tokens[masked_token_index]) or is_subword(tokens[masked_token_index+1]) or (tokens[masked_token_index] in self.stop_words or tokens[masked_token_index] in string.punctuation):
|
| 113 |
+
return False
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
def filter_special_candidate(self, top_n_tokens, tokens,masked_token_index,input_text):
|
| 117 |
+
if self.language == 'English':
|
| 118 |
+
filtered_tokens = [tok for tok in top_n_tokens if tok not in self.stop_words and tok not in string.punctuation and pos_tag([tok])[0][1] in self.en_tag_white_list and not is_subword(tok)]
|
| 119 |
+
|
| 120 |
+
lemmatized_tokens = []
|
| 121 |
+
# for token in filtered_tokens:
|
| 122 |
+
# doc = self.nlp(token)
|
| 123 |
+
# lemma = doc[0].lemma_ if doc[0].lemma_ != "-PRON-" else token
|
| 124 |
+
# lemmatized_tokens.append(lemma)
|
| 125 |
+
|
| 126 |
+
base_word = tokens[masked_token_index]
|
| 127 |
+
base_word_lemma = self.nlp(base_word)[0].lemma_
|
| 128 |
+
processed_tokens = [base_word]+[tok for tok in filtered_tokens if self.nlp(tok)[0].lemma_ != base_word_lemma]
|
| 129 |
+
return processed_tokens
|
| 130 |
+
elif self.language == 'Chinese':
|
| 131 |
+
pairs = posseg.lcut(input_text)
|
| 132 |
+
pos_dict = {word: pos for word, pos in pairs}
|
| 133 |
+
pos_list_input = [pos for _, pos in pairs]
|
| 134 |
+
pos = pos_dict.get(tokens[masked_token_index], '')
|
| 135 |
+
filtered_tokens = []
|
| 136 |
+
for tok in top_n_tokens:
|
| 137 |
+
watermarked_text_segtest = self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [tok] + tokens[masked_token_index+1:-1])
|
| 138 |
+
watermarked_text_segtest = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text_segtest)
|
| 139 |
+
pairs_tok = posseg.lcut(watermarked_text_segtest)
|
| 140 |
+
pos_dict_tok = {word: pos for word, pos in pairs_tok}
|
| 141 |
+
flag = pos_dict_tok.get(tok, '')
|
| 142 |
+
if flag not in self.cn_tag_black_list and flag == pos:
|
| 143 |
+
filtered_tokens.append(tok)
|
| 144 |
+
processed_tokens = filtered_tokens
|
| 145 |
+
return processed_tokens
|
| 146 |
+
|
| 147 |
+
def global_word_sim(self,word,ori_word):
|
| 148 |
+
try:
|
| 149 |
+
global_score = self.w2v_model.similarity(word,ori_word)
|
| 150 |
+
except KeyError:
|
| 151 |
+
global_score = 0
|
| 152 |
+
return global_score
|
| 153 |
+
|
| 154 |
+
def context_word_sim(self,init_candidates, tokens, masked_token_index, input_text):
|
| 155 |
+
original_input_tensor = self.tokenizer.encode(input_text,return_tensors='pt').to(self.device)
|
| 156 |
+
batch_input_ids = [[self.tokenizer.convert_tokens_to_ids(['[CLS]'] + tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]+ ['[SEP]'])] for token in init_candidates]
|
| 157 |
+
batch_input_tensors = torch.tensor(batch_input_ids).squeeze().to(self.device)
|
| 158 |
+
batch_input_tensors = torch.cat((batch_input_tensors,original_input_tensor),dim=0)
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
outputs = self.model(batch_input_tensors)
|
| 161 |
+
cos_sims = torch.zeros([len(init_candidates)]).to(self.device)
|
| 162 |
+
num_layers = len(outputs[1])
|
| 163 |
+
N = 8
|
| 164 |
+
i = masked_token_index
|
| 165 |
+
cos_sim_sum = 0
|
| 166 |
+
for layer in range(num_layers-N,num_layers):
|
| 167 |
+
ls_hidden_states = outputs[1][layer][0:len(init_candidates), i, :]
|
| 168 |
+
source_hidden_state = outputs[1][layer][len(init_candidates), i, :]
|
| 169 |
+
cos_sim_sum += F.cosine_similarity(source_hidden_state, ls_hidden_states, dim=1)
|
| 170 |
+
cos_sim_avg = cos_sim_sum / N
|
| 171 |
+
|
| 172 |
+
cos_sims += cos_sim_avg
|
| 173 |
+
return cos_sims.tolist()
|
| 174 |
+
|
| 175 |
+
def sentence_sim(self,init_candidates, tokens, masked_token_index, input_text):
|
| 176 |
+
if self.language == 'Chinese':
|
| 177 |
+
batch_sents = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates]
|
| 178 |
+
batch_sentences = [re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', sent) for sent in batch_sents]
|
| 179 |
+
roberta_inputs = [input_text + '[SEP]' + s for s in batch_sentences]
|
| 180 |
+
elif self.language == 'English':
|
| 181 |
+
batch_sentences = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates]
|
| 182 |
+
roberta_inputs = [input_text + '</s></s>' + s for s in batch_sentences]
|
| 183 |
+
|
| 184 |
+
encoded_dict = self.relatedness_tokenizer.batch_encode_plus(
|
| 185 |
+
roberta_inputs,
|
| 186 |
+
padding=True,
|
| 187 |
+
truncation=True,
|
| 188 |
+
max_length=512,
|
| 189 |
+
return_tensors='pt')
|
| 190 |
+
# Extract input_ids and attention_masks
|
| 191 |
+
input_ids = encoded_dict['input_ids'].to(self.device)
|
| 192 |
+
attention_masks = encoded_dict['attention_mask'].to(self.device)
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
outputs = self.relatedness_model(input_ids=input_ids, attention_mask=attention_masks)
|
| 195 |
+
logits = outputs[0]
|
| 196 |
+
probs = torch.softmax(logits, dim=1)
|
| 197 |
+
if self.language == 'Chinese':
|
| 198 |
+
relatedness_scores = probs[:, 1].tolist()
|
| 199 |
+
elif self.language == 'English':
|
| 200 |
+
relatedness_scores = probs[:, 2].tolist()
|
| 201 |
+
|
| 202 |
+
return relatedness_scores
|
| 203 |
+
|
| 204 |
+
def candidates_gen(self,tokens,masked_token_index,input_text,topk=64, dropout_prob=0.3):
|
| 205 |
+
input_ids_bert = self.tokenizer.convert_tokens_to_ids(tokens)
|
| 206 |
+
if not self.pos_filter(tokens,masked_token_index,input_text):
|
| 207 |
+
return []
|
| 208 |
+
masked_text = self.tokenizer.convert_tokens_to_string(tokens)
|
| 209 |
+
# Create a tensor of input IDs
|
| 210 |
+
input_tensor = torch.tensor([input_ids_bert]).to(self.device)
|
| 211 |
+
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
embeddings = self.model.bert.embeddings(input_tensor)
|
| 214 |
+
dropout = nn.Dropout2d(p=dropout_prob)
|
| 215 |
+
# Get the predicted logits
|
| 216 |
+
embeddings[:, masked_token_index, :] = dropout(embeddings[:, masked_token_index, :])
|
| 217 |
+
with torch.no_grad():
|
| 218 |
+
outputs = self.model(inputs_embeds=embeddings)
|
| 219 |
+
|
| 220 |
+
predicted_logits = outputs[0][0][masked_token_index]
|
| 221 |
+
|
| 222 |
+
# Set the number of top predictions to return
|
| 223 |
+
n = topk
|
| 224 |
+
# Get the top n predicted tokens and their probabilities
|
| 225 |
+
probs = torch.nn.functional.softmax(predicted_logits, dim=-1)
|
| 226 |
+
top_n_probs, top_n_indices = torch.topk(probs, n)
|
| 227 |
+
top_n_tokens = self.tokenizer.convert_ids_to_tokens(top_n_indices.tolist())
|
| 228 |
+
processed_tokens = self.filter_special_candidate(top_n_tokens,tokens,masked_token_index)
|
| 229 |
+
|
| 230 |
+
return processed_tokens
|
| 231 |
+
|
| 232 |
+
def filter_candidates(self, init_candidates, tokens, masked_token_index, input_text):
|
| 233 |
+
context_word_similarity_scores = self.context_word_sim(init_candidates, tokens, masked_token_index, input_text)
|
| 234 |
+
sentence_similarity_scores = self.sentence_sim(init_candidates, tokens, masked_token_index, input_text)
|
| 235 |
+
filtered_candidates = []
|
| 236 |
+
for idx, candidate in enumerate(init_candidates):
|
| 237 |
+
global_word_similarity_score = self.global_word_sim(tokens[masked_token_index], candidate)
|
| 238 |
+
word_similarity_score = self.lamda*context_word_similarity_scores[idx]+(1-self.lamda)*global_word_similarity_score
|
| 239 |
+
if word_similarity_score >= self.tau_word and sentence_similarity_scores[idx] >= self.tau_sent:
|
| 240 |
+
filtered_candidates.append((candidate, word_similarity_score))#, sentence_similarity_scores[idx]))
|
| 241 |
+
return filtered_candidates
|
| 242 |
+
|
| 243 |
+
def watermark_embed(self,text):
|
| 244 |
+
input_text = text
|
| 245 |
+
# Tokenize the input text
|
| 246 |
+
tokens = self.tokenizer.tokenize(input_text)
|
| 247 |
+
tokens = ['[CLS]'] + tokens + ['[SEP]']
|
| 248 |
+
masked_tokens=tokens.copy()
|
| 249 |
+
start_index = 1
|
| 250 |
+
end_index = len(tokens) - 1
|
| 251 |
+
for masked_token_index in range(start_index+1, end_index-1):
|
| 252 |
+
# pdb.set_trace()
|
| 253 |
+
binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tokens[masked_token_index])
|
| 254 |
+
if binary_encoding == 1:
|
| 255 |
+
continue
|
| 256 |
+
init_candidates = self.candidates_gen(tokens,masked_token_index,input_text, 32, 0.3)
|
| 257 |
+
if len(init_candidates) <=1:
|
| 258 |
+
continue
|
| 259 |
+
enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,input_text)
|
| 260 |
+
hash_top_tokens = enhanced_candidates.copy()
|
| 261 |
+
for i, tok in enumerate(enhanced_candidates):
|
| 262 |
+
binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tok[0])
|
| 263 |
+
if binary_encoding != 1 or (is_similar(tok[0], tokens[masked_token_index])) or (tokens[masked_token_index - 1] in tok or tokens[masked_token_index + 1] in tok):
|
| 264 |
+
hash_top_tokens.remove(tok)
|
| 265 |
+
hash_top_tokens.sort(key=lambda x: x[1], reverse=True)
|
| 266 |
+
if len(hash_top_tokens) > 0:
|
| 267 |
+
selected_token = hash_top_tokens[0][0]
|
| 268 |
+
else:
|
| 269 |
+
selected_token = tokens[masked_token_index]
|
| 270 |
+
|
| 271 |
+
tokens[masked_token_index] = selected_token
|
| 272 |
+
watermarked_text = self.tokenizer.convert_tokens_to_string(tokens[1:-1])
|
| 273 |
+
if self.language == 'Chinese':
|
| 274 |
+
watermarked_text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text)
|
| 275 |
+
|
| 276 |
+
return watermarked_text
|
| 277 |
+
|
| 278 |
+
def embed(self, ori_text):
|
| 279 |
+
sents = self.sent_tokenize(ori_text)
|
| 280 |
+
sents = [s for s in sents if s.strip()]
|
| 281 |
+
num_sents = len(sents)
|
| 282 |
+
watermarked_text = ''
|
| 283 |
+
for i in range(0, num_sents, 2):
|
| 284 |
+
if i+1 < num_sents:
|
| 285 |
+
sent_pair = sents[i] + sents[i+1]
|
| 286 |
+
else:
|
| 287 |
+
sent_pair = sents[i]
|
| 288 |
+
if len(watermarked_text) == 0:
|
| 289 |
+
watermarked_text = self.watermark_embed(sent_pair)
|
| 290 |
+
else:
|
| 291 |
+
watermarked_text = watermarked_text + self.watermark_embed(sent_pair)
|
| 292 |
+
if len(self.get_encodings_fast(ori_text)) == 0:
|
| 293 |
+
return ''
|
| 294 |
+
return watermarked_text
|
| 295 |
+
|
| 296 |
+
def get_encodings_fast(self,text):
|
| 297 |
+
sents = self.sent_tokenize(text)
|
| 298 |
+
sents = [s for s in sents if s.strip()]
|
| 299 |
+
num_sents = len(sents)
|
| 300 |
+
encodings = []
|
| 301 |
+
for i in range(0, num_sents, 2):
|
| 302 |
+
if i+1 < num_sents:
|
| 303 |
+
sent_pair = sents[i] + sents[i+1]
|
| 304 |
+
else:
|
| 305 |
+
sent_pair = sents[i]
|
| 306 |
+
tokens = self.tokenizer.tokenize(sent_pair)
|
| 307 |
+
|
| 308 |
+
for index in range(1,len(tokens)-1):
|
| 309 |
+
if not self.pos_filter(tokens,index,text):
|
| 310 |
+
continue
|
| 311 |
+
bit = binary_encoding_function(tokens[index-1]+tokens[index])
|
| 312 |
+
encodings.append(bit)
|
| 313 |
+
return encodings
|
| 314 |
+
|
| 315 |
+
def watermark_detector_fast(self, text,alpha=0.05):
|
| 316 |
+
p = 0.5
|
| 317 |
+
encodings = self.get_encodings_fast(text)
|
| 318 |
+
n = len(encodings)
|
| 319 |
+
ones = sum(encodings)
|
| 320 |
+
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
|
| 321 |
+
threshold = norm.ppf(1 - alpha, loc=0, scale=1)
|
| 322 |
+
p_value = norm.sf(z)
|
| 323 |
+
is_watermark = z >= threshold
|
| 324 |
+
return is_watermark, p_value, n, ones, z
|
| 325 |
+
|
| 326 |
+
def get_encodings_precise(self, text):
|
| 327 |
+
sents = self.sent_tokenize(text)
|
| 328 |
+
sents = [s for s in sents if s.strip()]
|
| 329 |
+
num_sents = len(sents)
|
| 330 |
+
encodings = []
|
| 331 |
+
for i in range(0, num_sents, 2):
|
| 332 |
+
if i+1 < num_sents:
|
| 333 |
+
sent_pair = sents[i] + sents[i+1]
|
| 334 |
+
else:
|
| 335 |
+
sent_pair = sents[i]
|
| 336 |
+
|
| 337 |
+
tokens = self.tokenizer.tokenize(sent_pair)
|
| 338 |
+
|
| 339 |
+
tokens = ['[CLS]'] + tokens + ['[SEP]']
|
| 340 |
+
|
| 341 |
+
masked_tokens=tokens.copy()
|
| 342 |
+
|
| 343 |
+
start_index = 1
|
| 344 |
+
end_index = len(tokens) - 1
|
| 345 |
+
|
| 346 |
+
for masked_token_index in range(start_index+1, end_index-1):
|
| 347 |
+
init_candidates = self.candidates_gen(tokens,masked_token_index,sent_pair, 8, 0)
|
| 348 |
+
if len(init_candidates) <=1:
|
| 349 |
+
continue
|
| 350 |
+
enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,sent_pair)
|
| 351 |
+
if len(enhanced_candidates) > 1:
|
| 352 |
+
bit = binary_encoding_function(tokens[masked_token_index-1]+tokens[masked_token_index])
|
| 353 |
+
encodings.append(bit)
|
| 354 |
+
return encodings
|
| 355 |
+
|
| 356 |
+
def watermark_detector_precise(self,text,alpha=0.05):
|
| 357 |
+
p = 0.5
|
| 358 |
+
encodings = self.get_encodings_precise(text)
|
| 359 |
+
n = len(encodings)
|
| 360 |
+
ones = sum(encodings)
|
| 361 |
+
if n == 0:
|
| 362 |
+
z = 0
|
| 363 |
+
else:
|
| 364 |
+
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
|
| 365 |
+
threshold = norm.ppf(1 - alpha, loc=0, scale=1)
|
| 366 |
+
p_value = norm.sf(z)
|
| 367 |
+
is_watermark = z >= threshold
|
| 368 |
+
return is_watermark, p_value, n, ones, z
|
options.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
# TODO: add help for the parameters
|
| 3 |
+
|
| 4 |
+
def get_parser_main_model():
|
| 5 |
+
parser = argparse.ArgumentParser()
|
| 6 |
+
# TODO: basic parameters training related
|
| 7 |
+
|
| 8 |
+
# for embed
|
| 9 |
+
parser.add_argument('--language', type=str, default='English', help='text language')
|
| 10 |
+
parser.add_argument('--mode', type=str, choices=['embed', 'fast_detect', 'precise_detect'], default='embed', help='Mode options: embed (default), fast_detect, precise_detect')
|
| 11 |
+
parser.add_argument('--tau_word', type=float, default=0.8, help='word-level similarity thresh')
|
| 12 |
+
parser.add_argument('--lamda', type=float, default=0.83, help='word-level similarity weight')
|
| 13 |
+
|
| 14 |
+
return parser
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gensim==4.3.0
|
| 2 |
+
gradio==3.30.0
|
| 3 |
+
jieba==0.42.1
|
| 4 |
+
nltk==3.8.1
|
| 5 |
+
paddle==1.0.2
|
| 6 |
+
paddlepaddle==2.4.2
|
| 7 |
+
python_Levenshtein==0.21.0
|
| 8 |
+
scipy==1.7.3
|
| 9 |
+
spacy==3.5.0
|
| 10 |
+
torch==1.11.0
|
| 11 |
+
transformers==4.26.1
|
| 12 |
+
wobert==0.0.1
|