veryfansome commited on
Commit
360e354
·
verified ·
1 Parent(s): 226443e

Initial upload

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sp.vocab filter=lfs diff=lfs merge=lfs -text
dataset/o3-mini/20250218/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d78eaab378f462e45adbf026f13c3d5b9a289ddea75d399dbbf0c12b0c25c11e
3
+ size 40179808
dataset/o3-mini/20250218/dataset_info.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "text": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "tokens": {
10
+ "feature": {
11
+ "dtype": "string",
12
+ "_type": "Value"
13
+ },
14
+ "_type": "Sequence"
15
+ },
16
+ "adj": {
17
+ "feature": {
18
+ "dtype": "string",
19
+ "_type": "Value"
20
+ },
21
+ "_type": "Sequence"
22
+ },
23
+ "adv": {
24
+ "feature": {
25
+ "dtype": "string",
26
+ "_type": "Value"
27
+ },
28
+ "_type": "Sequence"
29
+ },
30
+ "det": {
31
+ "feature": {
32
+ "dtype": "string",
33
+ "_type": "Value"
34
+ },
35
+ "_type": "Sequence"
36
+ },
37
+ "enc": {
38
+ "feature": {
39
+ "dtype": "string",
40
+ "_type": "Value"
41
+ },
42
+ "_type": "Sequence"
43
+ },
44
+ "func": {
45
+ "feature": {
46
+ "dtype": "string",
47
+ "_type": "Value"
48
+ },
49
+ "_type": "Sequence"
50
+ },
51
+ "misc": {
52
+ "feature": {
53
+ "dtype": "string",
54
+ "_type": "Value"
55
+ },
56
+ "_type": "Sequence"
57
+ },
58
+ "ner1": {
59
+ "feature": {
60
+ "dtype": "string",
61
+ "_type": "Value"
62
+ },
63
+ "_type": "Sequence"
64
+ },
65
+ "ner2": {
66
+ "feature": {
67
+ "dtype": "string",
68
+ "_type": "Value"
69
+ },
70
+ "_type": "Sequence"
71
+ },
72
+ "noun": {
73
+ "feature": {
74
+ "dtype": "string",
75
+ "_type": "Value"
76
+ },
77
+ "_type": "Sequence"
78
+ },
79
+ "pronoun": {
80
+ "feature": {
81
+ "dtype": "string",
82
+ "_type": "Value"
83
+ },
84
+ "_type": "Sequence"
85
+ },
86
+ "punct": {
87
+ "feature": {
88
+ "dtype": "string",
89
+ "_type": "Value"
90
+ },
91
+ "_type": "Sequence"
92
+ },
93
+ "verb": {
94
+ "feature": {
95
+ "dtype": "string",
96
+ "_type": "Value"
97
+ },
98
+ "_type": "Sequence"
99
+ },
100
+ "wh": {
101
+ "feature": {
102
+ "dtype": "string",
103
+ "_type": "Value"
104
+ },
105
+ "_type": "Sequence"
106
+ }
107
+ },
108
+ "homepage": "",
109
+ "license": ""
110
+ }
dataset/o3-mini/20250218/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "4a79c58b9023cf85",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
dataset_maker.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from asyncio import Task
2
+
3
+ from datasets import Dataset, load_dataset
4
+ from openai import AsyncOpenAI
5
+ from traceback import format_exc
6
+ from typing import Union
7
+ import asyncio
8
+ import itertools
9
+ import json
10
+ import logging.config
11
+ import random
12
+ import sentencepiece as spm
13
+
14
+ from utils import default_logging_config
15
+
16
+ client = AsyncOpenAI()
17
+ logger = logging.getLogger(__name__)
18
+
19
+ sp = spm.SentencePieceProcessor()
20
+ sp.LoadFromFile(f"sp.model")
21
+
22
+ #OPENAI_MODEL = "gpt-4o"
23
+ #OPENAI_MODEL = "o1"
24
+ OPENAI_MODEL = "o3-mini"
25
+
26
+
27
+ async def classify_tokens(prompt: str, labels: dict[str, str], tokens: list[str]):
28
+ tok_len = len(tokens)
29
+ example = "[" + (", ".join([f'"{tok}"' for tok in tokens])) + "]"
30
+ try:
31
+ response = await client.chat.completions.create(
32
+ model=OPENAI_MODEL,
33
+ timeout=30,
34
+ #temperature=0,
35
+ #presence_penalty=0,
36
+ reasoning_effort="low",
37
+ messages=[
38
+ {
39
+ "role": "system",
40
+ "content": (
41
+ "Analyze the user provided sequence. Consider each string's semantic role "
42
+ f"in the given sequence, then return a list of {tok_len} label strings. "
43
+ f"Generate no more than {tok_len} labels. "
44
+ "When typos or out-of-order words are provided, infer the intended meaning."
45
+ ),
46
+ },
47
+ {
48
+ "role": "system",
49
+ "content": f"Labels: {labels}",
50
+ },
51
+ {
52
+ "role": "user",
53
+ "content": example,
54
+ },
55
+ {
56
+ "role": "user",
57
+ "content": (f"Replace each of the {tok_len} given strings with one of the following labels "
58
+ f"that best describes {prompt}: {sorted(labels.keys())}"),
59
+ },
60
+ ],
61
+ response_format={
62
+ "type": "json_schema",
63
+ "json_schema": {
64
+ "name": "labels",
65
+ "strict": True,
66
+ "schema": {
67
+ "type": "object",
68
+ "properties": {
69
+ "labels": {
70
+ "type": "array",
71
+ "description": f"List of {tok_len} labels, one for each string from the user's sequence.",
72
+ "items": {
73
+ "type": "string",
74
+ }
75
+ }
76
+ },
77
+ "additionalProperties": False,
78
+ "required": ["labels"]
79
+ }
80
+ }
81
+ },
82
+ )
83
+ except Exception as e:
84
+ logger.error(f"openai call failed: {format_exc()}")
85
+ raise
86
+ try:
87
+ return json.loads(response.choices[0].message.content)["labels"]
88
+ except Exception as e:
89
+ logger.error(f"response: {response.choices[0].message} {format_exc()}")
90
+ raise
91
+
92
+
93
+ async def classify_with_retry(prompt, labels, tokens, retry=10):
94
+ for i in range(retry):
95
+ try:
96
+ return await classify_tokens(prompt, labels, tokens)
97
+ except Exception as e:
98
+ logger.error(f"attempt {i} failed {tokens} {prompt} {format_exc()}")
99
+ await asyncio.sleep(i)
100
+
101
+
102
+ async def generate_token_labels(case):
103
+ tokens = list(itertools.chain.from_iterable([s.strip("▁").split("▁") for s in sp.EncodeAsPieces(case)]))
104
+ (
105
+ adj,
106
+ adv,
107
+ det,
108
+ enc,
109
+ func,
110
+ misc,
111
+ ner1,
112
+ ner2,
113
+ noun,
114
+ pronoun,
115
+ punct,
116
+ verb,
117
+ wh,
118
+ ) = await asyncio.gather(
119
+ classify_with_retry(
120
+ f"its semantic role",
121
+ {
122
+ "JJ": "adjective",
123
+ "JJR": "comparative adjective",
124
+ "JJS": "superlative adjective",
125
+ "O": "out-of-scope",
126
+ },
127
+ tokens),
128
+ classify_with_retry(
129
+ f"its semantic role",
130
+ {
131
+ "RB": "adverb",
132
+ "RBR": "comparative adverb",
133
+ "RBS": "superlative adverb",
134
+ "O": "out-of-scope",
135
+ },
136
+ tokens),
137
+ classify_with_retry(
138
+ f"its semantic role",
139
+ {
140
+ "DT": "articles, demonstratives, and other determiners",
141
+ "EX": "existential 'there'",
142
+ "PDT": "predeterminer before a determiner to modify a noun phrase",
143
+ "O": "out-of-scope",
144
+ },
145
+ tokens),
146
+ classify_with_retry(
147
+ f"its sentence chunk classification",
148
+ {
149
+ "BRACKET": "in or contains bracket wrapped text",
150
+ "QUOTE": "in or contains quote wrapped text",
151
+ "TICK": "in or contains backtick wrapped text",
152
+ "O": "out-of-scope",
153
+ },
154
+ tokens),
155
+ classify_with_retry(
156
+ f"its semantic role",
157
+ {
158
+ "CC": "coordinating conjunction",
159
+ "IN": "preposition or subordinating conjunction",
160
+ "RP": "particle",
161
+ "TO": "to",
162
+ "UH": "interjection",
163
+ "O": "out-of-scope",
164
+ },
165
+ tokens),
166
+ classify_with_retry(
167
+ f"its semantic role",
168
+ {
169
+ "$": "currency",
170
+ "ADD": "address, URLs, usernames, or other non-lexical representations of places or entities",
171
+ "CD": "cardinal numbers",
172
+ "EMOJI": "emoji",
173
+ "TIME": "date or time",
174
+ "O": "out-of-scope",
175
+ },
176
+ tokens),
177
+ classify_with_retry(
178
+ f"its NER classification",
179
+ {
180
+ "B-GPE": "beginning of geopolitical entities",
181
+ "I-GPE": "inside of geopolitical entities",
182
+ "B-ORG": "beginning of organization",
183
+ "I-ORG": "inside of organization",
184
+ "B-PER": "beginning of person",
185
+ "I-PER": "inside of person",
186
+ "O": "out-of-scope",
187
+ },
188
+ tokens),
189
+ classify_with_retry(
190
+ f"its NER classification",
191
+ {
192
+ "B-EVENT": "beginning of event",
193
+ "I-EVENT": "inside of event",
194
+ "B-LOC": "beginning of location",
195
+ "I-LOC": "inside of location",
196
+ "O": "out-of-scope",
197
+ },
198
+ tokens),
199
+ classify_with_retry(
200
+ f"its semantic role",
201
+ {
202
+ "NN": "common noun singular",
203
+ "NNS": "common noun plural",
204
+ "NNP": "proper noun singular",
205
+ "NNPS": "proper noun plural",
206
+ "O": "out-of-scope",
207
+ },
208
+ tokens),
209
+ classify_with_retry(
210
+ f"its semantic role",
211
+ {
212
+ "POS": "possessive ending like the 's",
213
+ "PRP$": "possessive pronoun",
214
+ "PRP": "personal pronoun",
215
+ "O": "out-of-scope",
216
+ },
217
+ tokens),
218
+ classify_with_retry(
219
+ f"the punctuation classes it contains",
220
+ {
221
+ "COLON": "colon or semicolon",
222
+ "COMMA": "comma",
223
+ "EXCLAIM": "exclamation mark",
224
+ "HYPH": "dash or hyphen",
225
+ "LS": "list item marker",
226
+ "PERIOD": "period",
227
+ "QUESTION": "question mark",
228
+ "SEP": "any section separator",
229
+ "O": "out-of-scope",
230
+ },
231
+ tokens),
232
+ classify_with_retry(
233
+ f"its semantic role",
234
+ {
235
+ "MD": "modal verb",
236
+ "VB": "verb base form",
237
+ "VBD": "verb past tense",
238
+ "VBG": "present participle, gerund",
239
+ "VBN": "past participle",
240
+ "VBP": "non-3rd person singular present",
241
+ "VBZ": "3rd person singular present",
242
+ "O": "out-of-scope",
243
+ },
244
+ tokens),
245
+ classify_with_retry(
246
+ f"its semantic role",
247
+ {
248
+ "WDT": "Wh-determiner",
249
+ "WP$": "Wh-possessive pronoun",
250
+ "WP": "Wh-pronoun",
251
+ "WRB": "Wh-adverb",
252
+ "O": "out-of-scope",
253
+ },
254
+ tokens),
255
+ )
256
+ return {
257
+ "text": case,
258
+ "tokens": tokens,
259
+ "adj": adj,
260
+ "adv": adv,
261
+ "det": det,
262
+ "enc": enc,
263
+ "func": func,
264
+ "misc": misc,
265
+ "ner1": ner1,
266
+ "ner2": ner2,
267
+ "noun": noun,
268
+ "pronoun": pronoun,
269
+ "punct": punct,
270
+ "verb": verb,
271
+ "wh": wh,
272
+ }
273
+
274
+
275
+ async def main(cases):
276
+ ds_dict = {
277
+ "text": [],
278
+ "tokens": [],
279
+ "adj": [],
280
+ "adv": [],
281
+ "det": [],
282
+ "enc": [],
283
+ "func": [],
284
+ "misc": [],
285
+ "ner1": [],
286
+ "ner2": [],
287
+ "noun": [],
288
+ "pronoun": [],
289
+ "punct": [],
290
+ "verb": [],
291
+ "wh": [],
292
+ }
293
+ drain_completed = False
294
+ max_concurrent_tasks = 15
295
+ tasks: list[Union[Task, None]] = []
296
+
297
+ async def checkpoint_task():
298
+ while not drain_completed:
299
+ # checkpoint
300
+ _ds = Dataset.from_dict(ds_dict)
301
+ _ds.save_to_disk("./custom_checkpoint_data")
302
+ logger.info(f"\n{_ds}")
303
+ await asyncio.sleep(600)
304
+ future_checkpoint_task_completion = asyncio.create_task(checkpoint_task())
305
+
306
+ async def drain_tasks():
307
+ while not drain_completed:
308
+ for idx, task in enumerate(tasks):
309
+ if task is None:
310
+ continue
311
+ try:
312
+ logger.info(f"attempting Example {idx}")
313
+ example = await asyncio.wait_for(task, timeout=180)
314
+ for col, labels in example.items():
315
+ logger.info(f" {col}: {labels}")
316
+ ds_dict[col].append(example[col])
317
+ tasks[idx] = None
318
+ except asyncio.exceptions.TimeoutError:
319
+ logger.warning(f"attempt to wait_for Example {idx} timed out after 10 seconds.")
320
+ except Exception as e:
321
+ logger.error(f"attempt to generate Example {idx} failed.\n{format_exc()}")
322
+ tasks[idx] = None
323
+ raise
324
+ await asyncio.sleep(1)
325
+ future_drain_completion = asyncio.create_task(drain_tasks())
326
+
327
+ for case in cases:
328
+ while len([t for t in tasks if t is not None]) >= max_concurrent_tasks:
329
+ await asyncio.sleep(1)
330
+ logger.info(f"scheduling case {case}")
331
+ tasks.append(asyncio.create_task(generate_token_labels(case)))
332
+
333
+ # Block until done
334
+ while len([t for t in tasks if t is not None]) > 0:
335
+ logger.info(f"waiting on {len([t for t in tasks if t is not None])} tasks")
336
+ await asyncio.sleep(1)
337
+
338
+ drain_completed = True
339
+ await future_drain_completion
340
+ await future_checkpoint_task_completion
341
+
342
+ ds = Dataset.from_dict(ds_dict)
343
+ ds.save_to_disk("./custom_final_data")
344
+ logger.info(f"\n{ds}")
345
+
346
+
347
+ if __name__ == "__main__":
348
+ logging.config.dictConfig(default_logging_config)
349
+
350
+ all_examples = []
351
+
352
+ ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt")
353
+ all_examples += ud_en_ewt_ds["train"]["text"]
354
+ all_examples += ud_en_ewt_ds["validation"]["text"]
355
+ all_examples += ud_en_ewt_ds["test"]["text"]
356
+
357
+ ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum")
358
+ all_examples += ud_en_gum_ds["train"]["text"]
359
+ all_examples += ud_en_gum_ds["validation"]["text"]
360
+ all_examples += ud_en_gum_ds["test"]["text"]
361
+
362
+ ud_en_lines_ds = load_dataset("universal_dependencies", "en_lines")
363
+ all_examples += ud_en_lines_ds["train"]["text"]
364
+ all_examples += ud_en_lines_ds["validation"]["text"]
365
+ all_examples += ud_en_lines_ds["test"]["text"]
366
+
367
+ ud_en_partut_ds = load_dataset("universal_dependencies", "en_partut")
368
+ all_examples += ud_en_partut_ds["train"]["text"]
369
+ all_examples += ud_en_partut_ds["validation"]["text"]
370
+ all_examples += ud_en_partut_ds["test"]["text"]
371
+
372
+ ud_en_pronouns_ds = load_dataset("universal_dependencies", "en_pronouns")
373
+ all_examples += ud_en_pronouns_ds["test"]["text"]
374
+
375
+ ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud")
376
+ all_examples += ud_en_pud_ds["test"]["text"]
377
+
378
+ logger.info(f"{len(all_examples)} UD examples")
379
+ random.shuffle(all_examples)
380
+ logger.info(f"{all_examples[:10]}")
381
+
382
+ all_examples += [
383
+ "Hello world!",
384
+ "127.0.0.1 is the localhost address.",
385
+ "1/2 is equivalent to 0.5 or 50%",
386
+ "John was running so fast, you can just tell he's a runner.",
387
+ "He excels at math and competed in the Math Olympiad",
388
+ "They're only $5!",
389
+ "Where is your sense of adventure?",
390
+ "I have only 3 cents.",
391
+ "Watson was on his way to 221B Baker Street when the robbery occurred.",
392
+ "That's uncopyrightable.",
393
+ "She's full of incomprehensibilities.",
394
+ "He's a total sesquipedalian.",
395
+ "you piece of SHIT!!",
396
+ "uh........... what..",
397
+ "Steph Curry is GOAT!!",
398
+ "[click here!](http://www.google.com)",
399
+ "Dude! The stock's value grew like 10x in a year!",
400
+ "Yea, I was at the DMV - God what a shit show!",
401
+ "Send an email to help@example.com",
402
+ "@goober, take your question to #corp-help-desk",
403
+ "Example 1 : Joe Shmoe has a big toe.",
404
+ "1. Steal under-pants. 2. ... 3. Profit!",
405
+ "I expect `len(word_list) == 3`",
406
+ "Home | Shop | Contact Us",
407
+ "This is me on cake <(^.^)>",
408
+ "and then he fell right on his face 😂",
409
+ "Putin is from Russia.",
410
+ "Zelenskyy is Ukrainian",
411
+ "In 2013, the Pentagon and other agencies officially acknowledged the existence of Area-51.",
412
+ "The Freedom of Information Act gives us the right to request access to records from any federal agency.",
413
+ "His motives here are totally sus",
414
+ "Yea, he finished by doing the Dab",
415
+ "Be back i'll",
416
+ "hes got a bon to pick",
417
+ "so then he says then he says, \"you'll regret this\" lol",
418
+ ]
419
+ asyncio.run(main(all_examples))
420
+
multi_task_classifier.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import DatasetDict, load_from_disk
2
+ from sklearn.metrics import classification_report, precision_recall_fscore_support
3
+ from transformers import (
4
+ DebertaV2Config,
5
+ DebertaV2Model,
6
+ DebertaV2PreTrainedModel,
7
+ DebertaV2TokenizerFast,
8
+ Trainer,
9
+ TrainingArguments,
10
+ )
11
+ import argparse
12
+ import logging.config
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from ud_training_data_maker import get_uniq_training_labels, show_examples
18
+ from utils import default_logging_config
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
23
+ arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.",
24
+ action="store", type=int, default=8)
25
+ arg_parser.add_argument("--data-only", help='Show training data info and exit.',
26
+ action="store_true", default=False)
27
+ arg_parser.add_argument("--data-path", help="Load training dataset from specified path.",
28
+ action="store", default="./training_data")
29
+ arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.",
30
+ action="store", type=int, default=3)
31
+ arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.",
32
+ action="store", type=int, default=2)
33
+ arg_parser.add_argument("--from-base", help="Load a base model.",
34
+ action="store", default=None,
35
+ choices=[
36
+ "microsoft/deberta-v3-base", # Requires --deberta-v3
37
+ "microsoft/deberta-v3-large", # Requires --deberta-v3
38
+ # More?
39
+ ])
40
+ arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.",
41
+ action="store", type=float, default=5e-5)
42
+ arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.',
43
+ action="store_true", default=False)
44
+ arg_parser.add_argument("--save-path", help="Save final model to specified path.",
45
+ action="store", default="./final")
46
+ arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>",
47
+ action="store", default=None)
48
+ arg_parser.add_argument("--train", help='Train model using loaded examples.',
49
+ action="store_true", default=False)
50
+ arg_parser.add_argument("-T", "--train-batch-size", help="Per device train batch size.",
51
+ action="store", type=int, default=2)
52
+ args = arg_parser.parse_args()
53
+ logging.config.dictConfig(default_logging_config)
54
+ logger.info(f"Args {args}")
55
+
56
+
57
+
58
+ # ------------------------------------------------------------------------------
59
+ # Load dataset and show examples for manual inspection
60
+ # ------------------------------------------------------------------------------
61
+
62
+ loaded_dataset = load_from_disk(args.data_path)
63
+ show_examples(loaded_dataset, args.show)
64
+
65
+ # ------------------------------------------------------------------------------
66
+ # Convert label analysis data into label sets for each head
67
+ # ------------------------------------------------------------------------------
68
+
69
+ ALL_LABELS = {col: list(vals) for col, vals in get_uniq_training_labels(loaded_dataset).items()}
70
+
71
+ LABEL2ID = {
72
+ feat_name: {label: i for i, label in enumerate(ALL_LABELS[feat_name])}
73
+ for feat_name in ALL_LABELS
74
+ }
75
+ ID2LABEL = {
76
+ feat_name: {i: label for label, i in LABEL2ID[feat_name].items()}
77
+ for feat_name in LABEL2ID
78
+ }
79
+
80
+ # Each head's number of labels:
81
+ NUM_LABELS_DICT = {k: len(v) for k, v in ALL_LABELS.items()}
82
+
83
+ if args.data_only:
84
+ exit()
85
+
86
+ # ------------------------------------------------------------------------------
87
+ # Create a custom config that can store our multi-label info
88
+ # ------------------------------------------------------------------------------
89
+
90
+ class MultiHeadModelConfig(DebertaV2Config):
91
+ def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
92
+ super().__init__(**kwargs)
93
+ self.label_maps = label_maps or {}
94
+ self.num_labels_dict = num_labels_dict or {}
95
+
96
+ def to_dict(self):
97
+ output = super().to_dict()
98
+ output["label_maps"] = self.label_maps
99
+ output["num_labels_dict"] = self.num_labels_dict
100
+ return output
101
+
102
+ # ------------------------------------------------------------------------------
103
+ # Define a multi-head model
104
+ # ------------------------------------------------------------------------------
105
+
106
+ class MultiHeadModel(DebertaV2PreTrainedModel):
107
+ def __init__(self, config: MultiHeadModelConfig):
108
+ super().__init__(config)
109
+
110
+ self.deberta = DebertaV2Model(config)
111
+ self.classifiers = nn.ModuleDict()
112
+
113
+ hidden_size = config.hidden_size
114
+ for label_name, n_labels in config.num_labels_dict.items():
115
+ self.classifiers[label_name] = nn.Linear(hidden_size, n_labels)
116
+
117
+ # Initialize newly added weights
118
+ self.post_init()
119
+
120
+ def forward(
121
+ self,
122
+ input_ids=None,
123
+ attention_mask=None,
124
+ token_type_ids=None,
125
+ labels_dict=None,
126
+ **kwargs
127
+ ):
128
+ """
129
+ labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
130
+ If provided, we compute and return the sum of CE losses.
131
+ """
132
+ outputs = self.deberta(
133
+ input_ids=input_ids,
134
+ attention_mask=attention_mask,
135
+ token_type_ids=token_type_ids,
136
+ **kwargs
137
+ )
138
+
139
+ sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
140
+
141
+ logits_dict = {}
142
+ for label_name, classifier in self.classifiers.items():
143
+ logits_dict[label_name] = classifier(sequence_output)
144
+
145
+ total_loss = None
146
+ loss_dict = {}
147
+ if labels_dict is not None:
148
+ # We'll sum the losses from each head
149
+ loss_fct = nn.CrossEntropyLoss()
150
+ total_loss = 0.0
151
+
152
+ for label_name, logits in logits_dict.items():
153
+ if label_name not in labels_dict:
154
+ continue
155
+ label_ids = labels_dict[label_name]
156
+
157
+ # A typical approach for token classification:
158
+ # We ignore positions where label_ids == -100
159
+ active_loss = label_ids != -100 # shape (bs, seq_len)
160
+
161
+ # flatten everything
162
+ active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
163
+ active_labels = label_ids.view(-1)[active_loss.view(-1)]
164
+
165
+ loss = loss_fct(active_logits, active_labels)
166
+ loss_dict[label_name] = loss.item()
167
+ total_loss += loss
168
+
169
+ if labels_dict is not None:
170
+ # return (loss, predictions)
171
+ return total_loss, logits_dict
172
+ else:
173
+ # just return predictions
174
+ return logits_dict
175
+
176
+ # ------------------------------------------------------------------------------
177
+ # Tokenize with max_length=512, stride=128, and subword alignment
178
+ # ------------------------------------------------------------------------------
179
+
180
+ def tokenize_and_align_labels(examples):
181
+ """
182
+ For each example, the tokenizer may produce multiple overlapping
183
+ chunks if the tokens exceed 512 subwords. Each chunk will be
184
+ length=512, with a stride=128 for the next chunk.
185
+ We'll align labels so that subwords beyond the first in a token get -100.
186
+ """
187
+ # We rely on is_split_into_words=True because examples["tokens"] is a list of token strings.
188
+ tokenized_batch = tokenizer(
189
+ examples["tokens"],
190
+ is_split_into_words=True,
191
+ max_length=512,
192
+ stride=128,
193
+ truncation=True,
194
+ return_overflowing_tokens=True,
195
+ return_offsets_mapping=False, # not mandatory for basic alignment
196
+ padding="max_length"
197
+ )
198
+
199
+ # The tokenizer returns "overflow_to_sample_mapping", telling us
200
+ # which original example index each chunk corresponds to.
201
+ # If the tokenizer didn't need to create overflows, the key might be missing
202
+ if "overflow_to_sample_mapping" not in tokenized_batch:
203
+ # No overflow => each input corresponds 1:1 with the original example
204
+ sample_map = [i for i in range(len(tokenized_batch["input_ids"]))]
205
+ else:
206
+ sample_map = tokenized_batch["overflow_to_sample_mapping"]
207
+
208
+ # We'll build lists for final outputs.
209
+ # For each chunk i, we produce:
210
+ # "input_ids"[i], "attention_mask"[i], plus per-feature label IDs.
211
+ final_input_ids = []
212
+ final_attention_mask = []
213
+ final_labels_columns = {feat: [] for feat in ALL_LABELS} # store one label-sequence per chunk
214
+
215
+ for i in range(len(tokenized_batch["input_ids"])):
216
+ # chunk i
217
+ chunk_input_ids = tokenized_batch["input_ids"][i]
218
+ chunk_attn_mask = tokenized_batch["attention_mask"][i]
219
+
220
+ original_index = sample_map[i] # which example in the original batch
221
+ word_ids = tokenized_batch.word_ids(batch_index=i)
222
+
223
+ # We'll build label arrays for each feature
224
+ chunk_labels_dict = {}
225
+
226
+ for feat_name in ALL_LABELS:
227
+ # The UD token-level labels for the *original* example
228
+ token_labels = examples[feat_name][original_index] # e.g. length T
229
+ chunk_label_ids = []
230
+
231
+ previous_word_id = None
232
+ for w_id in word_ids:
233
+ if w_id is None:
234
+ # special token (CLS, SEP, padding)
235
+ chunk_label_ids.append(-100)
236
+ else:
237
+ # If it's the same word_id as before, it's a subword => label = -100
238
+ if w_id == previous_word_id:
239
+ chunk_label_ids.append(-100)
240
+ else:
241
+ # New token => use the actual label
242
+ label_str = token_labels[w_id]
243
+ label_id = LABEL2ID[feat_name][label_str]
244
+ chunk_label_ids.append(label_id)
245
+ previous_word_id = w_id
246
+
247
+ chunk_labels_dict[feat_name] = chunk_label_ids
248
+
249
+ final_input_ids.append(chunk_input_ids)
250
+ final_attention_mask.append(chunk_attn_mask)
251
+ for feat_name in ALL_LABELS:
252
+ final_labels_columns[feat_name].append(chunk_labels_dict[feat_name])
253
+
254
+ # Return the new "flattened" set of chunks
255
+ # So the "map" call will expand each example → multiple chunk examples.
256
+ result = {
257
+ "input_ids": final_input_ids,
258
+ "attention_mask": final_attention_mask,
259
+ }
260
+ # We'll store each feature's label IDs in separate columns (e.g. labels_xpos, labels_deprel, etc.)
261
+ for feat_name in ALL_LABELS:
262
+ result[f"labels_{feat_name}"] = final_labels_columns[feat_name]
263
+
264
+ return result
265
+
266
+ # ------------------------------------------------------------------------------
267
+ # Trainer Setup
268
+ # ------------------------------------------------------------------------------
269
+
270
+ class MultiHeadTrainer(Trainer):
271
+
272
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
273
+ # 1) Gather all your per-feature labels from inputs
274
+ _labels_dict = {}
275
+ for feat_name in ALL_LABELS:
276
+ key = f"labels_{feat_name}"
277
+ if key in inputs:
278
+ _labels_dict[feat_name] = inputs[key]
279
+
280
+ # 2) Remove them so they don't get passed incorrectly to the model
281
+ for key in list(inputs.keys()):
282
+ if key.startswith("labels_"):
283
+ del inputs[key]
284
+
285
+ # 3) Call model(...) with _labels_dict
286
+ outputs = model(**inputs, labels_dict=_labels_dict)
287
+ # 'outputs' is (loss, logits_dict) in training/eval mode
288
+ loss, logits_dict = outputs
289
+
290
+ # Optional: if your special param is used upstream for some logic,
291
+ # you can handle it here or pass it along. For example:
292
+ if num_items_in_batch is not None:
293
+ # ... do something if needed ...
294
+ pass
295
+
296
+ if return_outputs:
297
+ # Return (loss, logits_dict) so Trainer sees logits_dict as predictions
298
+ return (loss, logits_dict)
299
+ else:
300
+ return loss
301
+
302
+ def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
303
+ # 1) gather the "labels_xxx" columns
304
+ _labels_dict = {}
305
+ for feat_name in ALL_LABELS:
306
+ key = f"labels_{feat_name}"
307
+ if key in inputs:
308
+ _labels_dict[feat_name] = inputs[key]
309
+ del inputs[key]
310
+
311
+ # 2) forward pass without those keys
312
+ with torch.no_grad():
313
+ outputs = model(**inputs, labels_dict=_labels_dict)
314
+
315
+ loss, logits_dict = outputs # you are returning (loss, dict-of-arrays)
316
+
317
+ if prediction_loss_only:
318
+ return (loss, None, None)
319
+
320
+ # The trainer expects a triple: (loss, predictions, labels)
321
+ # - 'predictions' can be the dictionary
322
+ # - 'labels' can be the dictionary of label IDs
323
+ return (loss, logits_dict, _labels_dict)
324
+
325
+
326
+ def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict):
327
+ """
328
+ For each head, generate a classification report (precision, recall, f1, etc. per class).
329
+ Return them as a dict: {head_name: "string report"}.
330
+ :param logits_dict: dict of {head_name: np.array(batch_size, seq_len, num_classes)}
331
+ :param labels_dict: dict of {head_name: np.array(batch_size, seq_len)}
332
+ :param id2label_dict: dict of {head_name: {id: label_str}}
333
+ :return: A dict of classification-report strings, one per head.
334
+ """
335
+ reports = {}
336
+
337
+ for head_name, logits in logits_dict.items():
338
+ if head_name not in labels_dict:
339
+ continue
340
+
341
+ predictions = np.argmax(logits, axis=-1)
342
+ valid_preds, valid_labels = [], []
343
+ for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
344
+ for p, lab in zip(pred_seq, label_seq):
345
+ if lab != -100:
346
+ valid_preds.append(p)
347
+ valid_labels.append(lab)
348
+
349
+ if len(valid_preds) == 0:
350
+ reports[head_name] = "No valid predictions."
351
+ continue
352
+
353
+ # Convert numeric IDs to string labels
354
+ valid_preds_str = [id2label_dict[head_name][p] for p in valid_preds]
355
+ valid_labels_str = [id2label_dict[head_name][l] for l in valid_labels]
356
+
357
+ # Generate the per-class classification report
358
+ report_str = classification_report(
359
+ valid_labels_str,
360
+ valid_preds_str,
361
+ zero_division=0
362
+ )
363
+ reports[head_name] = report_str
364
+
365
+ return reports
366
+
367
+
368
+ def multi_head_compute_metrics(logits_dict, labels_dict):
369
+ """
370
+ For each head (e.g. xpos, deprel, Case, etc.), computes:
371
+ - Accuracy
372
+ - Precision (macro/micro)
373
+ - Recall (macro/micro)
374
+ - F1 (macro/micro)
375
+
376
+ :param logits_dict: dict of {head_name: np.array of shape (batch_size, seq_len, num_classes)}
377
+ :param labels_dict: dict of {head_name: np.array of shape (batch_size, seq_len)}
378
+ :return: A dict with aggregated metrics. Keys prefixed by head_name, e.g. "xpos_accuracy", "xpos_f1_macro", etc.
379
+ """
380
+ # We'll accumulate metrics in one big dictionary, keyed by "<head>_<metric>"
381
+ results = {}
382
+
383
+ for head_name, logits in logits_dict.items():
384
+ if head_name not in labels_dict:
385
+ # In case there's a mismatch or a head we didn't provide labels for
386
+ continue
387
+
388
+ # (batch_size, seq_len, num_classes)
389
+ predictions = np.argmax(logits, axis=-1) # => (batch_size, seq_len)
390
+
391
+ # Flatten ignoring positions where label == -100
392
+ valid_preds, valid_labels = [], []
393
+ for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
394
+ for p, lab in zip(pred_seq, label_seq):
395
+ if lab != -100:
396
+ valid_preds.append(p)
397
+ valid_labels.append(lab)
398
+
399
+ valid_preds = np.array(valid_preds)
400
+ valid_labels = np.array(valid_labels)
401
+
402
+ if len(valid_preds) == 0:
403
+ # No valid data for this head—skip
404
+ continue
405
+
406
+ # Overall token-level accuracy
407
+ accuracy = (valid_preds == valid_labels).mean()
408
+
409
+ # Macro average => treat each class equally
410
+ precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
411
+ valid_labels, valid_preds, average="macro", zero_division=0
412
+ )
413
+
414
+ # Micro average => aggregate across all classes
415
+ precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
416
+ valid_labels, valid_preds, average="micro", zero_division=0
417
+ )
418
+
419
+ results[f"{head_name}_accuracy"] = accuracy
420
+ results[f"{head_name}_precision_macro"] = precision_macro
421
+ results[f"{head_name}_recall_macro"] = recall_macro
422
+ results[f"{head_name}_f1_macro"] = f1_macro
423
+ results[f"{head_name}_precision_micro"] = precision_micro
424
+ results[f"{head_name}_recall_micro"] = recall_micro
425
+ results[f"{head_name}_f1_micro"] = f1_micro
426
+
427
+ return results
428
+
429
+ # ------------------------------------------------------------------------------
430
+ # Instantiate model and tokenizer
431
+ # ------------------------------------------------------------------------------
432
+
433
+ if args.from_base:
434
+ model_name_or_path = args.from_base
435
+ multi_head_model = MultiHeadModel.from_pretrained(
436
+ model_name_or_path,
437
+ config=MultiHeadModelConfig.from_pretrained(
438
+ model_name_or_path,
439
+ num_labels_dict=NUM_LABELS_DICT,
440
+ label_maps=ALL_LABELS
441
+ )
442
+ )
443
+ else:
444
+ model_name_or_path = args.save_path
445
+ # For evaluation, always load the saved checkpoint without overriding the config.
446
+ multi_head_model = MultiHeadModel.from_pretrained(model_name_or_path)
447
+ # EXTREMELY IMPORTANT!
448
+ # Override the label mapping based on the stored config to ensure consistency with training time ordering.
449
+ ALL_LABELS = multi_head_model.config.label_maps
450
+ LABEL2ID = {feat: {label: i for i, label in enumerate(ALL_LABELS[feat])} for feat in ALL_LABELS}
451
+ ID2LABEL = {feat: {i: label for label, i in LABEL2ID[feat].items()} for feat in LABEL2ID}
452
+ logger.info(f"using {model_name_or_path}")
453
+
454
+ # Check if GPU is usable
455
+ if torch.cuda.is_available():
456
+ device = torch.device("cuda")
457
+ elif torch.backends.mps.is_available(): # For Apple Silicon MPS
458
+ device = torch.device("mps")
459
+ else:
460
+ device = torch.device("cpu")
461
+ logger.info(f"using {device}")
462
+ multi_head_model.to(device)
463
+
464
+ tokenizer = DebertaV2TokenizerFast.from_pretrained(
465
+ model_name_or_path,
466
+ add_prefix_space=True,
467
+ )
468
+
469
+ # ------------------------------------------------------------------------------
470
+ # Shuffle, (optionally) sample, and tokenize final merged dataset
471
+ # ------------------------------------------------------------------------------
472
+
473
+ if args.mini:
474
+ loaded_dataset = DatasetDict({
475
+ "train": loaded_dataset["train"].shuffle(seed=42).select(range(1000)),
476
+ "validation": loaded_dataset["validation"].shuffle(seed=42).select(range(100)),
477
+ "test": loaded_dataset["test"].shuffle(seed=42).select(range(100)),
478
+ })
479
+
480
+ # remove_columns => remove old "text", "tokens", etc. so we keep only model inputs
481
+ tokenized_dataset = loaded_dataset.map(
482
+ tokenize_and_align_labels,
483
+ batched=True,
484
+ remove_columns=loaded_dataset["train"].column_names,
485
+ )
486
+
487
+ # ------------------------------------------------------------------------------
488
+ # Train the model!
489
+ # ------------------------------------------------------------------------------
490
+
491
+ """
492
+ Current bests:
493
+
494
+ deberta-v3-base:
495
+ num_train_epochs=3,
496
+ learning_rate=5e-5,
497
+ per_device_train_batch_size=2,
498
+ gradient_accumulation_steps=8,
499
+ """
500
+
501
+ training_args = TrainingArguments(
502
+ # Evaluate less frequently or keep the same
503
+ eval_strategy="epoch",
504
+ num_train_epochs=args.train_epochs,
505
+ learning_rate=args.learning_rate,
506
+
507
+ output_dir="training_output",
508
+ overwrite_output_dir=True,
509
+ remove_unused_columns=False, # important to keep the labels_xxx columns
510
+
511
+ logging_dir="training_logs",
512
+ logging_steps=100,
513
+
514
+ # Effective batch size = train_batch_size x gradient_accumulation_steps
515
+ per_device_train_batch_size=args.train_batch_size,
516
+ gradient_accumulation_steps=args.accumulation_steps,
517
+
518
+ per_device_eval_batch_size=args.eval_batch_size,
519
+ )
520
+
521
+ trainer = MultiHeadTrainer(
522
+ model=multi_head_model,
523
+ args=training_args,
524
+ train_dataset=tokenized_dataset["train"],
525
+ eval_dataset=tokenized_dataset["validation"],
526
+ )
527
+
528
+ if args.train:
529
+ trainer.train()
530
+ trainer.evaluate()
531
+ trainer.save_model(args.save_path)
532
+ tokenizer.save_pretrained(args.save_path)
533
+
534
+ # ------------------------------------------------------------------------------
535
+ # Evaluate the model!
536
+ # ------------------------------------------------------------------------------
537
+
538
+ pred_output = trainer.predict(tokenized_dataset["test"])
539
+ pred_logits_dict = pred_output.predictions
540
+ pred_labels_dict = pred_output.label_ids
541
+ id2label_dict = ID2LABEL # from earlier definitions
542
+
543
+ # 1) Calculate metrics
544
+ metrics = multi_head_compute_metrics(pred_logits_dict, pred_labels_dict)
545
+ for k,v in metrics.items():
546
+ print(f"{k}: {v:.4f}")
547
+
548
+ # 2) Print classification reports
549
+ reports = multi_head_classification_reports(pred_logits_dict, pred_labels_dict, id2label_dict)
550
+ for head_name, rstr in reports.items():
551
+ print(f"----- {head_name} classification report -----")
552
+ print(rstr)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.26.0 # Required for using Trainer with PyTorch
2
+ conllu
3
+ dataset
4
+ datasets
5
+ numpy
6
+ openai
7
+ pytest
8
+ pytest-cov
9
+ pytest-xdist
10
+ scikit-learn
11
+ sentencepiece
12
+ tensorflow
13
+ tf-keras # Required until transformers supports Keras 3
14
+ tiktoken
15
+ torch
16
+ transformers
sp.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2676ad813627497b95ce13c8ebe6b3313391c6df4b75909b5d6f68dcdde716b
3
+ size 18104223
sp.vocab ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3a11823032d025ecd19a1e6bfef167b9a9ef6489d81eff726d4b399a20163ce
3
+ size 18715604
sp_model_maker.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import argparse
3
+ import logging.config
4
+ import os
5
+ import random
6
+ import re
7
+ import sentencepiece as spm
8
+
9
+ from utils import default_logging_config
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ arg_parser = argparse.ArgumentParser(description="Train a sentencepiece tokenization model.")
14
+ arg_parser.add_argument("--train", action="store_true", default=False,
15
+ help="Train a sentencepiece tokenization model.")
16
+ arg_parser.add_argument("--wikipedia", action="store_true", default=False,
17
+ help="Use wikipedia dataset.")
18
+ args = arg_parser.parse_args()
19
+ logging.config.dictConfig(default_logging_config)
20
+
21
+ input_sentence_size = 9_000_000
22
+ max_line_char_len = 4192
23
+ vocab_size = 900_000
24
+
25
+ corpus_dir = "sp_data"
26
+ corpus_file_prefix = f"{corpus_dir}/sp_corpus"
27
+ model_file_prefix = "sp"
28
+ uber_chunk_file = f"{corpus_dir}/wikipedia_uber_chunks.txt"
29
+ white_space_pattern = re.compile(r"\s+")
30
+
31
+ if args.wikipedia:
32
+ wikipedia_dataset_name = "20231101.en"
33
+ wikipedia_dataset = load_dataset("wikimedia/wikipedia", wikipedia_dataset_name)
34
+ total_page_cnt = len(wikipedia_dataset["train"])
35
+ logger.info(f"loaded {wikipedia_dataset_name} containing {total_page_cnt} pages")
36
+
37
+ max_processed_pages = total_page_cnt # Change to single digits for spot checking / debugging
38
+ pages_processed_cnt = 0
39
+
40
+ corpus_file_part_idx = 0
41
+ current_corpus_file_char_len = 0
42
+ is_completed = False
43
+ iter_idx = 0
44
+ while not is_completed: # Do till completed
45
+ with open(f"{corpus_file_prefix}_{corpus_file_part_idx}.txt", "a", encoding="utf-8") as f:
46
+ while iter_idx < (total_page_cnt - 1):
47
+ page = wikipedia_dataset["train"][iter_idx]
48
+ page_char_len = len(page["text"]) # Character len because bytes requires encoding
49
+ if page_char_len + current_corpus_file_char_len > 1_000_000_000:
50
+ corpus_file_part_idx += 1 # New partition
51
+ current_corpus_file_char_len = 0 # Reset tally
52
+ break
53
+
54
+ page_chunk_cnt = 0
55
+ for page_chunk in page["text"].split("\n\n"):
56
+ page_chunk_len = len(page_chunk)
57
+ if not page_chunk or page_chunk[0] == " ":
58
+ continue
59
+ elif page_chunk_len > max_line_char_len:
60
+ with open(uber_chunk_file, "a", encoding="utf-8") as uber_chunk_f:
61
+ uber_chunk_f.write(page_chunk + "\n\n")
62
+ continue
63
+
64
+ page_chunk_lines = page_chunk.split("\n")
65
+ for chunk_line in page_chunk_lines:
66
+ if not chunk_line or chunk_line[0] == " ":
67
+ continue
68
+ elif len(white_space_pattern.split(chunk_line)) > 10: # Require at least 10 naive tokens
69
+ f.write(chunk_line + "\n")
70
+ current_corpus_file_char_len += len(chunk_line)
71
+ page_chunk_cnt += 1
72
+
73
+ iter_idx += 1
74
+ pages_processed_cnt += 1
75
+
76
+ if (pages_processed_cnt % 100) == 0:
77
+ logger.info(f"processed {pages_processed_cnt}/{total_page_cnt} pages")
78
+ if pages_processed_cnt >= max_processed_pages:
79
+ is_completed = True
80
+ break
81
+ if not is_completed and iter_idx == (total_page_cnt - 1):
82
+ is_completed = True
83
+
84
+ if args.train:
85
+ corpus_files = [f"{corpus_dir}/{f}" for f in os.listdir(corpus_dir) if f.startswith("sp_corpus")]
86
+ logger.info(f"corpus_files: {corpus_files}")
87
+
88
+ spm_training_args = [
89
+ f"--model_prefix={model_file_prefix}",
90
+ "--model_type=word",
91
+ "--shuffle_input_sentence=true",
92
+ #"--split_digits=true",
93
+ "--split_digits=false",
94
+ f"--input={','.join(random.sample(corpus_files, 15))}",
95
+ f"--input_sentence_size={input_sentence_size}",
96
+ f"--max_sentence_length={max_line_char_len}",
97
+ f"--vocab_size={vocab_size}",
98
+ ]
99
+ spm.SentencePieceTrainer.Train(" ".join(spm_training_args))
100
+
101
+ # Now you can load the model and test it:
102
+ sp = spm.SentencePieceProcessor()
103
+ sp.LoadFromFile(f"{model_file_prefix}.model")
104
+
105
+ print(sp.EncodeAsPieces("Hello world!"))
106
+ print(sp.EncodeAsPieces("127.0.0.1 is the localhost address."))
107
+ print(sp.EncodeAsPieces("1/2 is equivalent to 0.5 or 50%"))
108
+ print(sp.EncodeAsPieces("John was running so fast, you can just tell he's a runner."))
109
+ print(sp.EncodeAsPieces("He excels at math and competed in the Math Olympiad"))
110
+ print(sp.EncodeAsPieces("Watson was on his way to 221B Baker Street when the robbery occurred."))
111
+ print(sp.EncodeAsPieces("That's Uncopyrightable."))
112
+ print(sp.EncodeAsPieces("She's full of incomprehensibilities."))
113
+ print(sp.EncodeAsPieces("He's a total sesquipedalian."))
ud_dataset_maker.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, DatasetDict, concatenate_datasets
2
+ from typing import Optional
3
+ import argparse
4
+ import ast
5
+ import logging.config
6
+ import random
7
+
8
+ from utils.typos import generate_typo
9
+ from utils import default_logging_config
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ allowed_xpos = [
14
+ "''",
15
+ '$',
16
+ ',',
17
+ '-LRB-', # (
18
+ '-RRB-', # )
19
+ '.',
20
+ ':',
21
+ 'ADD', # URLs, email addresses, or other “address” forms (like Twitter handles) that do not fit elsewhere.
22
+ 'CC',
23
+ 'CD',
24
+ 'DT',
25
+ 'EX',
26
+ 'FW',
27
+ 'HYPH',
28
+ 'IN',
29
+ 'JJ',
30
+ 'JJR',
31
+ 'JJS',
32
+ 'LS', # List item marker
33
+ 'MD',
34
+ 'NFP', # “Non-Final Punctuation” for punctuation that doesn’t fit typical labels, in unexpected or stray positions
35
+ 'NN',
36
+ 'NNP',
37
+ 'NNPS',
38
+ 'NNS',
39
+ 'PDT',
40
+ 'POS',
41
+ 'PRP$',
42
+ 'PRP',
43
+ 'RB',
44
+ 'RBR',
45
+ 'RBS',
46
+ 'RP',
47
+ 'SYM',
48
+ 'TO',
49
+ 'UH',
50
+ 'VB',
51
+ 'VBD',
52
+ 'VBG',
53
+ 'VBN',
54
+ 'VBP',
55
+ 'VBZ',
56
+ 'WDT',
57
+ 'WP$',
58
+ 'WP',
59
+ 'WRB',
60
+ '``',
61
+ ]
62
+
63
+ allowed_deprel = [
64
+ 'acl',
65
+ 'acl:relcl',
66
+ 'advcl',
67
+ 'advmod',
68
+ 'amod',
69
+ 'appos',
70
+ 'aux',
71
+ 'aux:pass',
72
+ 'case',
73
+ 'cc',
74
+ 'cc:preconj',
75
+ 'ccomp',
76
+ 'compound',
77
+ 'compound:prt',
78
+ 'conj',
79
+ 'cop',
80
+ 'csubj',
81
+ 'csubj:pass',
82
+ 'dep',
83
+ 'det',
84
+ 'det:predet',
85
+ 'discourse',
86
+ 'dislocated',
87
+ 'expl',
88
+ 'fixed',
89
+ 'flat',
90
+ 'flat:foreign',
91
+ 'goeswith',
92
+ 'iobj',
93
+ 'list',
94
+ 'mark',
95
+ 'nmod',
96
+ 'nmod:npmod',
97
+ 'nmod:poss',
98
+ 'nmod:tmod',
99
+ 'nsubj',
100
+ 'nsubj:pass',
101
+ 'nummod',
102
+ 'obj',
103
+ 'obl',
104
+ 'obl:npmod',
105
+ 'obl:tmod',
106
+ 'orphan',
107
+ 'parataxis',
108
+ 'punct',
109
+ 'reparandum',
110
+ 'root',
111
+ 'vocative',
112
+ 'xcomp',
113
+ ]
114
+
115
+ target_feats = [
116
+ "Case", "Definite", "Degree", "Gender", "Mood", "NumType", "Number",
117
+ "Person", "Poss", "PronType", "Reflex", "Tense", "Typo", "VerbForm"
118
+ ]
119
+
120
+ non_target_feats = { # Found programmatically and added after analysis
121
+ "Abbr": [],
122
+ "Foreign": [],
123
+ "Polarity": [],
124
+ "Voice": [],
125
+ }
126
+
127
+
128
+ def add_target_feat_columns(exp):
129
+ """
130
+ Convert example["feats"] (list of feats) into separate columns
131
+ for each target_feat. Always return a dict with the same structure.
132
+ """
133
+ # example["feats"] is a list of length N (one per token)
134
+ feats_list = exp["feats"]
135
+
136
+ # Parse feats for each token
137
+ parsed_feats = [parse_morphological_feats(f, target_feats) for f in feats_list]
138
+
139
+ # Now add new columns for each target feat
140
+ for feat in target_feats:
141
+ exp[feat] = [pf[feat] for pf in parsed_feats]
142
+
143
+ return exp
144
+
145
+
146
+ def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None):
147
+ columns_to_train_on = [k for k in ds["train"].features.keys() if k not in (
148
+ {"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)]
149
+
150
+ # Create a dictionary of sets, keyed by each column name
151
+ label_counters = {col: dict() for col in columns_to_train_on}
152
+ unique_label_values = {col: set() for col in columns_to_train_on}
153
+
154
+ # Loop through each split and each example, and collect values
155
+ for split_name, dataset_split in ds.items():
156
+ for example in dataset_split:
157
+ # Each of these columns is a list (one entry per token),
158
+ # so we update our set with each token-level value
159
+ for col in columns_to_train_on:
160
+ unique_label_values[col].update(example[col])
161
+ for label_val in example[col]:
162
+ if label_val not in label_counters[col]:
163
+ label_counters[col][label_val] = 0 # Inits with 0
164
+ label_counters[col][label_val] += 1
165
+
166
+ logger.info(f"Columns:")
167
+ for col in columns_to_train_on:
168
+ logger.info(f" {col}:")
169
+ # Convert to a sorted list just to have a nice, stable ordering
170
+ vals = sorted(unique_label_values[col])
171
+ logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}")
172
+
173
+ return unique_label_values
174
+
175
+
176
+ def introduce_typos(exp, typo_probability=0.03):
177
+ """
178
+ Randomly introduce typos in some % of tokens.
179
+ Update the `tokens` and the `Typo` columns in-place.
180
+ """
181
+ # new lists for mutated tokens and new Typo labels
182
+ mutated_tokens = []
183
+ mutated_typo_col = []
184
+
185
+ # Loop over each token
186
+ for token, old_typo_label in zip(exp["tokens"], exp["Typo"]):
187
+ # Decide whether to mutate this token
188
+ if random.random() < typo_probability:
189
+ mutated_token = generate_typo(token)
190
+ mutated_tokens.append(mutated_token)
191
+ mutated_typo_col.append("Yes") # Mark as a "Yes" for the newly introduced typo
192
+ else:
193
+ mutated_tokens.append(token)
194
+ mutated_typo_col.append(old_typo_label)
195
+
196
+ exp["tokens"] = mutated_tokens
197
+ exp["Typo"] = mutated_typo_col
198
+ return exp
199
+
200
+
201
+ def is_evenly_shaped(exp):
202
+ # All your target columns
203
+ feats = ["xpos", "deprel", *target_feats]
204
+ n_tokens = len(exp["tokens"])
205
+ for feat_name in feats:
206
+ if len(exp[feat_name]) != n_tokens:
207
+ return False
208
+ return True
209
+
210
+
211
+ def is_valid_example(exp, dataset_name="ewt"):
212
+ """Return True if all xpos & deprel labels are in the common sets, else False."""
213
+ uniq_tokens = list(set(exp["tokens"]))
214
+ if len(uniq_tokens) == 1:
215
+ if uniq_tokens[0] == "_":
216
+ return False
217
+ for x in exp["xpos"]:
218
+ # If we hit an out-of-common-set xpos, we exclude this entire example
219
+ if x not in allowed_xpos:
220
+ # From time-to-time, we run into labels that are missing - either _ or None.
221
+ if x is None:
222
+ return False
223
+ elif x == "_":
224
+ return False
225
+ elif x == "-LSB-": # [, en_gum only, not shared by other datasets
226
+ return False
227
+ elif x == "-RSB-": # ], en_gum only, not shared by other datasets
228
+ return False
229
+ elif x == "AFX": # “Affix” for bound morphemes or prefixes/suffixes that are split off from main tokens
230
+ return False
231
+ elif x == "GW": # 'GW', # "Gap Word", sometimes called “additional word” or “merged/gap word”).
232
+ return False
233
+ elif x == "XX": # Unknown or “placeholder” words/tokens, 2 examples both word1/word2 with XX on the /
234
+ return False
235
+ logger.info(f"[{dataset_name}] Filtering example with: xpos={x}\n{exp['tokens']}\n{exp['xpos']}")
236
+ return False
237
+ for d in exp["deprel"]:
238
+ if d not in allowed_deprel:
239
+ if d is None:
240
+ return False
241
+ elif d == "_":
242
+ return False
243
+ logger.info(f"[{dataset_name}] Filtering example with: deprel={d}\n{exp['tokens']}\n{exp['deprel']}")
244
+ return False
245
+ return True
246
+
247
+
248
+ def parse_morphological_feats(feats_in, targeted_feats):
249
+ """
250
+ Return a dict {feat_name: feat_value} for each target_feat.
251
+ If a feature is absent or doesn't apply, use "O".
252
+ If feats_in is a dict, read from it.
253
+ If feats_in is a string, parse it.
254
+ If feats_in is None/'_'/'' => no features => all "O".
255
+ """
256
+ # Default
257
+ out = {feat: "O" for feat in targeted_feats}
258
+
259
+ # Case A: feats_in is None or "_" or an empty string
260
+ if not feats_in or feats_in == "_" or feats_in == "None":
261
+ return out
262
+
263
+ pristine_feats_in = feats_in
264
+
265
+ # Case B: feats_in is a dict string: "{'Number': 'Sing', 'Person': '3'}"
266
+ if isinstance(feats_in, str):
267
+ feats_in = ast.literal_eval(feats_in)
268
+
269
+ # Case C: feats_in is a dictionary (some UD data does that)
270
+ if isinstance(feats_in, dict):
271
+ for k, v in feats_in.items():
272
+ if k in targeted_feats:
273
+ out[k] = v
274
+ else:
275
+ if k in non_target_feats:
276
+ non_target_feats[k].append(v)
277
+ else:
278
+ logger.info(f"Unhandled non-target feat '{k}={v}'")
279
+ return out
280
+
281
+ # Otherwise, unknown type
282
+ logger.warning(f"Unknown feats type {type(pristine_feats_in)} => {pristine_feats_in}")
283
+ return out
284
+
285
+
286
+ def replace_bracket_label(exp):
287
+ label_map = {"(": "-LRB-", ")": "-RRB-"}
288
+ exp["xpos"] = [ label_map[tok] if tok in {"(", ")"} else tok for tok in exp["xpos"] ]
289
+ return exp
290
+
291
+
292
+ def show_examples(ds: DatasetDict, show_expr: Optional[str]):
293
+ logger.info(f"Dataset:\n{ds}")
294
+ if not show_expr:
295
+ count_to_show = 2
296
+ examples_to_show = ds["train"][:count_to_show]
297
+ else:
298
+ args_show_tokens = show_expr.split("/")
299
+ split_to_show, col_to_show, label_to_show, count_to_show = args_show_tokens
300
+ count_to_show = int(count_to_show)
301
+ examples_to_show = ds[split_to_show].filter(
302
+ lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show]
303
+ for i in range(count_to_show):
304
+ logger.info(f"Example {i}:")
305
+ for feature in examples_to_show.keys():
306
+ logger.info(f" {feature}: {examples_to_show[feature][i]}")
307
+
308
+
309
+ def transform_and_filter_dataset(ud_dataset, dataset_name="ewt"):
310
+ """
311
+ ud_dataset is a DatasetDict with splits: 'train', 'validation', 'test' etc.
312
+ Return a new DatasetDict with the same splits but transformed/filtered.
313
+ """
314
+ new_splits = {}
315
+ for _split_name, _split_ds in ud_dataset.items():
316
+ if dataset_name == "pud":
317
+ _split_ds = _split_ds.map(replace_bracket_label)
318
+ filtered_split = _split_ds.filter(lambda ex: is_valid_example(ex, dataset_name=dataset_name))
319
+ transformed_split = filtered_split.map(
320
+ add_target_feat_columns,
321
+ batched=False
322
+ )
323
+ transformed_split = transformed_split.remove_columns(["deps", "feats", "head", "idx", "lemmas", "misc", "upos"])
324
+ new_splits[_split_name] = transformed_split.filter(is_evenly_shaped)
325
+ return DatasetDict(new_splits)
326
+
327
+
328
+ if __name__ == "__main__":
329
+ arg_parser = argparse.ArgumentParser(description="Make training dataset.")
330
+ arg_parser.add_argument("--augment-typos", help='Augment final merged training data with typos.',
331
+ action="store_true", default=False)
332
+ arg_parser.add_argument("--log-level", help='Log level.',
333
+ action="store", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
334
+ arg_parser.add_argument("--save", help='Save dataset to disk.',
335
+ action="store_true", default=False)
336
+ arg_parser.add_argument("--save-path", help="Save final model to specified path.",
337
+ action="store", default="./training_data")
338
+ arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>",
339
+ action="store", default=None)
340
+ args = arg_parser.parse_args()
341
+ logging.config.dictConfig(default_logging_config)
342
+
343
+ # Load UD Datasets: EWT, GUM, PUD
344
+ ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt")
345
+ ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum")
346
+ ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud")
347
+
348
+ for loaded_ds_name, loaded_ds in {
349
+ "ud_en_ewt_ds": ud_en_ewt_ds,
350
+ "ud_en_gum_ds": ud_en_gum_ds,
351
+ "ud_en_pud_ds": ud_en_pud_ds
352
+ }.items():
353
+ t_cnt = len(loaded_ds['test']) if 'test' in loaded_ds else 0
354
+ tr_cnt = len(loaded_ds['train']) if 'train' in loaded_ds else 0
355
+ v_cnt = len(loaded_ds['validation']) if 'train' in loaded_ds else 0
356
+ logger.info(f"Loaded {loaded_ds_name}: t:{t_cnt}, tr:{tr_cnt}, v:{v_cnt}")
357
+
358
+ # Apply transform + filtering to each split in each dataset
359
+ en_ewt_processed = transform_and_filter_dataset(ud_en_ewt_ds, "ewt")
360
+ en_gum_processed = transform_and_filter_dataset(ud_en_gum_ds, "gum")
361
+ en_pud_processed = transform_and_filter_dataset(ud_en_pud_ds, "pud")
362
+
363
+ def is_rare_case(exp):
364
+ if "ADD" in exp["xpos"]:
365
+ return True
366
+ if "LS" in exp["xpos"]:
367
+ return True
368
+ if "WP$" in exp["xpos"]:
369
+ return True
370
+ if "Cmp" in exp["Degree"]:
371
+ return True
372
+ if "Sup" in exp["Degree"]:
373
+ return True
374
+ if "Fem" in exp["Gender"]:
375
+ return True
376
+ if "Imp" in exp["Mood"]:
377
+ return True
378
+ if "Mult" in exp["NumType"]:
379
+ return True
380
+ if "Ord" in exp["NumType"]:
381
+ return True
382
+ if "1" in exp["Person"]:
383
+ return True
384
+ if "2" in exp["Person"]:
385
+ return True
386
+ if "Int" in exp["PronType"]:
387
+ return True
388
+ if "Rel" in exp["PronType"]:
389
+ return True
390
+ if "Yes" in exp["Reflex"]:
391
+ return True
392
+ if "Yes" in exp["Typo"]:
393
+ return True
394
+ if "Ger" in exp["VerbForm"]:
395
+ return True
396
+ return False
397
+
398
+ # Concatenate Datasets
399
+ final_dataset = DatasetDict()
400
+ final_dataset["test"] = concatenate_datasets(
401
+ [
402
+ en_ewt_processed["test"],
403
+ #en_gum_processed["test"].filter(is_rare_case),
404
+ ]
405
+ )
406
+
407
+ final_dataset["train"] = concatenate_datasets(
408
+ [
409
+ en_ewt_processed["train"],
410
+ #en_gum_processed["train"].filter(is_rare_case),
411
+ #en_pud_processed["test"].filter(is_rare_case),
412
+ ]
413
+ )
414
+ if args.augment_typos:
415
+ final_dataset["train"] = final_dataset["train"].map(introduce_typos, batched=False)
416
+
417
+ final_dataset["validation"] = concatenate_datasets(
418
+ [
419
+ en_ewt_processed["validation"],
420
+ #en_gum_processed["validation"].filter(is_rare_case),
421
+ ]
422
+ )
423
+ show_examples(final_dataset, args.show)
424
+ get_uniq_training_labels(final_dataset)
425
+ if args.save:
426
+ final_dataset.save_to_disk(args.save_path)
427
+
utils/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ default_logging_config = {
3
+ "version": 1,
4
+ "disable_existing_loggers": False,
5
+ "formatters": {
6
+ "default": {
7
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
8
+ },
9
+ },
10
+ "handlers": {
11
+ "console": {
12
+ "class": "logging.StreamHandler",
13
+ "formatter": "default",
14
+ },
15
+ },
16
+ "loggers": {
17
+ "": {
18
+ "level": "INFO",
19
+ "handlers": ["console"],
20
+ },
21
+ },
22
+ }
utils/typos.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import string
3
+
4
+ common_misspelled_words = {
5
+ "absence": ["absense", "absentse", "abcense", "absance"],
6
+ "acceptable": ["acceptible"],
7
+ "accidentally": ["accidently", "ccidentaly"],
8
+ "accommodate": ["accomodate", "acommodate"],
9
+ "achieve": ["acheive"],
10
+ "acknowledge": ["acknowlege", "aknowledge"],
11
+ "acquaintance": ["acquaintence", "aquaintance"],
12
+ "acquire": ["aquire", "adquire"],
13
+ "acquit": ["aquit"],
14
+ "acreage": ["acrage", "acerage"],
15
+ "address": ["adress"],
16
+ "adultery": ["adultary"],
17
+ "advisable": ["adviseable", "advizable"],
18
+ "affect": ["effect"],
19
+ "aggression": ["agression"],
20
+ "aggressive": ["agressive"],
21
+ "allegiance": ["allegaince", "allegience", "alegiance"],
22
+ "almost": ["allmost"],
23
+ #"a lot": ["alot", "allot"] # Not captured since "a lot" is two tokens.
24
+ "amateur": ["amatuer", "amature"],
25
+ "annually": ["anually", "annualy"],
26
+ "apparent": ["apparant", "aparent", "apparrent", "aparrent"],
27
+ "arctic": ["artic"],
28
+ "argument": ["arguement"],
29
+ "atheist": ["athiest", "athist"],
30
+ "awful": ["awfull", "aweful"],
31
+ "because": ["becuase", "becasue"],
32
+ "beautiful": ["beatiful"],
33
+ "becoming": ["becomeing"],
34
+ "beginning": ["begining"],
35
+ "believe": ["beleive"],
36
+ "bellwether": ["bellweather"],
37
+ "benefit": ["benifit"],
38
+ "buoy": ["bouy"],
39
+ "buoyant": ["bouyant"],
40
+ "business": ["buisness"],
41
+ "calendar": ["calender"],
42
+ "camouflage": ["camoflage", "camoflague"],
43
+ "capitol": ["capital"],
44
+ "Caribbean": ["Carribean"], # More names?
45
+ "category": ["catagory"],
46
+ "caught": ["cauhgt", "caugt"],
47
+ "cemetery": ["cemetary", "cematery"],
48
+ "changeable": ["changable"],
49
+ "chief": ["cheif"],
50
+ "colleague": ["collaegue", "collegue"],
51
+ "column": ["colum"],
52
+ "coming": ["comming"],
53
+ "committed": ["commited", "comitted"],
54
+ "comparison": ["comparsion"],
55
+ "concede": ["conceed"],
56
+ "congratulate": ["congradulate"],
57
+ "conscientious": ["consciencious"],
58
+ "conscious": ["concious", "consious"],
59
+ "consensus": ["concensus"],
60
+ "controversy": ["contraversy"],
61
+ "coolly": ["cooly"],
62
+ "daiquiri": ["dacquiri", "daquiri"],
63
+ "deceive": ["decieve"],
64
+ "definite": ["definate", "definit"],
65
+ "definitely": ["definitly", "definately", "definatly", "defiantly"],
66
+ "desperate": ["desparate"],
67
+ "difference": ["diffrence"],
68
+ "dilemma": ["dilema"],
69
+ "disappoint": ["dissapoint"],
70
+ "disastrous": ["disasterous"],
71
+ "drunkenness": ["drunkeness"],
72
+ "dumbbell": ["dumbell"],
73
+ "embarrass": ["embarass"],
74
+ "equipment": ["equiptment"],
75
+ "exceed": ["excede"],
76
+ "exhilarate": ["exilerate"],
77
+ "existence": ["existance"],
78
+ "experience": ["experiance"],
79
+ "extreme": ["extreem"],
80
+ "fascinating": ["facinating"],
81
+ "fiery": ["firey"],
82
+ "fluorescent": ["flourescent"],
83
+ "foreign": ["foriegn"],
84
+ "forty": ["fourty"],
85
+ "friend": ["freind"],
86
+ "fulfil": ["fullfil", "fulfill"],
87
+ "gauge": ["guage"],
88
+ "grateful": ["gratefull", "greatful"],
89
+ "great": ["grate", "grat"],
90
+ "guarantee": ["garantee", "garentee", "garanty"],
91
+ "guidance": ["guidence"],
92
+ "harass": ["harrass"],
93
+ "height": ["heighth", "heigth"],
94
+ "hierarchy": ["heirarchy"],
95
+ # "hors d'oeuvres": ["hors derves", "ordeurves"] # Not captured since "hors d'oeuvres" is two tokens.
96
+ "humorous": ["humerous"],
97
+ "hygiene": ["hygene", "hygine", "hiygeine", "higeine", "hygeine"],
98
+ "hypocrite": ["hipocrit"],
99
+ "ignorance": ["ignorence"],
100
+ "imitate": ["immitate"],
101
+ "immediately": ["imediately"],
102
+ "indict": ["indite"],
103
+ "independent": ["independant"],
104
+ "indispensable": ["indispensible"],
105
+ "inoculate": ["innoculate"],
106
+ "intelligence": ["inteligence", "intelligance"],
107
+ "jewelry": ["jewellery", "jewelery"],
108
+ "judgment": ["judgement"],
109
+ "kernel": ["kernal"],
110
+ "leisure": ["liesure"],
111
+ "liaison": ["liason"],
112
+ "library": ["libary", "liberry"],
113
+ "license": ["lisence", "licence"],
114
+ "lightning": ["lightening"],
115
+ "lose": ["loose"],
116
+ "maintenance": ["maintainance", "maintnance"],
117
+ "marshmallow": ["marshmellow"],
118
+ "medieval": ["medeval", "medevil", "mideval"],
119
+ "memento": ["momento"],
120
+ "millennium": ["millenium", "milennium"],
121
+ "miniature": ["miniture"],
122
+ "minuscule": ["miniscule"],
123
+ "mischievous": ["mischievious", "mischevous", "mischevious"],
124
+ "misspell": ["mispell", "misspel"],
125
+ "necessary": ["neccessary", "necessery"],
126
+ "niece": ["neice"],
127
+ "neighbour": ["nieghbor"],
128
+ "noticeable": ["noticable"],
129
+ "occasion": ["occassion"],
130
+ "occasionally": ["occasionaly", "occassionally"],
131
+ "occurrence": ["occurrance", "occurence"],
132
+ "occurred": ["occured"],
133
+ "omission": ["ommision", "omision"],
134
+ "original": ["orignal"],
135
+ "outrageous": ["outragous"],
136
+ "parliament": ["parliment"],
137
+ "pastime": ["passtime", "pasttime"],
138
+ "pedagogue": ["pedagoge"],
139
+ "perceive": ["percieve"],
140
+ "perseverance": ["perseverence"],
141
+ "personnel": ["personell", "personel"],
142
+ "plagiarize": ["plagerize"],
143
+ "playwright": ["playright", "playwrite"],
144
+ "possession": ["posession", "possesion"],
145
+ "potatoes": ["potatos"],
146
+ "precede": ["preceed"],
147
+ "presence": ["presance"],
148
+ "principle": ["principal"],
149
+ "privilege": ["privelege", "priviledge"],
150
+ "professor": ["professer"],
151
+ "protester": ["protestor"],
152
+ "promise": ["promiss"],
153
+ "pronunciation": ["pronounciatio"],
154
+ "proof": ["prufe"],
155
+ "prophecy": ["prophesy"],
156
+ "publicly": ["publically"],
157
+ "quarantine": ["quarentine"],
158
+ "queue": ["que"],
159
+ "questionnaire": ["questionaire", "questionnair"],
160
+ "readable": ["readible"],
161
+ "really": ["realy"],
162
+ "receive": ["recieve"],
163
+ "receipt": ["reciept"],
164
+ "recommend": ["recomend", "reccommend"],
165
+ "referred": ["refered"],
166
+ "reference": ["referance", "refrence"],
167
+ "relevant": ["relevent", "revelant"],
168
+ "religious": ["religous", "religius"],
169
+ "repetition": ["repitition"],
170
+ "restaurant": ["restarant", "restaraunt"],
171
+ "rhyme": ["rime"],
172
+ "rhythm": ["rythm", "rythem"],
173
+ "secretary": ["secratary", "secretery"],
174
+ "seize": ["sieze"],
175
+ "separate": ["seperate"],
176
+ "sergeant": ["sargent"],
177
+ "similar": ["similer"],
178
+ "skilful": ["skilfull", "skillful"],
179
+ "speech": ["speach", "speeche"],
180
+ "successful": ["succesful", "successfull", "sucessful"],
181
+ "supersede": ["supercede"],
182
+ "surprise": ["suprise", "surprize"],
183
+ "than": ["then"],
184
+ "their": ["there", "they're"],
185
+ "tomatoes": ["tomatos"],
186
+ "tomorrow": ["tommorow", "tommorrow"],
187
+ "Tucson": ["Tuscon"],
188
+ "twelfth": ["twelth"],
189
+ "tyranny": ["tyrany"],
190
+ "underrate": ["underate"],
191
+ "until": ["untill"],
192
+ "upholstery": ["upholstry"],
193
+ "usable": ["useable", "usible"],
194
+ "vacuum": ["vaccuum", "vaccum", "vacume"],
195
+ "vehicle": ["vehical"],
196
+ "vicious": ["visious"],
197
+ "what": ["wat"],
198
+ "weather": ["wether", "whether"],
199
+ "weird": ["wierd"],
200
+ "welfare": ["wellfare", "welfair"],
201
+ "whether": ["wether"],
202
+ "wilful": ["wilfull", "willful"],
203
+ "withhold": ["withold"],
204
+ "writing": ["writting", "writeing"],
205
+ "you're": ["your"],
206
+ "your": ["you're"],
207
+ }
208
+
209
+
210
+ def apostrophe_error(word: str) -> str:
211
+ """
212
+ Simulate common errors with apostrophes.
213
+
214
+ If the word contains an apostrophe:
215
+ - randomly remove it,
216
+ - shift it one position left (if possible),
217
+ - shift it one position right (if possible), or
218
+ - duplicate it.
219
+
220
+ If the word does not contain an apostrophe but ends with 's',
221
+ sometimes insert an apostrophe to mimic a mistaken possessive.
222
+ """
223
+ if "'" in word:
224
+ # Identify all apostrophe positions
225
+ indices = [i for i, ch in enumerate(word) if ch == "'"]
226
+ idx = random.choice(indices)
227
+ error_choice = random.choice(['remove', 'shift_left', 'shift_right', 'duplicate'])
228
+ if error_choice == 'remove':
229
+ return word[:idx] + word[idx + 1:]
230
+ elif error_choice == 'shift_left':
231
+ if idx > 0:
232
+ # Remove the apostrophe and insert it one position left.
233
+ return word[:idx - 1] + word[idx] + word[idx - 1] + word[idx + 1:]
234
+ else:
235
+ return word[:idx] + word[idx + 1:]
236
+ elif error_choice == 'shift_right':
237
+ if idx < len(word) - 1:
238
+ # Remove the apostrophe and insert it one position right.
239
+ return word[:idx] + word[idx + 1] + word[idx] + word[idx + 2:]
240
+ else:
241
+ return word[:idx] + word[idx + 1:]
242
+ elif error_choice == 'duplicate':
243
+ return word[:idx + 1] + "'" + word[idx + 1:]
244
+ else:
245
+ # For words without an apostrophe: if the word ends with 's', sometimes insert one.
246
+ if word.endswith("s") and random.random() < 0.5:
247
+ # Insert an apostrophe before the last letter.
248
+ return word[:-1] + "'" + word[-1]
249
+ return word
250
+
251
+
252
+ def delete_random_letter(word: str) -> str:
253
+ """Simulate an omission error by deleting a random letter."""
254
+ if len(word) < 2:
255
+ return word
256
+ idx = random.randint(0, len(word) - 1)
257
+ return word[:idx] + word[idx + 1:]
258
+
259
+
260
+ def duplicate_random_letter(word: str) -> str:
261
+ """Simulate an extra keypress by duplicating a letter at a random index."""
262
+ if not word:
263
+ return word
264
+ idx = random.randint(0, len(word) - 1)
265
+ return word[:idx + 1] + word[idx] + word[idx + 1:]
266
+
267
+
268
+ def insert_random_letter(word: str) -> str:
269
+ """Simulate an insertion error by adding a random letter at a random position."""
270
+ idx = random.randint(0, len(word))
271
+ letter = random.choice(string.ascii_lowercase)
272
+ return word[:idx] + letter + word[idx:]
273
+
274
+
275
+ def replace_with_adjacent_key(word: str) -> str:
276
+ """
277
+ Simulate a typing error by replacing a letter with one of its QWERTY neighbors.
278
+ Only letters with defined neighbors are considered.
279
+ """
280
+ # Define neighboring keys for a QWERTY keyboard (for lowercase letters)
281
+ qwerty_neighbors = {
282
+ 'q': ['w', 'a'],
283
+ 'w': ['q', 'e', 's'],
284
+ 'e': ['w', 'r', 'd'],
285
+ 'r': ['e', 't', 'f'],
286
+ 't': ['r', 'y', 'g'],
287
+ 'y': ['t', 'u', 'h'],
288
+ 'u': ['y', 'i', 'j'],
289
+ 'i': ['u', 'o', 'k'],
290
+ 'o': ['i', 'p', 'l'],
291
+ 'p': ['o'],
292
+ 'a': ['q', 's', 'z'],
293
+ 's': ['a', 'd', 'w', 'x'],
294
+ 'd': ['s', 'f', 'e', 'c'],
295
+ 'f': ['d', 'g', 'r', 'v'],
296
+ 'g': ['f', 'h', 't', 'b'],
297
+ 'h': ['g', 'j', 'y', 'n'],
298
+ 'j': ['h', 'k', 'u', 'm'],
299
+ 'k': ['j', 'l', 'i'],
300
+ 'l': ['k', 'o'],
301
+ 'z': ['a', 'x'],
302
+ 'x': ['z', 'c', 's'],
303
+ 'c': ['x', 'v', 'd'],
304
+ 'v': ['c', 'b', 'f'],
305
+ 'b': ['v', 'n', 'g'],
306
+ 'n': ['b', 'm', 'h'],
307
+ 'm': ['n', 'j']
308
+ }
309
+ # Find indices of characters that are letters with neighbors
310
+ valid_indices = [i for i, ch in enumerate(word) if ch.lower() in qwerty_neighbors]
311
+ if not valid_indices:
312
+ return word
313
+ idx = random.choice(valid_indices)
314
+ orig_char = word[idx]
315
+ lower_char = orig_char.lower()
316
+ replacement = random.choice(qwerty_neighbors[lower_char])
317
+ # Preserve original case
318
+ if orig_char.isupper():
319
+ replacement = replacement.upper()
320
+ return word[:idx] + replacement + word[idx + 1:]
321
+
322
+
323
+ def swap_adjacent_letters(word: str) -> str:
324
+ """Simulate a transposition error by swapping two adjacent letters."""
325
+ if len(word) < 2:
326
+ return word
327
+ idx = random.randint(0, len(word) - 2)
328
+ word_list = list(word)
329
+ word_list[idx], word_list[idx + 1] = word_list[idx + 1], word_list[idx]
330
+ return ''.join(word_list)
331
+
332
+
333
+ def switch_ie_ei(word: str) -> str:
334
+ """
335
+ Switch occurrences of 'ie' with 'ei' (or vice versa) to simulate
336
+ a common vowel pair error.
337
+ """
338
+ if 'ie' in word:
339
+ # Find all occurrences of 'ie'
340
+ indices = []
341
+ start = 0
342
+ while True:
343
+ idx = word.find('ie', start)
344
+ if idx == -1:
345
+ break
346
+ indices.append(idx)
347
+ start = idx + 1
348
+ if indices:
349
+ idx = random.choice(indices)
350
+ return word[:idx] + 'ei' + word[idx + 2:]
351
+ elif 'ei' in word:
352
+ indices = []
353
+ start = 0
354
+ while True:
355
+ idx = word.find('ei', start)
356
+ if idx == -1:
357
+ break
358
+ indices.append(idx)
359
+ start = idx + 1
360
+ if indices:
361
+ idx = random.choice(indices)
362
+ return word[:idx] + 'ie' + word[idx + 2:]
363
+ return word
364
+
365
+
366
+ def generate_typo(word: str) -> str:
367
+ """
368
+ Given an input word, return a version of it with a common typo.
369
+ This function randomly selects one (or sometimes two) of the following error types:
370
+ - adjacent letter transposition
371
+ - deletion of a letter
372
+ - duplication of a letter
373
+ - insertion of a random letter
374
+ - replacement with a neighboring key (QWERTY)
375
+ - switching 'ie' and 'ei' sequences
376
+ While this method is by no means exhaustive, it reflects many of the typical errors documented.
377
+ """
378
+ if not word:
379
+ return word
380
+
381
+ if word in common_misspelled_words:
382
+ if random.random() < 0.5: # 50% chance of selecting a common misspelling.
383
+ return random.choice(common_misspelled_words[word])
384
+
385
+ # List of available transformation functions
386
+ transformations = [
387
+ apostrophe_error,
388
+ delete_random_letter,
389
+ duplicate_random_letter,
390
+ insert_random_letter,
391
+ replace_with_adjacent_key,
392
+ swap_adjacent_letters,
393
+ switch_ie_ei
394
+ ]
395
+
396
+ # Randomly choose one transformation
397
+ transformation = random.choice(transformations)
398
+ result = transformation(word)
399
+
400
+ # Occasionally chain a second transformation (10% chance) for added variability
401
+ if random.random() < 0.1:
402
+ second_transformation = random.choice(transformations)
403
+ result = second_transformation(result)
404
+
405
+ return result
406
+
407
+
408
+ # Example usage:
409
+ if __name__ == '__main__':
410
+ test_words = [
411
+ "accommodate", "definitely", "receive", "mischievous", "calendar",
412
+ "equipment", "pronunciation", "consensus", "friend", "beautiful",
413
+ "doesn't", "books"
414
+ ]
415
+ for test_word in test_words:
416
+ typo = generate_typo(test_word)
417
+ print(f"Original: {test_word:15s} -> Typo: {typo}")