Commit ·
9a7a0bd
1
Parent(s): 7fcdd24
Add NST+NPSC dataset script
Browse files- run.sh +1 -3
- run_speech_recognition_ctc.py +100 -67
run.sh
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_speech_recognition_ctc.py \
|
| 2 |
-
--dataset_name="NbAiLab/NST" \
|
| 3 |
--model_name_or_path="KBLab/wav2vec2-large-voxrex" \
|
| 4 |
-
--hub_model_id="NbAiLab/wav2vec2-large-voxrex-nst" \
|
| 5 |
-
--dataset_config_name="no-close" \
|
| 6 |
--output_dir="./" \
|
| 7 |
--overwrite_output_dir \
|
| 8 |
--num_train_epochs="15" \
|
|
|
|
| 1 |
WANDB_ENTITY=NbAiLab WANDB_PROJECT=wav2vec2 python run_speech_recognition_ctc.py \
|
|
|
|
| 2 |
--model_name_or_path="KBLab/wav2vec2-large-voxrex" \
|
| 3 |
+
--hub_model_id="NbAiLab/wav2vec2-large-voxrex-npsc-nst" \
|
|
|
|
| 4 |
--output_dir="./" \
|
| 5 |
--overwrite_output_dir \
|
| 6 |
--num_train_epochs="15" \
|
run_speech_recognition_ctc.py
CHANGED
|
@@ -47,13 +47,11 @@ from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
|
| 47 |
from transformers.utils import check_min_version
|
| 48 |
from transformers.utils.versions import require_version
|
| 49 |
|
| 50 |
-
|
| 51 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 52 |
check_min_version("4.16.0.dev0")
|
| 53 |
|
| 54 |
require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
| 55 |
|
| 56 |
-
|
| 57 |
logger = logging.getLogger(__name__)
|
| 58 |
|
| 59 |
|
|
@@ -102,8 +100,8 @@ class ModelArguments:
|
|
| 102 |
default=0.05,
|
| 103 |
metadata={
|
| 104 |
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
| 105 |
-
|
| 106 |
-
|
| 107 |
},
|
| 108 |
)
|
| 109 |
mask_time_length: int = field(
|
|
@@ -114,7 +112,7 @@ class ModelArguments:
|
|
| 114 |
default=0.0,
|
| 115 |
metadata={
|
| 116 |
"help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
|
| 117 |
-
|
| 118 |
},
|
| 119 |
)
|
| 120 |
mask_feature_length: int = field(
|
|
@@ -129,6 +127,7 @@ class ModelArguments:
|
|
| 129 |
default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
|
| 130 |
)
|
| 131 |
|
|
|
|
| 132 |
@dataclass
|
| 133 |
class DataTrainingArguments:
|
| 134 |
"""
|
|
@@ -176,14 +175,14 @@ class DataTrainingArguments:
|
|
| 176 |
default=None,
|
| 177 |
metadata={
|
| 178 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 179 |
-
|
| 180 |
},
|
| 181 |
)
|
| 182 |
max_eval_samples: Optional[int] = field(
|
| 183 |
default=None,
|
| 184 |
metadata={
|
| 185 |
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
| 186 |
-
|
| 187 |
},
|
| 188 |
)
|
| 189 |
chars_to_ignore: Optional[List[str]] = list_field(
|
|
@@ -207,16 +206,16 @@ class DataTrainingArguments:
|
|
| 207 |
default=False,
|
| 208 |
metadata={
|
| 209 |
"help": "Whether to only do data preprocessing and skip training. "
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
},
|
| 214 |
)
|
| 215 |
use_auth_token: bool = field(
|
| 216 |
default=False,
|
| 217 |
metadata={
|
| 218 |
"help": "If :obj:`True`, will use the token generated when running"
|
| 219 |
-
|
| 220 |
},
|
| 221 |
)
|
| 222 |
unk_token: str = field(
|
|
@@ -235,9 +234,9 @@ class DataTrainingArguments:
|
|
| 235 |
default=None,
|
| 236 |
metadata={
|
| 237 |
"help": "The target language that should be used be"
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
},
|
| 242 |
)
|
| 243 |
|
|
@@ -303,10 +302,10 @@ class DataCollatorCTCWithPadding:
|
|
| 303 |
|
| 304 |
|
| 305 |
def create_vocabulary_from_data(
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
):
|
| 311 |
# Given training and test labels create vocabulary
|
| 312 |
def extract_all_chars(batch):
|
|
@@ -344,6 +343,85 @@ def create_vocabulary_from_data(
|
|
| 344 |
return vocab_dict
|
| 345 |
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
def main():
|
| 348 |
# See all possible arguments in src/transformers/training_args.py
|
| 349 |
# or by passing the --help flag to this script.
|
|
@@ -393,45 +471,10 @@ def main():
|
|
| 393 |
# Set seed before initializing model.
|
| 394 |
set_seed(training_args.seed)
|
| 395 |
|
| 396 |
-
# Pre-processing dataset
|
| 397 |
-
import re
|
| 398 |
-
|
| 399 |
-
def map_dataset(entry):
|
| 400 |
-
text = entry["text"].lower()
|
| 401 |
-
text = text.replace("(...Vær stille under dette opptaket...)", "")
|
| 402 |
-
text = re.sub('[áàâ]', 'a', text)
|
| 403 |
-
text = re.sub('[ä]', 'æ', text)
|
| 404 |
-
text = re.sub('[éèëê]', 'e', text)
|
| 405 |
-
text = re.sub('[íìïî]', 'i', text)
|
| 406 |
-
text = re.sub('[óòöô]', 'o', text)
|
| 407 |
-
text = re.sub('[ö]', 'ø', text)
|
| 408 |
-
text = re.sub('[ç]', 'c', text)
|
| 409 |
-
text = re.sub('[úùüû]', 'u', text)
|
| 410 |
-
# text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
|
| 411 |
-
text = re.sub('\s+', ' ', text)
|
| 412 |
-
return {"text": text}
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
def filter_dataset(entry):
|
| 416 |
-
if not (len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3):
|
| 417 |
-
return False # Too short
|
| 418 |
-
if re.match(entry["type"], "pIW|CA"):
|
| 419 |
-
return False # Spelling out words
|
| 420 |
-
return True
|
| 421 |
-
|
| 422 |
# 1. First, let's load the dataset
|
| 423 |
-
raw_datasets =
|
| 424 |
|
| 425 |
if training_args.do_train:
|
| 426 |
-
raw_datasets["train"] = load_dataset(
|
| 427 |
-
data_args.dataset_name,
|
| 428 |
-
data_args.dataset_config_name,
|
| 429 |
-
split=data_args.train_split_name,
|
| 430 |
-
use_auth_token=data_args.use_auth_token,
|
| 431 |
-
).shuffle()
|
| 432 |
-
raw_datasets["train"] = raw_datasets["train"].filter(filter_dataset)
|
| 433 |
-
raw_datasets["train"] = raw_datasets["train"].map(map_dataset)
|
| 434 |
-
|
| 435 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
| 436 |
raise ValueError(
|
| 437 |
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
|
|
@@ -450,28 +493,18 @@ def main():
|
|
| 450 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
| 451 |
|
| 452 |
if training_args.do_eval:
|
| 453 |
-
raw_datasets["eval"] = load_dataset(
|
| 454 |
-
data_args.dataset_name,
|
| 455 |
-
data_args.dataset_config_name,
|
| 456 |
-
split=data_args.eval_split_name,
|
| 457 |
-
use_auth_token=data_args.use_auth_token,
|
| 458 |
-
).shuffle()
|
| 459 |
-
raw_datasets["eval"] = raw_datasets["eval"].filter(filter_dataset)
|
| 460 |
-
raw_datasets["eval"] = raw_datasets["eval"].map(map_dataset)
|
| 461 |
-
|
| 462 |
if data_args.max_eval_samples is not None:
|
| 463 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
| 464 |
|
| 465 |
-
|
| 466 |
# 2. We remove some special characters from the datasets
|
| 467 |
# that make training complicated and do not help in transcribing the speech
|
| 468 |
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
| 469 |
# that could be easily picked up by the model
|
| 470 |
-
#chars_to_ignore_regex = (
|
| 471 |
# f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
|
| 472 |
-
#)
|
| 473 |
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\–\_\\\+\#\/]'
|
| 474 |
-
|
| 475 |
text_column_name = data_args.text_column_name
|
| 476 |
|
| 477 |
def remove_special_characters(batch):
|
|
|
|
| 47 |
from transformers.utils import check_min_version
|
| 48 |
from transformers.utils.versions import require_version
|
| 49 |
|
|
|
|
| 50 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 51 |
check_min_version("4.16.0.dev0")
|
| 52 |
|
| 53 |
require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
| 54 |
|
|
|
|
| 55 |
logger = logging.getLogger(__name__)
|
| 56 |
|
| 57 |
|
|
|
|
| 100 |
default=0.05,
|
| 101 |
metadata={
|
| 102 |
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
| 103 |
+
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
|
| 104 |
+
"vectors will be masked along the time axis."
|
| 105 |
},
|
| 106 |
)
|
| 107 |
mask_time_length: int = field(
|
|
|
|
| 112 |
default=0.0,
|
| 113 |
metadata={
|
| 114 |
"help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
|
| 115 |
+
"span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
|
| 116 |
},
|
| 117 |
)
|
| 118 |
mask_feature_length: int = field(
|
|
|
|
| 127 |
default=False, metadata={"help": "If True, will try yo aboud the CTC loss goinf to infinity."}
|
| 128 |
)
|
| 129 |
|
| 130 |
+
|
| 131 |
@dataclass
|
| 132 |
class DataTrainingArguments:
|
| 133 |
"""
|
|
|
|
| 175 |
default=None,
|
| 176 |
metadata={
|
| 177 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 178 |
+
"value if set."
|
| 179 |
},
|
| 180 |
)
|
| 181 |
max_eval_samples: Optional[int] = field(
|
| 182 |
default=None,
|
| 183 |
metadata={
|
| 184 |
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
| 185 |
+
"value if set."
|
| 186 |
},
|
| 187 |
)
|
| 188 |
chars_to_ignore: Optional[List[str]] = list_field(
|
|
|
|
| 206 |
default=False,
|
| 207 |
metadata={
|
| 208 |
"help": "Whether to only do data preprocessing and skip training. "
|
| 209 |
+
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
| 210 |
+
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
| 211 |
+
"so that the cached datasets can consequently be loaded in distributed training"
|
| 212 |
},
|
| 213 |
)
|
| 214 |
use_auth_token: bool = field(
|
| 215 |
default=False,
|
| 216 |
metadata={
|
| 217 |
"help": "If :obj:`True`, will use the token generated when running"
|
| 218 |
+
":obj:`transformers-cli login` as HTTP bearer authorization for remote files."
|
| 219 |
},
|
| 220 |
)
|
| 221 |
unk_token: str = field(
|
|
|
|
| 234 |
default=None,
|
| 235 |
metadata={
|
| 236 |
"help": "The target language that should be used be"
|
| 237 |
+
" passed to the tokenizer for tokenization. Note that"
|
| 238 |
+
" this is only relevant if the model classifies the"
|
| 239 |
+
" input audio to a sequence of phoneme sequences."
|
| 240 |
},
|
| 241 |
)
|
| 242 |
|
|
|
|
| 302 |
|
| 303 |
|
| 304 |
def create_vocabulary_from_data(
|
| 305 |
+
datasets: DatasetDict,
|
| 306 |
+
word_delimiter_token: Optional[str] = None,
|
| 307 |
+
unk_token: Optional[str] = None,
|
| 308 |
+
pad_token: Optional[str] = None,
|
| 309 |
):
|
| 310 |
# Given training and test labels create vocabulary
|
| 311 |
def extract_all_chars(batch):
|
|
|
|
| 343 |
return vocab_dict
|
| 344 |
|
| 345 |
|
| 346 |
+
def make_dataset(seed=42):
|
| 347 |
+
# Pre-processing dataset
|
| 348 |
+
import re
|
| 349 |
+
|
| 350 |
+
def map_nst(entry):
|
| 351 |
+
text = entry["text"].lower()
|
| 352 |
+
text = text.replace("(...Vær stille under dette opptaket...)", "")
|
| 353 |
+
text = re.sub('[áàâ]', 'a', text)
|
| 354 |
+
text = re.sub('[ä]', 'æ', text)
|
| 355 |
+
text = re.sub('[éèëê]', 'e', text)
|
| 356 |
+
text = re.sub('[íìïî]', 'i', text)
|
| 357 |
+
text = re.sub('[óòöô]', 'o', text)
|
| 358 |
+
text = re.sub('[ö]', 'ø', text)
|
| 359 |
+
text = re.sub('[ç]', 'c', text)
|
| 360 |
+
text = re.sub('[úùüû]', 'u', text)
|
| 361 |
+
# text = re.sub('\\(?=(Punktum|Komma|Utropstegn|Spørsmålstegn))', ' ', text)
|
| 362 |
+
text = re.sub('\s+', ' ', text)
|
| 363 |
+
return {"text": text}
|
| 364 |
+
|
| 365 |
+
def filter_nst(entry):
|
| 366 |
+
if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
|
| 367 |
+
return False # Too short
|
| 368 |
+
if re.match(entry["type"], "pIW|CA"):
|
| 369 |
+
return False # Spelling out words
|
| 370 |
+
return True
|
| 371 |
+
|
| 372 |
+
def filter_npsc(entry):
|
| 373 |
+
# False if there are digits in the text
|
| 374 |
+
if not ((len(entry["text"]) <= len(entry["audio"]["array"]) // 320) and (len(entry["text"].strip()) >= 3)):
|
| 375 |
+
return False # Too short
|
| 376 |
+
if re.search("\d", entry["text"]):
|
| 377 |
+
return False
|
| 378 |
+
return True
|
| 379 |
+
|
| 380 |
+
def map_npsc(entry):
|
| 381 |
+
batch = {"text": entry["text"].lower()}
|
| 382 |
+
batch["text"] = re.sub('[áàâ]', 'a', batch["text"])
|
| 383 |
+
batch["text"] = re.sub('[ä]', 'æ', batch["text"])
|
| 384 |
+
batch["text"] = re.sub('[éèëê]', 'e', batch["text"])
|
| 385 |
+
batch["text"] = re.sub('[íìïî]', 'i', batch["text"])
|
| 386 |
+
batch["text"] = re.sub('[óòöô]', 'o', batch["text"])
|
| 387 |
+
batch["text"] = re.sub('[ö]', 'ø', batch["text"])
|
| 388 |
+
batch["text"] = re.sub('[ç]', 'c', batch["text"])
|
| 389 |
+
batch["text"] = re.sub('[úùüû]', 'u', batch["text"])
|
| 390 |
+
batch["text"] = re.sub('\s', ' ', batch["text"])
|
| 391 |
+
batch["text"] = re.sub('<ee>', 'eee', batch["text"])
|
| 392 |
+
batch["text"] = re.sub('<qq>', 'qqq', batch["text"])
|
| 393 |
+
batch["text"] = re.sub('<mm>', 'mmm', batch["text"])
|
| 394 |
+
batch["text"] = re.sub('<inaudible>', 'xxx', batch["text"])
|
| 395 |
+
# batch["text"] = re.sub('<inaudible>', '?', batch["text"])
|
| 396 |
+
if "<" in batch["text"]:
|
| 397 |
+
raise ValueError(batch["text"])
|
| 398 |
+
return batch
|
| 399 |
+
|
| 400 |
+
nst = datasets.load_dataset("NbAiLab/NST", "no-close")
|
| 401 |
+
npsc = datasets.load_dataset("NbAiLab/NPSC", "16K_mp3")
|
| 402 |
+
# TODO NST_hesitate
|
| 403 |
+
|
| 404 |
+
split = len(npsc["train"]) / (len(npsc["train"]) + len(npsc["validation"])) # Use same train/val ratio as NPSC
|
| 405 |
+
nst_train = nst["train"].train_test_split(train_size=split, seed=seed)
|
| 406 |
+
nst["train"] = nst_train["train"]
|
| 407 |
+
nst["validation"] = nst_train["test"]
|
| 408 |
+
|
| 409 |
+
nst = nst.filter(filter_nst).map(map_nst).shuffle(seed=seed)
|
| 410 |
+
npsc = npsc.filter(filter_npsc).map(map_npsc).shuffle(seed=seed)
|
| 411 |
+
|
| 412 |
+
npsc_base = npsc.remove_columns([col for col in npsc["train"].column_names if col not in ["text", "audio"]])
|
| 413 |
+
nst_base = nst.remove_columns([col for col in nst["train"].column_names if col not in ["text", "audio"]])
|
| 414 |
+
|
| 415 |
+
combined = {}
|
| 416 |
+
for split in "train", "validation", "test":
|
| 417 |
+
probs = np.array([len(nst_base[split]), len(npsc_base[split])]) # Weight by number of examples
|
| 418 |
+
probs = (probs / probs.sum()).tolist()
|
| 419 |
+
comb = datasets.interleave_datasets([nst_base[split], npsc_base[split]], probabilities=probs, seed=seed)
|
| 420 |
+
combined[split] = comb
|
| 421 |
+
|
| 422 |
+
return datasets.DatasetDict(**combined)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
def main():
|
| 426 |
# See all possible arguments in src/transformers/training_args.py
|
| 427 |
# or by passing the --help flag to this script.
|
|
|
|
| 471 |
# Set seed before initializing model.
|
| 472 |
set_seed(training_args.seed)
|
| 473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
# 1. First, let's load the dataset
|
| 475 |
+
raw_datasets = make_dataset(seed=training_args.seed)
|
| 476 |
|
| 477 |
if training_args.do_train:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
if data_args.audio_column_name not in raw_datasets["train"].column_names:
|
| 479 |
raise ValueError(
|
| 480 |
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
|
|
|
|
| 493 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
| 494 |
|
| 495 |
if training_args.do_eval:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
if data_args.max_eval_samples is not None:
|
| 497 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
| 498 |
|
|
|
|
| 499 |
# 2. We remove some special characters from the datasets
|
| 500 |
# that make training complicated and do not help in transcribing the speech
|
| 501 |
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
| 502 |
# that could be easily picked up by the model
|
| 503 |
+
# chars_to_ignore_regex = (
|
| 504 |
# f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
|
| 505 |
+
# )
|
| 506 |
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\–\_\\\+\#\/]'
|
| 507 |
+
|
| 508 |
text_column_name = data_args.text_column_name
|
| 509 |
|
| 510 |
def remove_special_characters(batch):
|