veryfansome commited on
Commit
ab4b5ab
·
1 Parent(s): 360e354

feat: working training

Browse files
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .idea/
3
+ *_data/
4
+ *_final/
5
+ __pycache__/
6
+ cache-*.arrow
7
+ final/
8
+ training_data/
9
+ training_logs/
10
+ training_output/
dataset/{o3-mini/20250218 → o3-mini_20250218}/data-00000-of-00001.arrow RENAMED
File without changes
dataset/{o3-mini/20250218 → o3-mini_20250218}/dataset_info.json RENAMED
File without changes
dataset/{o3-mini/20250218 → o3-mini_20250218}/state.json RENAMED
File without changes
dataset_maker.py CHANGED
@@ -1,14 +1,13 @@
1
  from asyncio import Task
2
-
3
  from datasets import Dataset, load_dataset
 
4
  from openai import AsyncOpenAI
5
  from traceback import format_exc
6
  from typing import Union
7
  import asyncio
8
  import itertools
9
  import json
10
- import logging.config
11
- import random
12
  import sentencepiece as spm
13
 
14
  from utils import default_logging_config
@@ -19,21 +18,103 @@ logger = logging.getLogger(__name__)
19
  sp = spm.SentencePieceProcessor()
20
  sp.LoadFromFile(f"sp.model")
21
 
22
- #OPENAI_MODEL = "gpt-4o"
23
- #OPENAI_MODEL = "o1"
24
- OPENAI_MODEL = "o3-mini"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- async def classify_tokens(prompt: str, labels: dict[str, str], tokens: list[str]):
28
  tok_len = len(tokens)
29
  example = "[" + (", ".join([f'"{tok}"' for tok in tokens])) + "]"
30
  try:
