dishitanagi commited on
Commit
b0f8f67
·
verified ·
1 Parent(s): 03cf094

Upload normalizers.py

Browse files
Files changed (1) hide show
  1. normalizers.py +202 -0
normalizers.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Text-based normalizers, used to mitigate simple attacks against watermarking.
2
+
3
+ This implementation is unlikely to be a complete list of all possible exploits within the unicode standard,
4
+ it represents our best effort at the time of writing.
5
+
6
+ These normalizers can be used as stand-alone normalizers. They could be made to conform to HF tokenizers standard, but that would
7
+ require messing with the limited rust interface of tokenizers.NormalizedString
8
+ """
9
+ from collections import defaultdict
10
+ from functools import cache
11
+
12
+ import re
13
+ import unicodedata
14
+ import homoglyphs as hg
15
+
16
+
17
+ def normalization_strategy_lookup(strategy_name: str) -> object:
18
+ if strategy_name == "unicode":
19
+ return UnicodeSanitizer()
20
+ elif strategy_name == "homoglyphs":
21
+ return HomoglyphCanonizer()
22
+ elif strategy_name == "truecase":
23
+ return TrueCaser()
24
+
25
+
26
+ class HomoglyphCanonizer:
27
+ """Attempts to detect homoglyph attacks and find a consistent canon.
28
+
29
+ This function does so on a per-ISO-category level. Language-level would also be possible (see commented code).
30
+ """
31
+
32
+ def __init__(self):
33
+ self.homoglyphs = None
34
+
35
+ def __call__(self, homoglyphed_str: str) -> str:
36
+ # find canon:
37
+ target_category, all_categories = self._categorize_text(homoglyphed_str)
38
+ homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
39
+ return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)
40
+
41
+ def _categorize_text(self, text: str) -> dict:
42
+ iso_categories = defaultdict(int)
43
+ # self.iso_languages = defaultdict(int)
44
+
45
+ for char in text:
46
+ iso_categories[hg.Categories.detect(char)] += 1
47
+ # for lang in hg.Languages.detect(char):
48
+ # self.iso_languages[lang] += 1
49
+ target_category = max(iso_categories, key=iso_categories.get)
50
+ all_categories = tuple(iso_categories)
51
+ return target_category, all_categories
52
+
53
+ @cache
54
+ def _select_canon_category_and_load(self, target_category: str, all_categories: tuple[str]) -> dict:
55
+ homoglyph_table = hg.Homoglyphs(categories=(target_category, "COMMON")) # alphabet loaded here from file
56
+
57
+ source_alphabet = hg.Categories.get_alphabet(all_categories)
58
+ restricted_table = homoglyph_table.get_restricted_table(source_alphabet, homoglyph_table.alphabet) # table loaded here from file
59
+ return restricted_table
60
+
61
+ def _sanitize_text(self, target_category: str, homoglyph_table: dict, homoglyphed_str: str) -> str:
62
+ sanitized_text = ""
63
+ for char in homoglyphed_str:
64
+ # langs = hg.Languages.detect(char)
65
+ cat = hg.Categories.detect(char)
66
+ if target_category in cat or "COMMON" in cat or len(cat) == 0:
67
+ sanitized_text += char
68
+ else:
69
+ sanitized_text += list(homoglyph_table[char])[0]
70
+ return sanitized_text
71
+
72
+
73
+ class UnicodeSanitizer:
74
+ """Regex-based unicode sanitzer. Has different levels of granularity.
75
+
76
+ * ruleset="whitespaces" - attempts to remove only whitespace unicode characters
77
+ * ruleset="IDN.blacklist" - does its best to remove unusual unicode based on Network.IDN.blacklist characters
78
+ * ruleset="ascii" - brute-forces all text into ascii
79
+
80
+ This is unlikely to be a comprehensive list.
81
+
82
+ You can find a more comprehensive discussion at https://www.unicode.org/reports/tr36/
83
+ and https://www.unicode.org/faq/security.html
84
+ """
85
+
86
+ def __init__(self, ruleset="whitespaces"):
87
+ if ruleset == "whitespaces":
88
+
89
+ """Documentation:
90
+ \u00A0: Non-breaking space
91
+ \u1680: Ogham space mark
92
+ \u180E: Mongolian vowel separator
93
+ \u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
94
+ \u200C\u200D: Zero-width non-joiner and zero-width joiner
95
+ \u200E,\u200F: Left-to-right-mark, Right-to-left-mark
96
+ \u2060: Word joiner
97
+ \u2063: Invisible separator
98
+ \u202F: Narrow non-breaking space
99
+ \u205F: Medium mathematical space
100
+ \u3000: Ideographic space
101
+ \uFEFF: Zero-width non-breaking space
102
+ \uFFA0: Halfwidth hangul filler
103
+ \uFFF9\uFFFA\uFFFB: Interlinear annotation characters
104
+ \uFE00-\uFE0F: Variation selectors
105
+ \u202A-\u202F: Embedding characters
106
+ \u3164: Korean hangul filler.
107
+
108
+ Note that these characters are not always superfluous whitespace characters!
109
+ """
110
+
111
+ self.pattern = re.compile(
112
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
113
+ r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
114
+ r"\u202E\u202F]"
115
+ )
116
+ elif ruleset == "IDN.blacklist":
117
+
118
+ """Documentation:
119
+ [\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
120
+ set that are included in the IDN blacklist.
121
+ \uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
122
+ These characters are not allowed in domain names.
123
+ \uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
124
+ set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
125
+ and the second part is in the range U+DC00 to U+DFFF.
126
+ \uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
127
+ to U+DFFF, and is optional.
128
+ [\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
129
+ """
130
+
131
+ self.pattern = re.compile(
132
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
133
+ r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
134
+ )
135
+ else:
136
+ """Documentation:
137
+ This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
138
+ """
139
+ self.pattern = re.compile(r"[^\x00-\x7F]+")
140
+
141
+ def __call__(self, text: str) -> str:
142
+ text = unicodedata.normalize("NFC", text) # canon forms
143
+ text = self.pattern.sub(" ", text) # pattern match
144
+ text = re.sub(" +", " ", text) # collapse whitespaces
145
+ text = "".join(c for c in text if unicodedata.category(c) != "Cc") # Remove any remaining non-printable characters
146
+ return text
147
+
148
+
149
+ class TrueCaser:
150
+ """True-casing, is a capitalization normalization that returns text to its original capitalization.
151
+
152
+ This defends against attacks that wRIte TeXt lIkE spOngBoB.
153
+
154
+ Here, a simple POS-tagger is used.
155
+ """
156
+
157
+ uppercase_pos = ["PROPN"] # Name POS tags that should be upper-cased
158
+
159
+ def __init__(self, backend="spacy"):
160
+ if backend == "spacy":
161
+ spacy_model = "en_core_web_sm"
162
+ try:
163
+ import spacy
164
+ self.nlp = spacy.load(spacy_model)
165
+ except:
166
+ import spacy.cli
167
+ spacy.cli.download(spacy_model)
168
+ import spacy
169
+ self.nlp = spacy.load(spacy_model)
170
+
171
+ self.normalize_fn = self._spacy_truecasing
172
+ else:
173
+ from nltk import pos_tag, word_tokenize # noqa
174
+ import nltk
175
+
176
+ nltk.download("punkt")
177
+ nltk.download("averaged_perceptron_tagger")
178
+ nltk.download("universal_tagset")
179
+ self.normalize_fn = self._nltk_truecasing
180
+
181
+ def __call__(self, random_capitalized_string: str) -> str:
182
+ truecased_str = self.normalize_fn(random_capitalized_string)
183
+ return truecased_str
184
+
185
+ def _spacy_truecasing(self, random_capitalized_string: str):
186
+ doc = self.nlp(random_capitalized_string.lower())
187
+ POS = self.uppercase_pos
188
+ truecased_str = "".join([w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws for w in doc])
189
+ return truecased_str
190
+
191
+ def _nltk_truecasing(self, random_capitalized_string: str):
192
+ from nltk import pos_tag, word_tokenize
193
+ import nltk
194
+
195
+ nltk.download("punkt")
196
+ nltk.download("averaged_perceptron_tagger")
197
+ nltk.download("universal_tagset")
198
+ POS = ["NNP", "NNPS"]
199
+
200
+ tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
201
+ truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
202
+ return truecased_str