veryfansome commited on
Commit
abf3529
·
1 Parent(s): 406d54a

feat: emotions integration

Browse files
Files changed (2) hide show
  1. goemotions_predict.py +63 -0
  2. ud_dataset_maker.py +374 -44
goemotions_predict.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+ import numpy as np
3
+ import torch
4
+
5
+ from utils import get_torch_device
6
+
7
+
8
+ class GoEmotionsPredictor:
9
+ def __init__(self, model_name_or_path: str, subfolder=None):
10
+ self.tokenizer = AutoTokenizer.from_pretrained(
11
+ model_name_or_path, subfolder=subfolder)
12
+ self.model = AutoModelForSequenceClassification.from_pretrained(
13
+ model_name_or_path, subfolder=subfolder)
14
+
15
+ self.label_names = getattr(self.model.config, "label_names", None)
16
+ self.per_label_thresh = getattr(self.model.config, "per_label_thresholds", None)
17
+ self.global_thresh = getattr(self.model.config, "best_global_threshold", 0.65)
18
+
19
+ self.device = get_torch_device()
20
+ self.model.to(self.device)
21
+ self.model.eval()
22
+
23
+ def predict(self, texts, use_per_label=True):
24
+ """
25
+ Args:
26
+ texts (list[str]): A list of raw text strings to classify.
27
+ use_per_label (bool): If True, apply per-label thresholds. If False, apply global threshold.
28
+ Returns:
29
+ A list of dicts, each with {"text": ..., "predicted_labels": [...]}
30
+ """
31
+ encodings = self.tokenizer(
32
+ texts,
33
+ truncation=True,
34
+ padding=True,
35
+ max_length=512,
36
+ return_tensors="pt"
37
+ )
38
+ # Move encodings to same device as the model
39
+ encodings = {k: v.to(self.device) for k, v in encodings.items()}
40
+
41
+ # 1) Run the model to get logits
42
+ with torch.no_grad():
43
+ outputs = self.model(**encodings)
44
+ logits = outputs.logits # shape: (batch_size, num_labels)
45
+ probs = torch.sigmoid(logits).cpu().numpy() # shape: (batch_size, num_labels)
46
+
47
+ # 2) Determine predictions by thresholding
48
+ if use_per_label:
49
+ # Use per-label thresholds
50
+ threshold_array = np.array(self.per_label_thresh)
51
+ preds = (probs >= threshold_array).astype(int) # shape: (batch_size, num_labels)
52
+ else:
53
+ # Use global threshold
54
+ preds = (probs >= self.global_thresh).astype(int)
55
+
56
+ # 3) Convert integer predictions to label names
57
+ results = []
58
+ for i, text in enumerate(texts):
59
+ row_preds = preds[i]
60
+ predicted_labels = [self.label_names[j] for j, val in enumerate(row_preds) if val == 1]
61
+ results.append({"text": text, "emotions": predicted_labels})
62
+
63
+ return results
ud_dataset_maker.py CHANGED
@@ -1,14 +1,21 @@
1
  from datasets import load_dataset, DatasetDict, concatenate_datasets
 
 
2
  import argparse
3
  import ast
 
4
  import logging.config
5
  import random
6
 
 
7
  from utils.typos import generate_typo