31
  response = await client.chat.completions.create(
32
- model=OPENAI_MODEL,
33
- timeout=30,
34
- #temperature=0,
35
- #presence_penalty=0,
36
- reasoning_effort="low",
37
  messages=[
38
  {
39
  "role": "system",
@@ -90,217 +171,42 @@ async def classify_tokens(prompt: str, labels: dict[str, str], tokens: list[str]
90
  raise
91
 
92
 
93
- async def classify_with_retry(prompt, labels, tokens, retry=10):
94
  for i in range(retry):
95
  try:
96
- return await classify_tokens(prompt, labels, tokens)
97
  except Exception as e:
98
  logger.error(f"attempt {i} failed {tokens} {prompt} {format_exc()}")
99
  await asyncio.sleep(i)
100
 
101
-
102
- async def generate_token_labels(case):
103
  tokens = list(itertools.chain.from_iterable([s.strip("▁").split("▁") for s in sp.EncodeAsPieces(case)]))
104
- (
105
- adj,
106
- adv,
107
- det,
108
- enc,
109
- func,
110
- misc,
111
- ner1,
112
- ner2,
113
- noun,
114
- pronoun,
115
- punct,
116
- verb,
117
- wh,
118
- ) = await asyncio.gather(
119
- classify_with_retry(
120
- f"its semantic role",
121
- {
122
- "JJ": "adjective",
123
- "JJR": "comparative adjective",
124
- "JJS": "superlative adjective",
125
- "O": "out-of-scope",
126
- },
127
- tokens),
128
- classify_with_retry(
129
- f"its semantic role",
130
- {
131
- "RB": "adverb",
132
- "RBR": "comparative adverb",
133
- "RBS": "superlative adverb",
134
- "O": "out-of-scope",
135
- },
136
- tokens),
137
- classify_with_retry(
138
- f"its semantic role",
139
- {
140
- "DT": "articles, demonstratives, and other determiners",
141
- "EX": "existential 'there'",
142
- "PDT": "predeterminer before a determiner to modify a noun phrase",
143
- "O": "out-of-scope",
144
- },
145
- tokens),
146
- classify_with_retry(
147
- f"its sentence chunk classification",
148
- {
149
- "BRACKET": "in or contains bracket wrapped text",
150
- "QUOTE": "in or contains quote wrapped text",
151
- "TICK": "in or contains backtick wrapped text",
152
- "O": "out-of-scope",
153
- },
154
- tokens),
155
- classify_with_retry(
156
- f"its semantic role",
157
- {
158
- "CC": "coordinating conjunction",
159
- "IN": "preposition or subordinating conjunction",
160
- "RP": "particle",
161
- "TO": "to",
162
- "UH": "interjection",
163
- "O": "out-of-scope",
164
- },
165
- tokens),
166
- classify_with_retry(
167
- f"its semantic role",
168
- {
169
- "$": "currency",
170
- "ADD": "address, URLs, usernames, or other non-lexical representations of places or entities",
171
- "CD": "cardinal numbers",
172
- "EMOJI": "emoji",
173
- "TIME": "date or time",
174
- "O": "out-of-scope",
175
- },
176
- tokens),
177
- classify_with_retry(
178
- f"its NER classification",
179
- {
180
- "B-GPE": "beginning of geopolitical entities",
181
- "I-GPE": "inside of geopolitical entities",
182
- "B-ORG": "beginning of organization",
183
- "I-ORG": "inside of organization",
184
- "B-PER": "beginning of person",
185
- "I-PER": "inside of person",
186
- "O": "out-of-scope",
187
- },
188
- tokens),
189
- classify_with_retry(
190
- f"its NER classification",
191
- {
192
- "B-EVENT": "beginning of event",
193
- "I-EVENT": "inside of event",
194
- "B-LOC": "beginning of location",
195
- "I-LOC": "inside of location",
196
- "O": "out-of-scope",
197
- },
198
- tokens),
199
- classify_with_retry(
200
- f"its semantic role",
201
- {
202
- "NN": "common noun singular",
203
- "NNS": "common noun plural",
204
- "NNP": "proper noun singular",
205
- "NNPS": "proper noun plural",
206
- "O": "out-of-scope",
207
- },
208
- tokens),
209
- classify_with_retry(
210
- f"its semantic role",
211
- {
212
- "POS": "possessive ending like the 's",
213
- "PRP$": "possessive pronoun",
214
- "PRP": "personal pronoun",
215
- "O": "out-of-scope",
216
- },
217
- tokens),
218
- classify_with_retry(
219
- f"the punctuation classes it contains",
220
- {
221
- "COLON": "colon or semicolon",
222
- "COMMA": "comma",
223
- "EXCLAIM": "exclamation mark",
224
- "HYPH": "dash or hyphen",
225
- "LS": "list item marker",
226
- "PERIOD": "period",
227
- "QUESTION": "question mark",
228
- "SEP": "any section separator",
229
- "O": "out-of-scope",
230
- },
231
- tokens),
232
- classify_with_retry(
233
- f"its semantic role",
234
- {
235
- "MD": "modal verb",
236
- "VB": "verb base form",
237
- "VBD": "verb past tense",
238
- "VBG": "present participle, gerund",
239
- "VBN": "past participle",
240
- "VBP": "non-3rd person singular present",
241
- "VBZ": "3rd person singular present",
242
- "O": "out-of-scope",
243
- },
244
- tokens),
245
- classify_with_retry(
246
- f"its semantic role",
247
- {
248
- "WDT": "Wh-determiner",
249
- "WP$": "Wh-possessive pronoun",
250
- "WP": "Wh-pronoun",
251
- "WRB": "Wh-adverb",
252
- "O": "out-of-scope",
253
- },
254
- tokens),
255
- )
256
- return {
257
- "text": case,
258
- "tokens": tokens,
259
- "adj": adj,
260
- "adv": adv,
261
- "det": det,
262
- "enc": enc,
263
- "func": func,
264
- "misc": misc,
265
- "ner1": ner1,
266
- "ner2": ner2,
267
- "noun": noun,
268
- "pronoun": pronoun,
269
- "punct": punct,
270
- "verb": verb,
271
- "wh": wh,
272
- }
273
-
274
-
275
- async def main(cases):
276
- ds_dict = {
277
- "text": [],
278
- "tokens": [],
279
- "adj": [],
280
- "adv": [],
281
- "det": [],
282
- "enc": [],
283
- "func": [],
284
- "misc": [],
285
- "ner1": [],
286
- "ner2": [],
287
- "noun": [],
288
- "pronoun": [],
289
- "punct": [],
290
- "verb": [],
291
- "wh": [],
292
- }
293
  drain_completed = False
294
  max_concurrent_tasks = 15
295
  tasks: list[Union[Task, None]] = []
296
 
297
  async def checkpoint_task():
 
298
  while not drain_completed:
299
- # checkpoint
300
- _ds = Dataset.from_dict(ds_dict)
301
- _ds.save_to_disk("./custom_checkpoint_data")
302
- logger.info(f"\n{_ds}")
303
- await asyncio.sleep(600)
 
 
304
  future_checkpoint_task_completion = asyncio.create_task(checkpoint_task())
305
 
306
  async def drain_tasks():
@@ -328,7 +234,7 @@ async def main(cases):
328
  while len([t for t in tasks if t is not None]) >= max_concurrent_tasks:
329
  await asyncio.sleep(1)
330
  logger.info(f"scheduling case {case}")
331
- tasks.append(asyncio.create_task(generate_token_labels(case)))
332
 
333
  # Block until done
334
  while len([t for t in tasks if t is not None]) > 0:
@@ -340,46 +246,57 @@ async def main(cases):
340
  await future_checkpoint_task_completion
341
 
342
  ds = Dataset.from_dict(ds_dict)
343
- ds.save_to_disk("./custom_final_data")
344
  logger.info(f"\n{ds}")
345
 
346
 
347
  if __name__ == "__main__":
 
 
 
348
  logging.config.dictConfig(default_logging_config)
349
 
350
- all_examples = []
 
 
 
 
 
 
 
 
 
351
 
352
- ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt")
353
- all_examples += ud_en_ewt_ds["train"]["text"]
354
- all_examples += ud_en_ewt_ds["validation"]["text"]
355
- all_examples += ud_en_ewt_ds["test"]["text"]
 
356
 
357
- ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum")
358
- all_examples += ud_en_gum_ds["train"]["text"]
359
- all_examples += ud_en_gum_ds["validation"]["text"]
360
- all_examples += ud_en_gum_ds["test"]["text"]
361
 
362
- ud_en_lines_ds = load_dataset("universal_dependencies", "en_lines")
363
- all_examples += ud_en_lines_ds["train"]["text"]
364
- all_examples += ud_en_lines_ds["validation"]["text"]
365
- all_examples += ud_en_lines_ds["test"]["text"]
366
 
367
- ud_en_partut_ds = load_dataset("universal_dependencies", "en_partut")
368
- all_examples += ud_en_partut_ds["train"]["text"]
369
- all_examples += ud_en_partut_ds["validation"]["text"]
370
- all_examples += ud_en_partut_ds["test"]["text"]
371
 
372
- ud_en_pronouns_ds = load_dataset("universal_dependencies", "en_pronouns")
373
- all_examples += ud_en_pronouns_ds["test"]["text"]
374
 
375
- ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud")
376
- all_examples += ud_en_pud_ds["test"]["text"]
377
 
378
- logger.info(f"{len(all_examples)} UD examples")
379
- random.shuffle(all_examples)
380
- logger.info(f"{all_examples[:10]}")
381
 
382
- all_examples += [
383
  "Hello world!",
384
  "127.0.0.1 is the localhost address.",
385
  "1/2 is equivalent to 0.5 or 50%",
@@ -416,5 +333,5 @@ if __name__ == "__main__":
416
  "hes got a bon to pick",
417
  "so then he says then he says, \"you'll regret this\" lol",
418
  ]
419
- asyncio.run(main(all_examples))
420
 
 
1
  from asyncio import Task
 
2
  from datasets import Dataset, load_dataset
3
+ from datetime import datetime
4
  from openai import AsyncOpenAI
5
  from traceback import format_exc
6
  from typing import Union
7
  import asyncio
8
  import itertools
9
  import json
10
+ import logging
 
11
  import sentencepiece as spm
12
 
13
  from utils import default_logging_config
 
18
  sp = spm.SentencePieceProcessor()
19
  sp.LoadFromFile(f"sp.model")
20
 
21
+ features = {
22
+ "adj": {"JJ": "adjective",
23
+ "JJR": "comparative adjective",
24
+ "JJS": "superlative adjective",
25
+ "O": "out-of-scope"},
26
+ "adv": {"RB": "adverb",
27
+ "RBR": "comparative adverb",
28
+ "RBS": "superlative adverb",
29
+ "O": "out-of-scope"},
30
+ "det": {"DT": "articles, demonstratives, and other determiners",
31
+ "EX": "existential 'there'",
32
+ "PDT": "predeterminer before a determiner to modify a noun phrase",
33
+ "O": "out-of-scope"},
34
+ "enc": {"BRACKET": "in or contains bracket wrapped text",
35
+ "QUOTE": "in or contains quote wrapped text",
36
+ "TICK": "in or contains backtick wrapped text",
37
+ "O": "out-of-scope"},
38
+ "func": {"CC": "coordinating conjunction",
39
+ "IN": "preposition or subordinating conjunction",
40
+ "RP": "particle",
41
+ "TO": "to",
42
+ "UH": "interjection",
43
+ "O": "out-of-scope"},
44
+ "misc": {"$": "currency",
45
+ "ADD": "address, URLs, usernames, or other non-lexical representations of places or entities",
46
+ "CD": "cardinal numbers",
47
+ "EMOJI": "emoji",
48
+ "TIME": "date or time",
49
+ "O": "out-of-scope"},
50
+ "ner1": {"B-GPE": "beginning of geopolitical entities",
51
+ "I-GPE": "inside of geopolitical entities",
52
+ "B-ORG": "beginning of organization",
53
+ "I-ORG": "inside of organization",
54
+ "B-PER": "beginning of person",
55
+ "I-PER": "inside of person",
56
+ "O": "out-of-scope"},
57
+ "ner2": {"B-EVENT": "beginning of event",
58
+ "I-EVENT": "inside of event",
59
+ "B-LOC": "beginning of location",
60
+ "I-LOC": "inside of location",
61
+ "O": "out-of-scope"},
62
+ "noun": {"NN": "common noun singular",
63
+ "NNS": "common noun plural",
64
+ "NNP": "proper noun singular",
65
+ "NNPS": "proper noun plural",
66
+ "O": "out-of-scope" },
67
+ "pronoun": {"POS": "possessive ending like the 's",
68
+ "PRP$": "possessive pronoun",
69
+ "PRP": "personal pronoun",
70
+ "O": "out-of-scope"},
71
+ "punct": {"COLON": "colon or semicolon",
72
+ "COMMA": "comma",
73
+ "EXCLAIM": "exclamation mark",
74
+ "HYPH": "dash or hyphen",
75
+ "LS": "list item marker",
76
+ "PERIOD": "period",
77
+ "QUESTION": "question mark",
78
+ "SEP": "any section separator",
79
+ "O": "out-of-scope"},
80
+ "verb": {"MD": "modal verb",
81
+ "VB": "verb base form",
82
+ "VBD": "verb past tense",
83
+ "VBG": "present participle, gerund",
84
+ "VBN": "past participle",
85
+ "VBP": "non-3rd person singular present",
86
+ "VBZ": "3rd person singular present",
87
+ "O": "out-of-scope"},
88
+ "wh": {"WDT": "Wh-determiner",
89
+ "WP$": "Wh-possessive pronoun",
90
+ "WP": "Wh-pronoun",
91
+ "WRB": "Wh-adverb",
92
+ "O": "out-of-scope"},
93
+ }
94
 
95
+ prompts = {
96
+ "adj": f"its semantic role",
97
+ "adv": f"its semantic role",
98
+ "det": f"its semantic role",
99
+ "enc": f"its sentence chunk classification",
100
+ "func": f"its semantic role",
101
+ "misc": f"its semantic role",
102
+ "ner1": f"its NER classification",
103
+ "ner2": f"its NER classification",
104
+ "noun": f"its semantic role",
105
+ "pronoun": f"its semantic role",
106
+ "punct": f"the punctuation classes it contains",
107
+ "verb": f"its semantic role",
108
+ "wh": f"its semantic role",
109
+ }
110
 
111
+ async def classify_tokens(prompt: str, labels: dict[str, str], tokens: list[str], model="gpt-4o"):
112
  tok_len = len(tokens)
113
  example = "[" + (", ".join([f'"{tok}"' for tok in tokens])) + "]"
114
  try:
115
  response = await client.chat.completions.create(
116
+ model=model, timeout=30,
117
+ **({"reasoning_effort": "low"} if model.startswith("o") else {"presence_penalty": 0, "temperature": 0}),
 
 
 
118
  messages=[
119
  {
120
  "role": "system",
 
171
  raise
172
 
173
 
174
+ async def classify_with_retry(prompt, labels, tokens, model="gpt-4o", retry=10):
175
  for i in range(retry):
176
  try:
177
+ return await classify_tokens(prompt, labels, tokens, model=model)
178
  except Exception as e:
179
  logger.error(f"attempt {i} failed {tokens} {prompt} {format_exc()}")
180
  await asyncio.sleep(i)
181
 
182
+ async def generate_token_labels(case, model="gpt-4o"):
 
183
  tokens = list(itertools.chain.from_iterable([s.strip("▁").split("▁") for s in sp.EncodeAsPieces(case)]))
184
+ sorted_cols = list(sorted(features.keys()))
185
+ example = {}
186
+ for idx, labels in enumerate(list(await asyncio.gather(
187
+ *[classify_with_retry(prompts[col], features[col], tokens, model=model) for col in sorted_cols]))):
188
+ example[sorted_cols[idx]] = labels
189
+ return example
190
+
191
+
192
+ async def main(args, cases):
193
+ ds_dict = {k: [] for k in features.keys()}
194
+
195
+ ts_run = datetime.now().strftime("%Y%m%d")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  drain_completed = False
197
  max_concurrent_tasks = 15
198
  tasks: list[Union[Task, None]] = []
199
 
200
  async def checkpoint_task():
201
+ tick_cnt = 0
202
  while not drain_completed:
203
+ tick_cnt += 1
204
+ if tick_cnt % 600 == 0:
205
+ # checkpoint
206
+ _ds = Dataset.from_dict(ds_dict)
207
+ _ds.save_to_disk(f"{args.save_path}/{args.openai_model}_{ts_run}_checkpoint")
208
+ logger.info(f"\n{_ds}")
209
+ await asyncio.sleep(1)
210
  future_checkpoint_task_completion = asyncio.create_task(checkpoint_task())
211
 
212
  async def drain_tasks():
 
234
  while len([t for t in tasks if t is not None]) >= max_concurrent_tasks:
235
  await asyncio.sleep(1)
236
  logger.info(f"scheduling case {case}")
237
+ tasks.append(asyncio.create_task(generate_token_labels(case, model=args.openai_model)))
238
 
239
  # Block until done
240
  while len([t for t in tasks if t is not None]) > 0:
 
246
  await future_checkpoint_task_completion
247
 
248
  ds = Dataset.from_dict(ds_dict)
249
+ ds.save_to_disk(f"{args.save_path}/{args.openai_model}_{ts_run}")
250
  logger.info(f"\n{ds}")
251
 
252
 
253
  if __name__ == "__main__":
254
+ import argparse
255
+ import logging.config
256
+
257
  logging.config.dictConfig(default_logging_config)
258
 
259
+ arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
260
+ arg_parser.add_argument("--openai-model", help="OpenAI model.",
261
+ action="store", default="o3-mini", choices=["gpt-4o", "o3-mini", "o1"])
262
+ arg_parser.add_argument("--save-path", help="Save final dataset to specified path.",
263
+ action="store", default="./dataset")
264
+ arg_parser.add_argument("--ud", help='Use UD datasets.',
265
+ action="store_true", default=False)
266
+ parsed_args = arg_parser.parse_args()
267
+
268
+ all_text = []
269
 
270
+ if parsed_args.ud:
271
+ ud_en_ewt_ds = load_dataset("universal_dependencies", "en_ewt")
272
+ all_text += ud_en_ewt_ds["train"]["text"]
273
+ all_text += ud_en_ewt_ds["validation"]["text"]
274
+ all_text += ud_en_ewt_ds["test"]["text"]
275
 
276
+ ud_en_gum_ds = load_dataset("universal_dependencies", "en_gum")
277
+ all_text += ud_en_gum_ds["train"]["text"]
278
+ all_text += ud_en_gum_ds["validation"]["text"]
279
+ all_text += ud_en_gum_ds["test"]["text"]
280
 
281
+ ud_en_lines_ds = load_dataset("universal_dependencies", "en_lines")
282
+ all_text += ud_en_lines_ds["train"]["text"]
283
+ all_text += ud_en_lines_ds["validation"]["text"]
284
+ all_text += ud_en_lines_ds["test"]["text"]
285
 
286
+ ud_en_partut_ds = load_dataset("universal_dependencies", "en_partut")
287
+ all_text += ud_en_partut_ds["train"]["text"]
288
+ all_text += ud_en_partut_ds["validation"]["text"]
289
+ all_text += ud_en_partut_ds["test"]["text"]
290
 
291
+ ud_en_pronouns_ds = load_dataset("universal_dependencies", "en_pronouns")
292
+ all_text += ud_en_pronouns_ds["test"]["text"]
293
 
294
+ ud_en_pud_ds = load_dataset("universal_dependencies", "en_pud")
295
+ all_text += ud_en_pud_ds["test"]["text"]
296
 
297
+ logger.info(f"{len(all_text)} UD examples")
 
 
298
 
299
+ all_text += [
300
  "Hello world!",
301
  "127.0.0.1 is the localhost address.",
302
  "1/2 is equivalent to 0.5 or 50%",
 
333
  "hes got a bon to pick",
334
  "so then he says then he says, \"you'll regret this\" lol",
335
  ]
336
+ asyncio.run(main(parsed_args, all_text))
337
 
dataset_splitter.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import DatasetDict, load_from_disk
2
+ import argparse
3
+
4
+ from dataset_maker import features
5
+
6
+ def has_all_valid_labels(exp):
7
+ for col, labels in exp.items():
8
+ if col in {"text", "tokens"}:
9
+ continue
10
+ for label in labels:
11
+ if label not in features[col]:
12
+ return False
13
+ return True
14
+
15
+ def is_evenly_shaped(exp):
16
+ cnt_set = set()
17
+ for col, labels in exp.items():
18
+ if col == "text":
19
+ continue
20
+ cnt_set.add(len(labels))
21
+ return len(cnt_set) == 1
22
+
23
+
24
+ if __name__ == '__main__':
25
+ arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
26
+ arg_parser.add_argument("data_path", help="Load dataset from specified path.",
27
+ action="store")
28
+ arg_parser.add_argument("--save-path", help="Save final dataset to specified path.",
29
+ action="store", default="./training_data")
30
+ args = arg_parser.parse_args()
31
+
32
+ loaded_dataset = load_from_disk(args.data_path)
33
+ loaded_dataset = loaded_dataset.filter(is_evenly_shaped)
34
+ loaded_dataset = loaded_dataset.filter(has_all_valid_labels)
35
+
36
+ first_split = loaded_dataset.train_test_split(shuffle=True, seed=42, test_size=0.09)
37
+ second_split = first_split["train"].train_test_split(test_size=0.1)
38
+
39
+ new_ds = DatasetDict()
40
+ new_ds["test"] = first_split["test"]
41
+ new_ds["train"] = second_split["train"]
42
+ new_ds["validation"] = second_split["test"]
43
+ new_ds.save_to_disk(args.save_path)
multi_task_classifier.py CHANGED
@@ -14,8 +14,7 @@ import numpy as np
14
  import torch
15
  import torch.nn as nn
16
 
17
- from ud_training_data_maker import get_uniq_training_labels, show_examples
18
- from utils import default_logging_config
19
 
20
  logger = logging.getLogger(__name__)
21
 
@@ -53,8 +52,6 @@ args = arg_parser.parse_args()
53
  logging.config.dictConfig(default_logging_config)
54
  logger.info(f"Args {args}")
55
 
56
-
57
-
58
  # ------------------------------------------------------------------------------
59
  # Load dataset and show examples for manual inspection
60
  # ------------------------------------------------------------------------------
 
14
  import torch
15
  import torch.nn as nn
16
 
17
+ from utils import default_logging_config, get_uniq_training_labels, show_examples
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
52
  logging.config.dictConfig(default_logging_config)
53
  logger.info(f"Args {args}")
54
 
 
 
55
  # ------------------------------------------------------------------------------
56
  # Load dataset and show examples for manual inspection
57
  # ------------------------------------------------------------------------------
ud_dataset_maker.py CHANGED
@@ -1,12 +1,11 @@
1
  from datasets import load_dataset, DatasetDict, concatenate_datasets
2
- from typing import Optional
3
  import argparse
4
  import ast
5
  import logging.config
6
  import random
7
 
8
  from utils.typos import generate_typo
9
- from utils import default_logging_config
10
 
11
  logger = logging.getLogger(__name__)
12
 
@@ -143,36 +142,6 @@ def add_target_feat_columns(exp):
143
  return exp
144
 
145
 
146
- def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None):
147
- columns_to_train_on = [k for k in ds["train"].features.keys() if k not in (
148
- {"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)]
149
-
150
- # Create a dictionary of sets, keyed by each column name
151
- label_counters = {col: dict() for col in columns_to_train_on}
152
- unique_label_values = {col: set() for col in columns_to_train_on}
153
-
154
- # Loop through each split and each example, and collect values
155
- for split_name, dataset_split in ds.items():
156
- for example in dataset_split:
157
- # Each of these columns is a list (one entry per token),
158
- # so we update our set with each token-level value
159
- for col in columns_to_train_on:
160
- unique_label_values[col].update(example[col])
161
- for label_val in example[col]:
162
- if label_val not in label_counters[col]:
163
- label_counters[col][label_val] = 0 # Inits with 0
164
- label_counters[col][label_val] += 1
165
-
166
- logger.info(f"Columns:")
167
- for col in columns_to_train_on:
168
- logger.info(f" {col}:")
169
- # Convert to a sorted list just to have a nice, stable ordering
170
- vals = sorted(unique_label_values[col])
171
- logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}")
172
-
173
- return unique_label_values
174
-
175
-
176
  def introduce_typos(exp, typo_probability=0.03):
177
  """
178
  Randomly introduce typos in some % of tokens.
@@ -289,23 +258,6 @@ def replace_bracket_label(exp):
289
  return exp
290
 
291
 
292
- def show_examples(ds: DatasetDict, show_expr: Optional[str]):
293
- logger.info(f"Dataset:\n{ds}")
294
- if not show_expr:
295
- count_to_show = 2
296
- examples_to_show = ds["train"][:count_to_show]
297
- else:
298
- args_show_tokens = show_expr.split("/")
299
- split_to_show, col_to_show, label_to_show, count_to_show = args_show_tokens
300
- count_to_show = int(count_to_show)
301
- examples_to_show = ds[split_to_show].filter(
302
- lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show]
303
- for i in range(count_to_show):
304
- logger.info(f"Example {i}:")
305
- for feature in examples_to_show.keys():
306
- logger.info(f" {feature}: {examples_to_show[feature][i]}")
307
-
308
-
309
  def transform_and_filter_dataset(ud_dataset, dataset_name="ewt"):
310
  """
311
  ud_dataset is a DatasetDict with splits: 'train', 'validation', 'test' etc.
 
1
  from datasets import load_dataset, DatasetDict, concatenate_datasets
 
2
  import argparse
3
  import ast
4
  import logging.config
5
  import random
6
 
7
  from utils.typos import generate_typo
8
+ from utils import default_logging_config, get_uniq_training_labels, show_examples
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
142
  return exp
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def introduce_typos(exp, typo_probability=0.03):
146
  """
147
  Randomly introduce typos in some % of tokens.
 
258
  return exp
259
 
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  def transform_and_filter_dataset(ud_dataset, dataset_name="ewt"):
262
  """
263
  ud_dataset is a DatasetDict with splits: 'train', 'validation', 'test' etc.
ud_multi_task_classifier.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import DatasetDict, load_from_disk
2
+ from sklearn.metrics import classification_report, precision_recall_fscore_support
3
+ from transformers import (
4
+ DebertaV2Config,
5
+ DebertaV2Model,
6
+ DebertaV2PreTrainedModel,
7
+ DebertaV2TokenizerFast,
8
+ Trainer,
9
+ TrainingArguments,
10
+ )
11
+ import argparse
12
+ import logging.config
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from utils import default_logging_config, get_uniq_training_labels, show_examples
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ arg_parser = argparse.ArgumentParser(description="Train multi-task model.")
22
+ arg_parser.add_argument("-A", "--accumulation-steps", help="Gradient accumulation steps.",
23
+ action="store", type=int, default=8)
24
+ arg_parser.add_argument("--data-only", help='Show training data info and exit.',
25
+ action="store_true", default=False)
26
+ arg_parser.add_argument("--data-path", help="Load training dataset from specified path.",
27
+ action="store", default="./training_data")
28
+ arg_parser.add_argument("-E", "--train-epochs", help="Number of epochs to train for.",
29
+ action="store", type=int, default=3)
30
+ arg_parser.add_argument("-V", "--eval-batch-size", help="Per device eval batch size.",
31
+ action="store", type=int, default=2)
32
+ arg_parser.add_argument("--from-base", help="Load a base model.",
33
+ action="store", default=None,
34
+ choices=[
35
+ "microsoft/deberta-v3-base", # Requires --deberta-v3
36
+ "microsoft/deberta-v3-large", # Requires --deberta-v3
37
+ # More?
38
+ ])
39
+ arg_parser.add_argument("-L", "--learning-rate", help="Learning rate.",
40
+ action="store", type=float, default=5e-5)
41
+ arg_parser.add_argument("--mini", help='Train model using small subset of examples for pipeline testing.',
42
+ action="store_true", default=False)
43
+ arg_parser.add_argument("--save-path", help="Save final model to specified path.",
44
+ action="store", default="./final")
45
+ arg_parser.add_argument("--show", help="Show examples: <split>/<col>/<label>/<count>",
46
+ action="store", default=None)
47
+ arg_parser.add_argument("--train", help='Train model using loaded examples.',
48
+ action="store_true", default=False)
49
+ arg_parser.add_argument("-T", "--train-batch-size", help="Per device train batch size.",
50
+ action="store", type=int, default=2)
51
+ args = arg_parser.parse_args()
52
+ logging.config.dictConfig(default_logging_config)
53
+ logger.info(f"Args {args}")
54
+
55
+
56
+
57
+ # ------------------------------------------------------------------------------
58
+ # Load dataset and show examples for manual inspection
59
+ # ------------------------------------------------------------------------------
60
+
61
+ loaded_dataset = load_from_disk(args.data_path)
62
+ show_examples(loaded_dataset, args.show)
63
+
64
+ # ------------------------------------------------------------------------------
65
+ # Convert label analysis data into label sets for each head
66
+ # ------------------------------------------------------------------------------
67
+
68
+ ALL_LABELS = {col: list(vals) for col, vals in get_uniq_training_labels(loaded_dataset).items()}
69
+
70
+ LABEL2ID = {
71
+ feat_name: {label: i for i, label in enumerate(ALL_LABELS[feat_name])}
72
+ for feat_name in ALL_LABELS
73
+ }
74
+ ID2LABEL = {
75
+ feat_name: {i: label for label, i in LABEL2ID[feat_name].items()}
76
+ for feat_name in LABEL2ID
77
+ }
78
+
79
+ # Each head's number of labels:
80
+ NUM_LABELS_DICT = {k: len(v) for k, v in ALL_LABELS.items()}
81
+
82
+ if args.data_only:
83
+ exit()
84
+
85
+ # ------------------------------------------------------------------------------
86
+ # Create a custom config that can store our multi-label info
87
+ # ------------------------------------------------------------------------------
88
+
89
+ class MultiHeadModelConfig(DebertaV2Config):
90
+ def __init__(self, label_maps=None, num_labels_dict=None, **kwargs):
91
+ super().__init__(**kwargs)
92
+ self.label_maps = label_maps or {}
93
+ self.num_labels_dict = num_labels_dict or {}
94
+
95
+ def to_dict(self):
96
+ output = super().to_dict()
97
+ output["label_maps"] = self.label_maps
98
+ output["num_labels_dict"] = self.num_labels_dict
99
+ return output
100
+
101
+ # ------------------------------------------------------------------------------
102
+ # Define a multi-head model
103
+ # ------------------------------------------------------------------------------
104
+
105
+ class MultiHeadModel(DebertaV2PreTrainedModel):
106
+ def __init__(self, config: MultiHeadModelConfig):
107
+ super().__init__(config)
108
+
109
+ self.deberta = DebertaV2Model(config)
110
+ self.classifiers = nn.ModuleDict()
111
+
112
+ hidden_size = config.hidden_size
113
+ for label_name, n_labels in config.num_labels_dict.items():
114
+ self.classifiers[label_name] = nn.Linear(hidden_size, n_labels)
115
+
116
+ # Initialize newly added weights
117
+ self.post_init()
118
+
119
+ def forward(
120
+ self,
121
+ input_ids=None,
122
+ attention_mask=None,
123
+ token_type_ids=None,
124
+ labels_dict=None,
125
+ **kwargs
126
+ ):
127
+ """
128
+ labels_dict: a dict of { label_name: (batch_size, seq_len) } with label ids.
129
+ If provided, we compute and return the sum of CE losses.
130
+ """
131
+ outputs = self.deberta(
132
+ input_ids=input_ids,
133
+ attention_mask=attention_mask,
134
+ token_type_ids=token_type_ids,
135
+ **kwargs
136
+ )
137
+
138
+ sequence_output = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
139
+
140
+ logits_dict = {}
141
+ for label_name, classifier in self.classifiers.items():
142
+ logits_dict[label_name] = classifier(sequence_output)
143
+
144
+ total_loss = None
145
+ loss_dict = {}
146
+ if labels_dict is not None:
147
+ # We'll sum the losses from each head
148
+ loss_fct = nn.CrossEntropyLoss()
149
+ total_loss = 0.0
150
+
151
+ for label_name, logits in logits_dict.items():
152
+ if label_name not in labels_dict:
153
+ continue
154
+ label_ids = labels_dict[label_name]
155
+
156
+ # A typical approach for token classification:
157
+ # We ignore positions where label_ids == -100
158
+ active_loss = label_ids != -100 # shape (bs, seq_len)
159
+
160
+ # flatten everything
161
+ active_logits = logits.view(-1, logits.shape[-1])[active_loss.view(-1)]
162
+ active_labels = label_ids.view(-1)[active_loss.view(-1)]
163
+
164
+ loss = loss_fct(active_logits, active_labels)
165
+ loss_dict[label_name] = loss.item()
166
+ total_loss += loss
167
+
168
+ if labels_dict is not None:
169
+ # return (loss, predictions)
170
+ return total_loss, logits_dict
171
+ else:
172
+ # just return predictions
173
+ return logits_dict
174
+
175
+ # ------------------------------------------------------------------------------
176
+ # Tokenize with max_length=512, stride=128, and subword alignment
177
+ # ------------------------------------------------------------------------------
178
+
179
+ def tokenize_and_align_labels(examples):
180
+ """
181
+ For each example, the tokenizer may produce multiple overlapping
182
+ chunks if the tokens exceed 512 subwords. Each chunk will be
183
+ length=512, with a stride=128 for the next chunk.
184
+ We'll align labels so that subwords beyond the first in a token get -100.
185
+ """
186
+ # We rely on is_split_into_words=True because examples["tokens"] is a list of token strings.
187
+ tokenized_batch = tokenizer(
188
+ examples["tokens"],
189
+ is_split_into_words=True,
190
+ max_length=512,
191
+ stride=128,
192
+ truncation=True,
193
+ return_overflowing_tokens=True,
194
+ return_offsets_mapping=False, # not mandatory for basic alignment
195
+ padding="max_length"
196
+ )
197
+
198
+ # The tokenizer returns "overflow_to_sample_mapping", telling us
199
+ # which original example index each chunk corresponds to.
200
+ # If the tokenizer didn't need to create overflows, the key might be missing
201
+ if "overflow_to_sample_mapping" not in tokenized_batch:
202
+ # No overflow => each input corresponds 1:1 with the original example
203
+ sample_map = [i for i in range(len(tokenized_batch["input_ids"]))]
204
+ else:
205
+ sample_map = tokenized_batch["overflow_to_sample_mapping"]
206
+
207
+ # We'll build lists for final outputs.
208
+ # For each chunk i, we produce:
209
+ # "input_ids"[i], "attention_mask"[i], plus per-feature label IDs.
210
+ final_input_ids = []
211
+ final_attention_mask = []
212
+ final_labels_columns = {feat: [] for feat in ALL_LABELS} # store one label-sequence per chunk
213
+
214
+ for i in range(len(tokenized_batch["input_ids"])):
215
+ # chunk i
216
+ chunk_input_ids = tokenized_batch["input_ids"][i]
217
+ chunk_attn_mask = tokenized_batch["attention_mask"][i]
218
+
219
+ original_index = sample_map[i] # which example in the original batch
220
+ word_ids = tokenized_batch.word_ids(batch_index=i)
221
+
222
+ # We'll build label arrays for each feature
223
+ chunk_labels_dict = {}
224
+
225
+ for feat_name in ALL_LABELS:
226
+ # The UD token-level labels for the *original* example
227
+ token_labels = examples[feat_name][original_index] # e.g. length T
228
+ chunk_label_ids = []
229
+
230
+ previous_word_id = None
231
+ for w_id in word_ids:
232
+ if w_id is None:
233
+ # special token (CLS, SEP, padding)
234
+ chunk_label_ids.append(-100)
235
+ else:
236
+ # If it's the same word_id as before, it's a subword => label = -100
237
+ if w_id == previous_word_id:
238
+ chunk_label_ids.append(-100)
239
+ else:
240
+ # New token => use the actual label
241
+ label_str = token_labels[w_id]
242
+ label_id = LABEL2ID[feat_name][label_str]
243
+ chunk_label_ids.append(label_id)
244
+ previous_word_id = w_id
245
+
246
+ chunk_labels_dict[feat_name] = chunk_label_ids
247
+
248
+ final_input_ids.append(chunk_input_ids)
249
+ final_attention_mask.append(chunk_attn_mask)
250
+ for feat_name in ALL_LABELS:
251
+ final_labels_columns[feat_name].append(chunk_labels_dict[feat_name])
252
+
253
+ # Return the new "flattened" set of chunks
254
+ # So the "map" call will expand each example → multiple chunk examples.
255
+ result = {
256
+ "input_ids": final_input_ids,
257
+ "attention_mask": final_attention_mask,
258
+ }
259
+ # We'll store each feature's label IDs in separate columns (e.g. labels_xpos, labels_deprel, etc.)
260
+ for feat_name in ALL_LABELS:
261
+ result[f"labels_{feat_name}"] = final_labels_columns[feat_name]
262
+
263
+ return result
264
+
265
+ # ------------------------------------------------------------------------------
266
+ # Trainer Setup
267
+ # ------------------------------------------------------------------------------
268
+
269
+ class MultiHeadTrainer(Trainer):
270
+
271
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
272
+ # 1) Gather all your per-feature labels from inputs
273
+ _labels_dict = {}
274
+ for feat_name in ALL_LABELS:
275
+ key = f"labels_{feat_name}"
276
+ if key in inputs:
277
+ _labels_dict[feat_name] = inputs[key]
278
+
279
+ # 2) Remove them so they don't get passed incorrectly to the model
280
+ for key in list(inputs.keys()):
281
+ if key.startswith("labels_"):
282
+ del inputs[key]
283
+
284
+ # 3) Call model(...) with _labels_dict
285
+ outputs = model(**inputs, labels_dict=_labels_dict)
286
+ # 'outputs' is (loss, logits_dict) in training/eval mode
287
+ loss, logits_dict = outputs
288
+
289
+ # Optional: if your special param is used upstream for some logic,
290
+ # you can handle it here or pass it along. For example:
291
+ if num_items_in_batch is not None:
292
+ # ... do something if needed ...
293
+ pass
294
+
295
+ if return_outputs:
296
+ # Return (loss, logits_dict) so Trainer sees logits_dict as predictions
297
+ return (loss, logits_dict)
298
+ else:
299
+ return loss
300
+
301
+ def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
302
+ # 1) gather the "labels_xxx" columns
303
+ _labels_dict = {}
304
+ for feat_name in ALL_LABELS:
305
+ key = f"labels_{feat_name}"
306
+ if key in inputs:
307
+ _labels_dict[feat_name] = inputs[key]
308
+ del inputs[key]
309
+
310
+ # 2) forward pass without those keys
311
+ with torch.no_grad():
312
+ outputs = model(**inputs, labels_dict=_labels_dict)
313
+
314
+ loss, logits_dict = outputs # you are returning (loss, dict-of-arrays)
315
+
316
+ if prediction_loss_only:
317
+ return (loss, None, None)
318
+
319
+ # The trainer expects a triple: (loss, predictions, labels)
320
+ # - 'predictions' can be the dictionary
321
+ # - 'labels' can be the dictionary of label IDs
322
+ return (loss, logits_dict, _labels_dict)
323
+
324
+
325
+ def multi_head_classification_reports(logits_dict, labels_dict, id2label_dict):
326
+ """
327
+ For each head, generate a classification report (precision, recall, f1, etc. per class).
328
+ Return them as a dict: {head_name: "string report"}.
329
+ :param logits_dict: dict of {head_name: np.array(batch_size, seq_len, num_classes)}
330
+ :param labels_dict: dict of {head_name: np.array(batch_size, seq_len)}
331
+ :param id2label_dict: dict of {head_name: {id: label_str}}
332
+ :return: A dict of classification-report strings, one per head.
333
+ """
334
+ reports = {}
335
+
336
+ for head_name, logits in logits_dict.items():
337
+ if head_name not in labels_dict:
338
+ continue
339
+
340
+ predictions = np.argmax(logits, axis=-1)
341
+ valid_preds, valid_labels = [], []
342
+ for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
343
+ for p, lab in zip(pred_seq, label_seq):
344
+ if lab != -100:
345
+ valid_preds.append(p)
346
+ valid_labels.append(lab)
347
+
348
+ if len(valid_preds) == 0:
349
+ reports[head_name] = "No valid predictions."
350
+ continue
351
+
352
+ # Convert numeric IDs to string labels
353
+ valid_preds_str = [id2label_dict[head_name][p] for p in valid_preds]
354
+ valid_labels_str = [id2label_dict[head_name][l] for l in valid_labels]
355
+
356
+ # Generate the per-class classification report
357
+ report_str = classification_report(
358
+ valid_labels_str,
359
+ valid_preds_str,
360
+ zero_division=0
361
+ )
362
+ reports[head_name] = report_str
363
+
364
+ return reports
365
+
366
+
367
+ def multi_head_compute_metrics(logits_dict, labels_dict):
368
+ """
369
+ For each head (e.g. xpos, deprel, Case, etc.), computes:
370
+ - Accuracy
371
+ - Precision (macro/micro)
372
+ - Recall (macro/micro)
373
+ - F1 (macro/micro)
374
+
375
+ :param logits_dict: dict of {head_name: np.array of shape (batch_size, seq_len, num_classes)}
376
+ :param labels_dict: dict of {head_name: np.array of shape (batch_size, seq_len)}
377
+ :return: A dict with aggregated metrics. Keys prefixed by head_name, e.g. "xpos_accuracy", "xpos_f1_macro", etc.
378
+ """
379
+ # We'll accumulate metrics in one big dictionary, keyed by "<head>_<metric>"
380
+ results = {}
381
+
382
+ for head_name, logits in logits_dict.items():
383
+ if head_name not in labels_dict:
384
+ # In case there's a mismatch or a head we didn't provide labels for
385
+ continue
386
+
387
+ # (batch_size, seq_len, num_classes)
388
+ predictions = np.argmax(logits, axis=-1) # => (batch_size, seq_len)
389
+
390
+ # Flatten ignoring positions where label == -100
391
+ valid_preds, valid_labels = [], []
392
+ for pred_seq, label_seq in zip(predictions, labels_dict[head_name]):
393
+ for p, lab in zip(pred_seq, label_seq):
394
+ if lab != -100:
395
+ valid_preds.append(p)
396
+ valid_labels.append(lab)
397
+
398
+ valid_preds = np.array(valid_preds)
399
+ valid_labels = np.array(valid_labels)
400
+
401
+ if len(valid_preds) == 0:
402
+ # No valid data for this head—skip
403
+ continue
404
+
405
+ # Overall token-level accuracy
406
+ accuracy = (valid_preds == valid_labels).mean()
407
+
408
+ # Macro average => treat each class equally
409
+ precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
410
+ valid_labels, valid_preds, average="macro", zero_division=0
411
+ )
412
+
413
+ # Micro average => aggregate across all classes
414
+ precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
415
+ valid_labels, valid_preds, average="micro", zero_division=0
416
+ )
417
+
418
+ results[f"{head_name}_accuracy"] = accuracy
419
+ results[f"{head_name}_precision_macro"] = precision_macro
420
+ results[f"{head_name}_recall_macro"] = recall_macro
421
+ results[f"{head_name}_f1_macro"] = f1_macro
422
+ results[f"{head_name}_precision_micro"] = precision_micro
423
+ results[f"{head_name}_recall_micro"] = recall_micro
424
+ results[f"{head_name}_f1_micro"] = f1_micro
425
+
426
+ return results
427
+
428
+ # ------------------------------------------------------------------------------
429
+ # Instantiate model and tokenizer
430
+ # ------------------------------------------------------------------------------
431
+
432
+ if args.from_base:
433
+ model_name_or_path = args.from_base
434
+ multi_head_model = MultiHeadModel.from_pretrained(
435
+ model_name_or_path,
436
+ config=MultiHeadModelConfig.from_pretrained(
437
+ model_name_or_path,
438
+ num_labels_dict=NUM_LABELS_DICT,
439
+ label_maps=ALL_LABELS
440
+ )
441
+ )
442
+ else:
443
+ model_name_or_path = args.save_path
444
+ # For evaluation, always load the saved checkpoint without overriding the config.
445
+ multi_head_model = MultiHeadModel.from_pretrained(model_name_or_path)
446
+ # EXTREMELY IMPORTANT!
447
+ # Override the label mapping based on the stored config to ensure consistency with training time ordering.
448
+ ALL_LABELS = multi_head_model.config.label_maps
449
+ LABEL2ID = {feat: {label: i for i, label in enumerate(ALL_LABELS[feat])} for feat in ALL_LABELS}
450
+ ID2LABEL = {feat: {i: label for label, i in LABEL2ID[feat].items()} for feat in LABEL2ID}
451
+ logger.info(f"using {model_name_or_path}")
452
+
453
+ # Check if GPU is usable
454
+ if torch.cuda.is_available():
455
+ device = torch.device("cuda")
456
+ elif torch.backends.mps.is_available(): # For Apple Silicon MPS
457
+ device = torch.device("mps")
458
+ else:
459
+ device = torch.device("cpu")
460
+ logger.info(f"using {device}")
461
+ multi_head_model.to(device)
462
+
463
+ tokenizer = DebertaV2TokenizerFast.from_pretrained(
464
+ model_name_or_path,
465
+ add_prefix_space=True,
466
+ )
467
+
468
+ # ------------------------------------------------------------------------------
469
+ # Shuffle, (optionally) sample, and tokenize final merged dataset
470
+ # ------------------------------------------------------------------------------
471
+
472
+ if args.mini:
473
+ loaded_dataset = DatasetDict({
474
+ "train": loaded_dataset["train"].shuffle(seed=42).select(range(1000)),
475
+ "validation": loaded_dataset["validation"].shuffle(seed=42).select(range(100)),
476
+ "test": loaded_dataset["test"].shuffle(seed=42).select(range(100)),
477
+ })
478
+
479
+ # remove_columns => remove old "text", "tokens", etc. so we keep only model inputs
480
+ tokenized_dataset = loaded_dataset.map(
481
+ tokenize_and_align_labels,
482
+ batched=True,
483
+ remove_columns=loaded_dataset["train"].column_names,
484
+ )
485
+
486
+ # ------------------------------------------------------------------------------
487
+ # Train the model!
488
+ # ------------------------------------------------------------------------------
489
+
490
+ """
491
+ Current bests:
492
+
493
+ deberta-v3-base:
494
+ num_train_epochs=3,
495
+ learning_rate=5e-5,
496
+ per_device_train_batch_size=2,
497
+ gradient_accumulation_steps=8,
498
+ """
499
+
500
+ training_args = TrainingArguments(
501
+ # Evaluate less frequently or keep the same
502
+ eval_strategy="epoch",
503
+ num_train_epochs=args.train_epochs,
504
+ learning_rate=args.learning_rate,
505
+
506
+ output_dir="training_output",
507
+ overwrite_output_dir=True,
508
+ remove_unused_columns=False, # important to keep the labels_xxx columns
509
+
510
+ logging_dir="training_logs",
511
+ logging_steps=100,
512
+
513
+ # Effective batch size = train_batch_size x gradient_accumulation_steps
514
+ per_device_train_batch_size=args.train_batch_size,
515
+ gradient_accumulation_steps=args.accumulation_steps,
516
+
517
+ per_device_eval_batch_size=args.eval_batch_size,
518
+ )
519
+
520
+ trainer = MultiHeadTrainer(
521
+ model=multi_head_model,
522
+ args=training_args,
523
+ train_dataset=tokenized_dataset["train"],
524
+ eval_dataset=tokenized_dataset["validation"],
525
+ )
526
+
527
+ if args.train:
528
+ trainer.train()
529
+ trainer.evaluate()
530
+ trainer.save_model(args.save_path)
531
+ tokenizer.save_pretrained(args.save_path)
532
+
533
+ # ------------------------------------------------------------------------------
534
+ # Evaluate the model!
535
+ # ------------------------------------------------------------------------------
536
+
537
+ pred_output = trainer.predict(tokenized_dataset["test"])
538
+ pred_logits_dict = pred_output.predictions
539
+ pred_labels_dict = pred_output.label_ids
540
+ id2label_dict = ID2LABEL # from earlier definitions
541
+
542
+ # 1) Calculate metrics
543
+ metrics = multi_head_compute_metrics(pred_logits_dict, pred_labels_dict)
544
+ for k,v in metrics.items():
545
+ print(f"{k}: {v:.4f}")
546
+
547
+ # 2) Print classification reports
548
+ reports = multi_head_classification_reports(pred_logits_dict, pred_labels_dict, id2label_dict)
549
+ for head_name, rstr in reports.items():
550
+ print(f"----- {head_name} classification report -----")
551
+ print(rstr)
utils/__init__.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
 
2
  default_logging_config = {
3
  "version": 1,
@@ -19,4 +24,51 @@ default_logging_config = {
19
  "handlers": ["console"],
20
  },
21
  },
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import DatasetDict
2
+ from typing import Optional
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
 
7
  default_logging_config = {
8
  "version": 1,
 
24
  "handlers": ["console"],
25
  },
26
  },
27
+ }
28
+
29
+
30
+ def get_uniq_training_labels(ds: DatasetDict, columns_to_exclude: set[str] = None):
31
+ columns_to_train_on = [k for k in ds["train"].features.keys() if k not in (
32
+ {"text", "tokens"} if columns_to_exclude is None else columns_to_exclude)]
33
+
34
+ # Create a dictionary of sets, keyed by each column name
35
+ label_counters = {col: dict() for col in columns_to_train_on}
36
+ unique_label_values = {col: set() for col in columns_to_train_on}
37
+
38
+ # Loop through each split and each example, and collect values
39
+ for split_name, dataset_split in ds.items():
40
+ for example in dataset_split:
41
+ # Each of these columns is a list (one entry per token),
42
+ # so we update our set with each token-level value
43
+ for col in columns_to_train_on:
44
+ unique_label_values[col].update(example[col])
45
+ for label_val in example[col]:
46
+ if label_val not in label_counters[col]:
47
+ label_counters[col][label_val] = 0 # Inits with 0
48
+ label_counters[col][label_val] += 1
49
+
50
+ logger.info(f"Columns:")
51
+ for col in columns_to_train_on:
52
+ logger.info(f" {col}:")
53
+ # Convert to a sorted list just to have a nice, stable ordering
54
+ vals = sorted(unique_label_values[col])
55
+ logger.info(f" {len(vals)} labels: {[f'{v}:{label_counters[col][v]}' for v in vals]}")
56
+
57
+ return unique_label_values
58
+
59
+
60
+ def show_examples(ds: DatasetDict, show_expr: Optional[str]):
61
+ logger.info(f"Dataset:\n{ds}")
62
+ if not show_expr:
63
+ count_to_show = 2
64
+ examples_to_show = ds["train"][:count_to_show]
65
+ else:
66
+ args_show_tokens = show_expr.split("/")
67
+ split_to_show, col_to_show, label_to_show, count_to_show = args_show_tokens
68
+ count_to_show = int(count_to_show)
69
+ examples_to_show = ds[split_to_show].filter(
70
+ lambda exp: label_to_show in exp[col_to_show]).shuffle(seed=42)[:count_to_show]
71
+ for i in range(count_to_show):
72
+ logger.info(f"Example {i}:")
73
+ for feature in examples_to_show.keys():
74
+ logger.info(f" {feature}: {examples_to_show[feature][i]}")