| | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \ |
| | TrainerCallback |
| | from datasets import load_from_disk |
| | from data_handler import DataCollatorCTCWithPadding |
| | from transformers import TrainingArguments |
| | from transformers import Trainer, logging |
| | from metric_utils import compute_metrics_fn |
| | from transformers.trainer_utils import get_last_checkpoint |
| | import json |
| | import os, glob |
| | from callbacks import BreakEachEpoch |
| | import subprocess |
| | from multiprocessing import Process |
| | import shutil |
| |
|
| | logging.set_verbosity_info() |
| |
|
| |
|
| | def load_pretrained_model(checkpoint_path=None): |
| | if checkpoint_path is None: |
| | pre_trained_path = './model-bin/pretrained/base' |
| | tokenizer = Wav2Vec2CTCTokenizer("./model-bin/finetune/vocab.json", |
| | unk_token="<unk>", |
| | pad_token="<pad>", |
| | word_delimiter_token="|") |
| |
|
| | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pre_trained_path) |
| | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) |
| |
|
| | model = Wav2Vec2ForCTC.from_pretrained( |
| | pre_trained_path, |
| | gradient_checkpointing=True, |
| | ctc_loss_reduction="mean", |
| | pad_token_id=processor.tokenizer.pad_token_id, |
| | ) |
| | model.freeze_feature_extractor() |
| | else: |
| | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint_path) |
| |
|
| | feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint_path) |
| | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) |
| |
|
| | model = Wav2Vec2ForCTC.from_pretrained( |
| | checkpoint_path, |
| | gradient_checkpointing=True, |
| | ctc_loss_reduction="mean", |
| | pad_token_id=processor.tokenizer.pad_token_id, |
| | ) |
| | |
| |
|
| | |
| | model_total_params = sum(p.numel() for p in model.parameters()) |
| | model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | print(model) |
| | print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params, |
| | model_total_params_trainable)) |
| | return model, processor |
| |
|
| |
|
| | def prepare_dataset(batch, processor): |
| | |
| | assert ( |
| | len(set(batch["sampling_rate"])) == 1 |
| | ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}." |
| |
|
| | batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values |
| |
|
| | batch["length"] = [len(item) for item in batch["input_values"]] |
| |
|
| | with processor.as_target_processor(): |
| | batch["labels"] = processor(batch["target_text"]).input_ids |
| | return batch |
| |
|
| |
|
| | def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=5): |
| | try: |
| | dataset = load_from_disk(path) |
| | list_cache_prefetch_files = glob.glob( |
| | cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace( |
| | '.arrow', '*')) |
| |
|
| | |
| | if cache_file_map_name.startswith(cache_processing_dataset_folder_prefetch): |
| | if len(glob.glob(cache_file_map_name.replace(cache_processing_dataset_folder_prefetch, |
| | cache_processing_dataset_folder).replace('.arrow', '*'))) > 0: |
| | return |
| | if len(list_cache_prefetch_files) > 0: |
| | return |
| |
|
| | |
| | if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0: |
| | for item_file in list_cache_prefetch_files: |
| | shutil.move(item_file, item_file.replace(cache_processing_dataset_folder_prefetch, |
| | cache_processing_dataset_folder)) |
| | if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0: |
| | return dataset.map(prepare_dataset, |
| | remove_columns=dataset.column_names, |
| | batch_size=32, |
| | num_proc=num_proc, |
| | batched=True, |
| | fn_kwargs={"processor": processor}, |
| | cache_file_name=cache_file_map_name) |
| |
|
| | dataset = dataset.filter(lambda example: len(example['speech']) < 160000, |
| | batch_size=32, |
| | num_proc=num_proc, |
| | cache_file_name=cache_file_filter_name) |
| | processed_dataset = dataset.map(prepare_dataset, |
| | remove_columns=dataset.column_names, |
| | batch_size=32, |
| | num_proc=num_proc, |
| | batched=True, |
| | fn_kwargs={"processor": processor}, |
| | cache_file_name=cache_file_map_name) |
| | processed_dataset.cleanup_cache_files() |
| | return processed_dataset |
| | except: |
| | return None |
| |
|
| |
|
| | def commit_checkpoint(): |
| | submit_commands = [ |
| | 'git add model-bin/finetune/base/*', |
| | 'git commit -m "auto-commit"', |
| | 'git push origin main' |
| | ] |
| | for command in submit_commands: |
| | print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8')) |
| |
|
| |
|
| | def get_train_test_shard_id(epoch_count): |
| | |
| | _train_dataset_shard_idx = epoch_count % num_train_shards |
| | |
| | _test_dataset_shard_idx = min(round(_train_dataset_shard_idx / (num_train_shards / num_test_shards)), num_test_shards - 1) |
| | _num_test_sub_shard = 8 |
| | _idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard |
| | return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard |
| |
|
| |
|
| | def process_prefetch_epoch(epoch_count): |
| | train_shard_idx, test_shard_idx, _, _ = get_train_test_shard_id(epoch_count) |
| | load_prepared_dataset(os.path.join(train_dataset_root_folder, |
| | 'shard_{}'.format(train_shard_idx)), |
| | w2v_ctc_processor, |
| | cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch, |
| | 'train', |
| | 'cache-train-filter-shard-{}.arrow'.format( |
| | train_shard_idx)), |
| | cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch, |
| | 'train', |
| | 'cache-train-map-shard-{}.arrow'.format( |
| | train_shard_idx)), |
| | ) |
| | load_prepared_dataset(os.path.join(test_dataset_root_folder, |
| | 'shard_{}'.format(test_shard_idx)), |
| | w2v_ctc_processor, |
| | cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch, |
| | 'test', |
| | 'cache-test-filter-shard-{}.arrow'.format( |
| | test_shard_idx)), |
| | cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch, 'test', |
| | 'cache-test-map-shard-{}.arrow'.format( |
| | test_shard_idx)) |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | checkpoint_path = "./model-bin/finetune/base/" |
| |
|
| | |
| | |
| |
|
| | train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset' |
| | test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset' |
| |
|
| | cache_processing_dataset_folder = '/dev/shm/cache/' |
| | cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/' |
| | if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')): |
| | os.makedirs(os.path.join(cache_processing_dataset_folder, 'train')) |
| | os.makedirs(os.path.join(cache_processing_dataset_folder, 'test')) |
| | if not os.path.exists(os.path.join(cache_processing_dataset_folder_prefetch, 'train')): |
| | os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'train')) |
| | os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'test')) |
| | num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*'))) |
| | num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*'))) |
| | num_epochs = 5000 |
| |
|
| | training_args = TrainingArguments( |
| | output_dir=checkpoint_path, |
| | fp16=True, |
| | group_by_length=True, |
| | per_device_train_batch_size=32, |
| | per_device_eval_batch_size=32, |
| | gradient_accumulation_steps=2, |
| | num_train_epochs=num_epochs, |
| | logging_steps=5, |
| | learning_rate=1e-5, |
| | weight_decay=0.005, |
| | warmup_steps=1000, |
| | save_total_limit=2, |
| | ignore_data_skip=True, |
| | logging_dir=os.path.join(checkpoint_path, 'log'), |
| | metric_for_best_model='wer', |
| | save_strategy="epoch", |
| | evaluation_strategy="epoch", |
| | greater_is_better=False, |
| | |
| | |
| | ) |
| | trainer = None |
| |
|
| | |
| | last_checkpoint_path = None |
| | last_epoch_idx = 0 |
| | if os.path.exists(checkpoint_path): |
| | last_checkpoint_path = get_last_checkpoint(checkpoint_path) |
| | if last_checkpoint_path is not None: |
| | with open(os.path.join(last_checkpoint_path, "trainer_state.json"), 'r', encoding='utf-8') as file: |
| | trainer_state = json.load(file) |
| | last_epoch_idx = int(trainer_state['epoch']) |
| |
|
| | w2v_ctc_model, w2v_ctc_processor = load_pretrained_model() |
| | data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True) |
| |
|
| | prefetch_process = [] |
| |
|
| | for epoch_idx in range(last_epoch_idx, num_epochs): |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | train_dataset_shard_idx, test_dataset_shard_idx, num_test_sub_shard, idx_sub_shard = get_train_test_shard_id( |
| | epoch_idx) |
| |
|
| | |
| | for process_instance in prefetch_process: |
| | process_instance.join() |
| | prefetch_process.clear() |
| |
|
| | |
| | train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder, |
| | 'shard_{}'.format(train_dataset_shard_idx)), |
| | w2v_ctc_processor, |
| | cache_file_filter_name=os.path.join(cache_processing_dataset_folder, |
| | 'train', |
| | 'cache-train-filter-shard-{}.arrow'.format( |
| | train_dataset_shard_idx)), |
| | cache_file_map_name=os.path.join(cache_processing_dataset_folder, |
| | 'train', |
| | 'cache-train-map-shard-{}.arrow'.format( |
| | train_dataset_shard_idx)), |
| | ) |
| | |
| | test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder, |
| | 'shard_{}'.format(test_dataset_shard_idx)), |
| | w2v_ctc_processor, |
| | cache_file_filter_name=os.path.join(cache_processing_dataset_folder, |
| | 'test', |
| | 'cache-test-filter-shard-{}.arrow'.format( |
| | test_dataset_shard_idx)), |
| | cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'test', |
| | 'cache-test-map-shard-{}.arrow'.format( |
| | test_dataset_shard_idx)) |
| | ) |
| | if train_dataset is None or test_dataset is None: |
| | print("Ignore Shard {}".format(train_dataset_shard_idx)) |
| | continue |
| | test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard) |
| |
|
| | |
| | prefetch_process.append(Process(target=process_prefetch_epoch, args=(epoch_idx + 1,))) |
| | for process_instance in prefetch_process: |
| | process_instance.start() |
| |
|
| | |
| | if trainer is None: |
| | trainer = Trainer( |
| | model=w2v_ctc_model, |
| | data_collator=data_collator, |
| | args=training_args, |
| | compute_metrics=compute_metrics_fn(w2v_ctc_processor), |
| | train_dataset=train_dataset, |
| | eval_dataset=test_dataset, |
| | tokenizer=w2v_ctc_processor.feature_extractor, |
| | callbacks=[BreakEachEpoch()] |
| | ) |
| | else: |
| | trainer.train_dataset = train_dataset |
| | trainer.eval_dataset = test_dataset |
| |
|
| | logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards)) |
| | logging.get_logger().info( |
| | 'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard)) |
| |
|
| | if last_checkpoint_path is not None: |
| | |
| | trainer.train(resume_from_checkpoint=True) |
| | else: |
| | |
| | trainer.train() |
| | last_checkpoint_path = get_last_checkpoint(checkpoint_path) |
| |
|
| | |
| | test_dataset.cleanup_cache_files() |
| | train_dataset.cleanup_cache_files() |
| |
|
| | if epoch_idx % 5 == 0: |
| | commit_checkpoint() |
| |
|