8
  from utils import default_logging_config, get_uniq_training_labels, show_examples
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
 
 
12
  allowed_xpos = [
13
  "''",
14
  '$',
@@ -111,11 +118,6 @@ allowed_deprel = [
111
  'xcomp',
112
  ]
113
 
114
- target_feats = [
115
- "Case", "Definite", "Degree", "Gender", "Mood", "NumType", "Number",
116
- "Person", "Poss", "PronType", "Reflex", "Tense", "Typo", "VerbForm"
117
- ]
118
-
119
  non_target_feats = { # Found programmatically and added after analysis
120
  "Abbr": [],
121
  "Foreign": [],
@@ -123,6 +125,68 @@ non_target_feats = { # Found programmatically and added after analysis
123
  "Voice": [],
124
  }
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  def add_target_feat_columns(exp):
128
  """
@@ -142,6 +206,283 @@ def add_target_feat_columns(exp):
142
  return exp
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def introduce_typos(exp, typo_probability=0.03):
146
  """
147
  Randomly introduce typos in some % of tokens.
@@ -268,10 +609,34 @@ def transform_and_filter_dataset(ud_dataset, dataset_name="ewt"):
268
  if dataset_name == "pud":
269
  _split_ds = _split_ds.map(replace_bracket_label)
270
  filtered_split = _split_ds.filter(lambda ex: is_valid_example(ex, dataset_name=dataset_name))
 
271
  transformed_split = filtered_split.map(
272
  add_target_feat_columns,
273
  batched=False
274
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  transformed_split = transformed_split.remove_columns(["deps", "feats", "head", "idx", "lemmas", "misc", "upos"])
276
  new_splits[_split_name] = transformed_split.filter(is_evenly_shaped)
277
  return DatasetDict(new_splits)
@@ -312,55 +677,20 @@ if __name__ == "__main__":
312
  en_gum_processed = transform_and_filter_dataset(ud_en_gum_ds, "gum")
313
  en_pud_processed = transform_and_filter_dataset(ud_en_pud_ds, "pud")
314
 
315
- def is_rare_case(exp):
316
- if "ADD" in exp["xpos"]:
317
- return True
318
- if "LS" in exp["xpos"]:
319
- return True
320
- if "WP$" in exp["xpos"]:
321
- return True
322
- if "Cmp" in exp["Degree"]:
323
- return True
324
- if "Sup" in exp["Degree"]:
325
- return True
326
- if "Fem" in exp["Gender"]:
327
- return True
328
- if "Imp" in exp["Mood"]:
329
- return True
330
- if "Mult" in exp["NumType"]:
331
- return True
332
- if "Ord" in exp["NumType"]:
333
- return True
334
- if "1" in exp["Person"]:
335
- return True
336
- if "2" in exp["Person"]:
337
- return True
338
- if "Int" in exp["PronType"]:
339
- return True
340
- if "Rel" in exp["PronType"]:
341
- return True
342
- if "Yes" in exp["Reflex"]:
343
- return True
344
- if "Yes" in exp["Typo"]:
345
- return True
346
- if "Ger" in exp["VerbForm"]:
347
- return True
348
- return False
349
-
350
  # Concatenate Datasets
351
  final_dataset = DatasetDict()
352
  final_dataset["test"] = concatenate_datasets(
353
  [
354
  en_ewt_processed["test"],
355
- en_gum_processed["test"], #.filter(is_rare_case),
356
- en_pud_processed["test"], #.filter(is_rare_case),
357
  ]
358
  )
359
 
360
  final_dataset["train"] = concatenate_datasets(
361
  [
362
  en_ewt_processed["train"],
363
- en_gum_processed["train"], #.filter(is_rare_case),
364
  ]
365
  )
366
  if args.augment_typos:
@@ -369,7 +699,7 @@ if __name__ == "__main__":
369
  final_dataset["validation"] = concatenate_datasets(
370
  [
371
  en_ewt_processed["validation"],
372
- en_gum_processed["validation"], #.filter(is_rare_case),
373
  ]
374
  )
375
  show_examples(final_dataset, args.show)
 
1
  from datasets import load_dataset, DatasetDict, concatenate_datasets
2
+ from openai import OpenAI
3
+ from traceback import format_exc
4
  import argparse
5
  import ast
6
+ import json
7
  import logging.config
8
  import random
9
 
10
+ from goemotions_predict import GoEmotionsPredictor
11
  from utils.typos import generate_typo
12
  from utils import default_logging_config, get_uniq_training_labels, show_examples
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
+ goemotions_predictor = GoEmotionsPredictor(
17
+ "veryfansome/deberta-goemotions", subfolder="pos_weight_best")
18
+
19
  allowed_xpos = [
20
  "''",
21
  '$',
 
118
  'xcomp',
119
  ]
120
 
 
 
 
 
 
121
  non_target_feats = { # Found programmatically and added after analysis
122
  "Abbr": [],
123
  "Foreign": [],
 
125
  "Voice": [],
126
  }
127
 
128
+ openai_classification_params = {
129
+ "model": "gpt-4o",
130
+ "temperature": 0.0,
131
+
132
+ #"model": "o3-mini",
133
+ #"reasoning_effort": "high",
134
+
135
+ "top_p": 1.0,
136
+ "presence_penalty": 0.0,
137
+ "frequency_penalty": 0.0,
138
+ "timeout": 30,
139
+ }
140
+
141
+ target_feats = [
142
+ "Case", "Definite", "Degree", "Gender", "Mood", "NumType", "Number",
143
+ "Person", "Poss", "PronType", "Reflex", "Tense", "Typo", "VerbForm"
144
+ ]
145
+
146
+ word_lists_limiting_adjectives = [
147
+ "any",
148
+ "certain",
149
+ "each",
150
+ "every",
151
+ "other",
152
+ "some",
153
+
154
+ # Demonstrative adjectives / determiners
155
+ "that",
156
+ "these",
157
+ "this",
158
+ "those",
159
+ ]
160
+ word_lists_difference_adjectives = [
161
+ "contrasting",
162
+ "different",
163
+ "disparate",
164
+ "dissimilar",
165
+ "distinct",
166
+ "divergent",
167
+ "diverse",
168
+ "heterogeneous",
169
+ "varied",
170
+ "various",
171
+ ]
172
+
173
+ word_lists_similarity_adjectives = [
174
+ "alike",
175
+ "analogous",
176
+ "comparable",
177
+ "equal",
178
+ "equivalent",
179
+ "homogeneous",
180
+ "identical",
181
+ "interchangeable",
182
+ "same",
183
+ "similar",
184
+ ]
185
+
186
+ word_lists_states_of_being_verbs = [
187
+ "am", "are", "be", "been", "being", "is", "was", "were",
188
+ ]
189
+
190
 
191
  def add_target_feat_columns(exp):
192
  """
 
206
  return exp
207
 
208
 
209
+ def extract_label_groups(exp, feat, target_labels=None):
210
+ """
211
+ For example, given a list of labels (e.g. ["O", "O", "NN", "NN", "O", "O", "NNS", "O"]),
212
+ this function will extract the index positions of the labels: NN, NNS, NNP, NNPS.
213
+
214
+ It returns a list of consecutive index groupings for those noun labels.
215
+ For example:
216
+ ["O", "O", "NN", "NN", "O", "O", "NNS", "O"]
217
+ would return:
218
+ [[2, 3], [6]]
219
+
220
+ Args:
221
+ exp: Example
222
+ feat: feature
223
+ target_labels (set of str): The set of tags to target.
224
+
225
+ Returns:
226
+ list of lists of int: A list where each sub-list contains consecutive indices
227
+ of labels that match NN, NNS, NNP, NNPS.
228
+ """
229
+ groups = []
230
+ current_group = []
231
+
232
+ for idx, label in enumerate(exp[feat]):
233
+ if (label in target_labels) if target_labels is not None else label != "O":
234
+ # If current_group is empty or the current idx is consecutive (i.e., previous index + 1),
235
+ # append to current_group. Otherwise, start a new group.
236
+ if current_group and idx == current_group[-1] + 1:
237
+ current_group.append(idx)
238
+ else:
239
+ if current_group:
240
+ groups.append(current_group)
241
+ current_group = [idx]
242
+ else:
243
+ if current_group:
244
+ groups.append(current_group)
245
+ current_group = []
246
+
247
+ # If there's an open group at the end, add it
248
+ if current_group:
249
+ groups.append(current_group)
250
+
251
+ return groups
252
+
253
+
254
+ def introduce_emotion(exp):
255
+ exp["Emotion"] = ["X" for _ in exp["tokens"]]
256
+ labels = [l.upper() for l in goemotions_predictor.predict([exp["text"]], use_per_label=True)[0]["emotions"] if l != "neutral"]
257
+ labels.append("O")
258
+ labels_len = len(labels)
259
+ label_blob = ", ".join([(f"or {l}" if (labels_len > 1 and i == labels_len - 1) else l) for i, l in enumerate(labels)])
260
+ logger.info(f"label_blob: {label_blob}")
261
+ if label_blob != "O":
262
+ for capture_group in extract_label_groups(exp, "xpos", {
263
+ "JJ", "JJR", "JJS",
264
+ "NN", "NNS", "NNP", "NNPS",
265
+ "RB", "RBR", "RBS",
266
+ "VB", "VBD", "VBG", "VBN", "VBP", "VBZ",
267
+ }):
268
+ for token_idx in capture_group:
269
+ token = exp["tokens"][token_idx]
270
+ if token in word_lists_states_of_being_verbs:
271
+ exp["Emotion"][token_idx] = "O"
272
+ else:
273
+ with OpenAI() as client:
274
+ while exp["Emotion"][token_idx] == "X": # While not labeled
275
+ try:
276
+ completion = client.chat.completions.create(
277
+ messages=[
278
+ {
279
+ "role": "system",
280
+ "content": f"""
281
+ Classify '{token}' at token index position {token_idx} by choosing the best fitting emotion label or O if out of scope.
282
+ Pay close attention to semantic context but don't over-generalize if there is not enough context in the provided text.
283
+ Return only the label value, nothing else.
284
+ """.replace("\n", "").strip()
285
+ },
286
+ {
287
+ "role": "user",
288
+ "content": exp["text"]
289
+ },
290
+ {
291
+ "role": "user",
292
+ "content": str(exp["tokens"])
293
+ },
294
+ {
295
+ "role": "user",
296
+ "content": f"The word '{token}' at token index position {token_idx} above evokes {label_blob}?"
297
+ },
298
+ ],
299
+ **openai_classification_params,
300
+ response_format={
301
+ "type": "json_schema",
302
+ "json_schema": {
303
+ "name": "label",
304
+ "strict": True,
305
+ "schema": {
306
+ "type": "object",
307
+ "properties": {
308
+ "label": {
309
+ "type": "string",
310
+ "enum": labels
311
+ }
312
+ },
313
+ "additionalProperties": False,
314
+ "required": ["label"]
315
+ }
316
+ }
317
+ },
318
+ )
319
+ # Set so occasional hallucinations are retried
320
+ new_label = json.loads(completion.choices[0].message.content)['label']
321
+ logger.info(f"{token_idx}:{token} {new_label}")
322
+ if new_label in labels:
323
+ exp["Emotion"][token_idx] = new_label
324
+ except Exception as e:
325
+ logger.error(f"failed to get label, trying again:\n{format_exc()}")
326
+ exp["Emotion"] = [("O" if l == "X" else l) for l in exp["Emotion"]]
327
+ logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", "Emotion"}]))
328
+ return exp
329
+
330
+
331
+ def introduce_adj_type(exp):
332
+ exp["AdjType"] = ["O" for _ in exp["tokens"]]
333
+ labels = ["Quantity", "Quality", "Size", "Age", "Shape", "Color", "Origin", "Material", "Purpose"]
334
+ labels_len = len(labels)
335
+ label_blob = ", ".join([(f"or {l}" if i == labels_len - 1 else l) for i, l in enumerate(labels)])
336
+ if "JJ" in exp["xpos"] or "JJR" in exp["xpos"] or "JJS" in exp["xpos"]:
337
+ for jj_group in extract_label_groups(exp, "xpos", {"JJ", "JJR", "JJS"}):
338
+ for jj_idx in jj_group:
339
+ jj_token = exp["tokens"][jj_idx]
340
+ if jj_token in word_lists_difference_adjectives:
341
+ exp["AdjType"][jj_idx] = "Difference"
342
+ elif jj_token in word_lists_limiting_adjectives:
343
+ exp["AdjType"][jj_idx] = "Limit"
344
+ elif jj_token in word_lists_similarity_adjectives:
345
+ exp["AdjType"][jj_idx] = "Similarity"
346
+ else:
347
+ with OpenAI() as client:
348
+ while exp["AdjType"][jj_idx] == "O": # While not labeled
349
+ try:
350
+ completion = client.chat.completions.create(
351
+ messages=[
352
+ {
353
+ "role": "system",
354
+ "content": f"""
355
+ Classify '{jj_token}' at token index position {jj_idx} by choosing the best fitting adjective label. Return only the
356
+ label value, nothing else.
357
+ """.replace("\n", "").strip()
358
+ },
359
+ {
360
+ "role": "user",
361
+ "content": exp["text"]
362
+ },
363
+ {
364
+ "role": "user",
365
+ "content": str(exp["tokens"])
366
+ },
367
+ {
368
+ "role": "user",
369
+ "content": f"The adjective '{jj_token}' at token index position {jj_idx} above describes a {label_blob}?"
370
+ },
371
+ ],
372
+ **openai_classification_params,
373
+ response_format={
374
+ "type": "json_schema",
375
+ "json_schema": {
376
+ "name": "label",
377
+ "strict": True,
378
+ "schema": {
379
+ "type": "object",
380
+ "properties": {
381
+ "label": {
382
+ "type": "string",
383
+ "enum": labels
384
+ }
385
+ },
386
+ "additionalProperties": False,
387
+ "required": ["label"]
388
+ }
389
+ }
390
+ },
391
+ )
392
+ # Set so occasional hallucinations are retried
393
+ new_label = json.loads(completion.choices[0].message.content)['label']
394
+ logger.info(f"{jj_idx}:{jj_token} {new_label}")
395
+ if new_label in labels:
396
+ exp["AdjType"][jj_idx] = new_label
397
+ except Exception as e:
398
+ logger.error(f"failed to get label, trying again:\n{format_exc()}")
399
+ logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", "AdjType"}]))
400
+ return exp
401
+
402
+
403
+ def introduce_ner_feature(exp, class_name: str, class_desc: str):
404
+ class_name_capital = class_name.capitalize()
405
+ class_name_upper = class_name.upper()
406
+ class_feature_name = f"Ner{class_name_capital}"
407
+ exp[class_feature_name] = ["X" for _ in exp["tokens"]]
408
+
409
+ labels = [f"B-{class_name_upper}", f"I-{class_name_upper}", "O"]
410
+ labels_len = len(labels)
411
+ label_blob = ", ".join([(f"or {l}" if i == labels_len - 1 else l) for i, l in enumerate(labels)])
412
+ for capital_idx in [i for i, t in enumerate(exp["tokens"]) if len(t) > 0
413
+ and t[0].isupper()
414
+ and exp["xpos"][i] in {
415
+ "JJ", "JJR", "JJS",
416
+ "NN", "NNS", "NNP", "NNPS"
417
+ }]:
418
+ capital_token = exp["tokens"][capital_idx]
419
+ with OpenAI() as client:
420
+ while exp[class_feature_name][capital_idx] == "X": # While not labeled
421
+ try:
422
+ completion = client.chat.completions.create(
423
+ messages=[
424
+ {
425
+ "role": "system",
426
+ "content": "You are an expert in recognizing all kinds of names.",
427
+ },
428
+ {
429
+ "role": "user",
430
+ "content": f"""
431
+ Classify '{capital_token}' at token index position {capital_idx} by choosing the best fitting BIO named entity label.
432
+ Pay close attention to semantic context and neighboring tokens but don't over-generalize if there is not enough context
433
+ in the provided text. Classify '{capital_token}' as a {class_name_upper} if it is being used as a part of a
434
+ {class_desc}. Use the B-{class_name_upper} label if the token begins a {class_name_upper} name entity and the
435
+ I-{class_name_upper} label if '{capital_token}' continues a {class_name_upper} name entity. Return only the label
436
+ value, nothing else.
437
+ """.replace("\n", "").strip()
438
+ },
439
+ {
440
+ "role": "user",
441
+ "content": exp["text"]
442
+ },
443
+ {
444
+ "role": "user",
445
+ "content": str(exp["tokens"])
446
+ },
447
+ {
448
+ "role": "user",
449
+ "content": (f"The token '{capital_token}' at index position {capital_idx} above "
450
+ f"is used as a {label_blob} in the text?")
451
+ },
452
+ ],
453
+ **openai_classification_params,
454
+ response_format={
455
+ "type": "json_schema",
456
+ "json_schema": {
457
+ "name": "label",
458
+ "strict": True,
459
+ "schema": {
460
+ "type": "object",
461
+ "properties": {
462
+ "label": {
463
+ "type": "string",
464
+ "enum": labels
465
+ }
466
+ },
467
+ "additionalProperties": False,
468
+ "required": ["label"]
469
+ }
470
+ }
471
+ },
472
+ )
473
+ # Set if valid label so occasional hallucinations are retried
474
+ new_label = json.loads(completion.choices[0].message.content)['label']
475
+ logger.info(f"{capital_idx}:{capital_token} {new_label}")
476
+ if new_label in labels:
477
+ exp[class_feature_name][capital_idx] = new_label
478
+ except Exception as e:
479
+ logger.error(f"failed to get {class_feature_name} label for {capital_token} at idx {capital_idx} "
480
+ f"in \"{exp['text']}\", trying again:\n{format_exc()}")
481
+ exp[class_feature_name] = [("O" if l == "X" else l) for l in exp[class_feature_name]]
482
+ logger.info("\n" + "\n".join([f"{k}\t{v}" for k, v in exp.items() if k in {"tokens", class_feature_name}]))
483
+ return exp
484
+
485
+
486
  def introduce_typos(exp, typo_probability=0.03):
487
  """
488
  Randomly introduce typos in some % of tokens.
 
609
  if dataset_name == "pud":
610
  _split_ds = _split_ds.map(replace_bracket_label)
611
  filtered_split = _split_ds.filter(lambda ex: is_valid_example(ex, dataset_name=dataset_name))
612
+
613
  transformed_split = filtered_split.map(
614
  add_target_feat_columns,
615
  batched=False
616
  )
617
+ # TODO:
618
+ # - Get emotion classes and label adj and adv tokens based on classified emotions. This connects descriptions,
619
+ # with the kind of attribute, with the emotions evoked.
620
+ # - checkpoints after each phase to avoid costly re-dos
621
+ transformed_split = transformed_split.map(introduce_emotion, batched=False)
622
+ transformed_split = transformed_split.map(introduce_adj_type, batched=False)
623
+ transformed_split = transformed_split.map(
624
+ lambda exp: introduce_ner_feature(
625
+ exp, "location",
626
+ "location's name"),
627
+ batched=False)
628
+ transformed_split = transformed_split.map(
629
+ lambda exp: introduce_ner_feature(
630
+ exp, "organization",
631
+ "organization's name"),
632
+ batched=False)
633
+ transformed_split = transformed_split.map(
634
+ lambda exp: introduce_ner_feature(
635
+ exp, "person",
636
+ "person's name"),
637
+ batched=False)
638
+
639
+ new_splits[_split_name] = transformed_split
640
  transformed_split = transformed_split.remove_columns(["deps", "feats", "head", "idx", "lemmas", "misc", "upos"])
641
  new_splits[_split_name] = transformed_split.filter(is_evenly_shaped)
642
  return DatasetDict(new_splits)
 
677
  en_gum_processed = transform_and_filter_dataset(ud_en_gum_ds, "gum")
678
  en_pud_processed = transform_and_filter_dataset(ud_en_pud_ds, "pud")
679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
  # Concatenate Datasets
681
  final_dataset = DatasetDict()
682
  final_dataset["test"] = concatenate_datasets(
683
  [
684
  en_ewt_processed["test"],
685
+ en_gum_processed["test"],
686
+ en_pud_processed["test"],
687
  ]
688
  )
689
 
690
  final_dataset["train"] = concatenate_datasets(
691
  [
692
  en_ewt_processed["train"],
693
+ en_gum_processed["train"],
694
  ]
695
  )
696
  if args.augment_typos:
 
699
  final_dataset["validation"] = concatenate_datasets(
700
  [
701
  en_ewt_processed["validation"],
702
+ en_gum_processed["validation"],
703
  ]
704
  )
705
  show_examples(final_dataset, args.show)