Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| import logging | |
| import os | |
| import sys | |
| import speechbrain as sb | |
| import torch | |
| import torchaudio | |
| import librosa | |
| from common_accent_prepare import prepare_common_accent | |
| from hyperpyyaml import load_hyperpyyaml | |
| import pickle | |
| """Recipe for training an Accent Classification system with CommonVoice Accent. | |
| To run this recipe, do the following: | |
| > python train_w2v2.py hparams/train_w2v2.yaml | |
| Author | |
| ------ | |
| * Juan Pablo Zuluaga 2023 | |
| """ | |
| logger = logging.getLogger(__name__) | |
| import ipdb | |
| # Brain class for Accent ID training | |
| class AID(sb.Brain): | |
| def __init__(self, gemeinde_df, accents_encoder, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.gemeinde_df = gemeinde_df | |
| self.accents_encoder = accents_encoder | |
| def prepare_features(self, wavs, stage): | |
| """Prepare the features for computation, including augmentation. | |
| Arguments | |
| --------- | |
| wavs : tuple | |
| Input signals (tensor) and their relative lengths (tensor). | |
| stage : sb.Stage | |
| The current stage of training. | |
| """ | |
| wavs, wav_lens = wavs | |
| # Add augmentation if specified. In this version of augmentation, we | |
| # concatenate the original and the augment batches in a single bigger | |
| # batch. This is more memory-demanding, but helps to improve the | |
| # performance. Change it if you run OOM. | |
| if stage == sb.Stage.TRAIN and hparams["apply_augmentation"]: | |
| # added the False for now, to avoid augmentation of any type | |
| wavs_noise = self.modules.env_corrupt(wavs, wav_lens) | |
| wavs = torch.cat([wavs, wavs_noise], dim=0) | |
| wav_lens = torch.cat([wav_lens, wav_lens], dim=0) | |
| if hasattr(self.hparams, "augmentation"): | |
| wavs = self.hparams.augmentation(wavs, wav_lens) | |
| # Feature extraction and normalization | |
| # wavs = self.modules.mean_var_norm_input(wavs, wav_lens) | |
| # forward pass HF (possible: pre-trained) model | |
| # feats = self.modules.wav2vec2(wavs, wav_lens=wav_lens) | |
| feats = self.modules.wav2vec2(wavs) | |
| return feats, wav_lens | |
| def compute_forward(self, batch, stage): | |
| """Runs all the computation of that transforms the input into the | |
| output probabilities over the N classes. | |
| Arguments | |
| --------- | |
| batch : PaddedBatch | |
| This batch object contains all the relevant tensors for computation. | |
| stage : sb.Stage | |
| One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. | |
| Returns | |
| ------- | |
| predictions : Tensor | |
| Tensor that contains the posterior probabilities over the N classes. | |
| """ | |
| # We first move the batch to the appropriate device. | |
| batch = batch.to(self.device) | |
| print("batch.sig: ", batch.sig) | |
| # Compute features, embeddings and output | |
| feats, lens = self.prepare_features(batch.sig, stage) | |
| print("feats shape: ", feats.shape) | |
| print("lens: ", lens) | |
| # if hparams not defined, load it | |
| # if not hasattr(self.hparams, "avg_pool"): | |
| # self.hparams.avg_pool = sb.nnet.pooling.StatisticsPooling() | |
| pooling = sb.nnet.pooling.StatisticsPooling(return_std=False) | |
| outputs = pooling(feats, lens) | |
| print("outputs shape (after pooling): ", outputs.shape) | |
| # last dim will be used for pooling, | |
| # StatisticsPooling uses 'lens' | |
| # if hparams["avg_pool_class"] == "statpool": | |
| # outputs = self.hparams.avg_pool(feats, lens) | |
| # elif hparams["avg_pool_class"] == "avgpool": | |
| # outputs = self.hparams.avg_pool(feats) | |
| # # this uses a kernel, thus the output dim is not 1 (mean to reduce) | |
| # outputs = outputs.mean(dim=1) | |
| # else: | |
| # outputs = self.hparams.avg_pool(feats) | |
| # ipdb.set_trace() | |
| # preparing outputs | |
| outputs = outputs.view(outputs.shape[0], -1) | |
| print("outputs shape (after view): ", outputs.shape) | |
| embeddings = outputs | |
| print("outputs shape (from class): ", outputs.shape) | |
| # print(self.modules) | |
| # outputs = self.modules.preout_mlp(outputs) | |
| outputs = self.modules.output_mlp(outputs) | |
| print("outputs shape (after output_mlp): ", outputs.shape) | |
| print("outputs: ", outputs) | |
| # outputs = self.hparams.log_softmax(outputs) | |
| # print("outputs shape (after log_softmax): ", outputs.shape) | |
| return outputs, lens, embeddings | |
| def compute_objectives(self, inputs, batch, stage): | |
| """Computes the loss given the predicted and targeted outputs. | |
| Arguments | |
| --------- | |
| inputs : tensors | |
| The output tensors from `compute_forward`. | |
| batch : PaddedBatch | |
| This batch object contains all the relevant tensors for computation. | |
| stage : sb.Stage | |
| One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. | |
| Returns | |
| ------- | |
| loss : torch.Tensor | |
| A one-element tensor used for backpropagating the gradient. | |
| """ | |
| predictions, lens, embeddings = inputs | |
| # get the targets from the batch | |
| targets = batch.accent_encoded.data | |
| # to meet the input form of nll loss | |
| targets = targets.squeeze(1) | |
| # Concatenate labels (due to data augmentation) | |
| if stage == sb.Stage.TRAIN and hparams["apply_augmentation"]: | |
| targets = torch.cat([targets, targets], dim=0) | |
| lens = torch.cat([lens, lens], dim=0) | |
| # if hasattr(self.hparams.lr_annealing, "on_batch_end"): | |
| # self.hparams.lr_annealing.on_batch_end(self.optimizer) | |
| # save embeddings | |
| export_embeddings = False | |
| if export_embeddings and stage == sb.Stage.TEST: | |
| self.save_embeddings_and_labels(embeddings, targets, stage) | |
| # get the final loss | |
| loss = self.hparams.compute_cost(predictions, targets) | |
| # append the metrics for evaluation | |
| if stage != sb.Stage.TRAIN: | |
| # ipdb.set_trace() | |
| self.error_metrics.append(batch.id, predictions, targets) | |
| self.error_metrics2.append(batch.id, predictions.argmax(-1), targets) | |
| # compute the accuracy of the one-step-forward prediction | |
| # self.acc_metric.append(predictions, targets, lens) | |
| self.acc_metric.append(predictions, targets.view(1, -1), lens) | |
| self.acc_metric2.append(predictions.argmax(-1), targets.view(1, -1), lens) | |
| export_predictions = False | |
| if export_predictions and stage == sb.Stage.TEST: | |
| self.save_predictions(batch.id, predictions, targets, lens) | |
| return loss | |
| def save_predictions(self, batch_ids, predictions, targets, lens): | |
| # Create a dictionary to store results | |
| results = { | |
| "batch_ids": batch_ids, | |
| "predictions": predictions.cpu().detach().numpy(), | |
| "targets": targets.cpu().detach().numpy(), | |
| "lens": lens.cpu().detach().numpy(), | |
| } | |
| # Define where to save the results | |
| save_path = os.path.join(self.hparams.save_folder, "predictions", "test_predictions.pkl") | |
| # os.remove(save_path) if os.path.exists(save_path) else None | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| # Save results as a pickle file | |
| with open(save_path, "ab") as f: | |
| pickle.dump(results, f) | |
| def fit_batch(self, batch): | |
| """Trains the parameters given a single batch in input""" | |
| should_step = self.step % self.grad_accumulation_factor == 0 | |
| predictions = self.compute_forward(batch, sb.Stage.TRAIN) | |
| loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) | |
| with self.no_sync(not should_step): | |
| (loss / self.grad_accumulation_factor).backward() | |
| if should_step: | |
| if self.check_gradients(loss): | |
| self.wav2vec2_optimizer.step() | |
| self.optimizer.step() | |
| self.wav2vec2_optimizer.zero_grad() | |
| self.optimizer.zero_grad() | |
| self.optimizer_step += 1 | |
| self.on_fit_batch_end(batch, predictions[0:2], loss, should_step) | |
| return loss.detach().cpu() | |
| def evaluate_batch(self, batch, stage): | |
| """Computations needed for validation/test batches""" | |
| with torch.no_grad(): | |
| # call on_stage_start | |
| self.on_stage_start(stage) | |
| predictions = self.compute_forward(batch, stage=stage) | |
| loss = self.compute_objectives(predictions, batch, stage=stage) | |
| return loss.detach() | |
| def on_stage_start(self, stage, epoch=None): | |
| """Gets called at the beginning of each epoch. | |
| Arguments | |
| --------- | |
| stage : sb.Stage | |
| One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. | |
| epoch : int | |
| The currently-starting epoch. This is passed | |
| `None` during the test stage. | |
| """ | |
| # Set up statistics trackers for this stage | |
| self.loss_metric = sb.utils.metric_stats.MetricStats( | |
| metric=sb.nnet.losses.nll_loss | |
| ) | |
| # Set up evaluation-only statistics trackers | |
| if stage != sb.Stage.TRAIN: | |
| self.error_metrics = self.hparams.error_stats() | |
| self.acc_metric = self.hparams.acc_computer() | |
| self.error_metrics2 = self.hparams.error_stats() | |
| self.acc_metric2 = self.hparams.acc_computer() | |
| def on_stage_end(self, stage, stage_loss, epoch=None): | |
| """Gets called at the end of an epoch. | |
| Arguments | |
| --------- | |
| stage : sb.Stage | |
| One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST | |
| stage_loss : float | |
| The average loss for all of the data processed in this stage. | |
| epoch : int | |
| The currently-starting epoch. This is passed | |
| `None` during the test stage. | |
| """ | |
| stage_stats = {"loss": stage_loss} | |
| # Store the train loss until the validation stage. | |
| if stage == sb.Stage.TRAIN: | |
| # self.train_stats = stage_stats | |
| self.train_loss = stage_loss | |
| # Summarize the statistics from the stage for record-keeping. | |
| else: | |
| stage_stats["ACC"] = self.acc_metric.summarize() | |
| stage_stats["error_rate"] = self.error_metrics.summarize("average") | |
| # log stats and save checkpoint at end-of-epoch | |
| if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process(): | |
| # ipdb.set_trace() | |
| old_lr, new_lr = self.hparams.lr_annealing(stage_stats["error_rate"]) | |
| sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) | |
| ( | |
| old_lr_wav2vec2, | |
| new_lr_wav2vec2, | |
| ) = self.hparams.lr_annealing_wav2vec2(stage_stats["error_rate"]) | |
| sb.nnet.schedulers.update_learning_rate( | |
| self.wav2vec2_optimizer, new_lr_wav2vec2 | |
| ) | |
| steps = self.optimizer_step | |
| # The train_logger writes a summary to stdout and to the logfile. | |
| epoch_stats = { | |
| "epoch": epoch, | |
| "lr": old_lr, | |
| "wave2vec_lr": old_lr_wav2vec2, | |
| "steps": steps, | |
| } | |
| self.hparams.train_logger.log_stats( | |
| stats_meta=epoch_stats, | |
| train_stats={"loss": self.train_loss}, | |
| valid_stats=stage_stats, | |
| ) | |
| self.checkpointer.save_and_keep_only( | |
| meta={"ACC": stage_stats["ACC"], "epoch": epoch}, | |
| max_keys=["ACC"], | |
| num_to_keep=1, | |
| ) | |
| # We also write statistics about test data to stdout and to logfile. | |
| if stage == sb.Stage.TEST: | |
| self.hparams.train_logger.log_stats( | |
| stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, | |
| test_stats=stage_stats, | |
| ) | |
| def init_optimizers(self): | |
| "Initializes the wav2vec2 optimizer and model optimizer" | |
| self.wav2vec2_optimizer = self.hparams.wav2vec2_opt_class( | |
| self.modules.wav2vec2.parameters() | |
| ) | |
| self.optimizer = self.hparams.opt_class(self.hparams.model.parameters()) | |
| if self.checkpointer is not None: | |
| self.checkpointer.add_recoverable( | |
| "wav2vec2_opt", self.wav2vec2_optimizer | |
| ) | |
| self.checkpointer.add_recoverable("optimizer", self.optimizer) | |
| def zero_grad(self, set_to_none=False): | |
| self.wav2vec2_optimizer.zero_grad(set_to_none) | |
| self.optimizer.zero_grad(set_to_none) | |
| def save_embeddings_and_labels(self, embeddings, labels, stage): | |
| """Saves embeddings and labels to a file for later analysis.""" | |
| if stage == sb.Stage.TEST: | |
| embeddings_np = embeddings.detach().cpu().numpy() | |
| labels_np = labels.detach().cpu().numpy() | |
| save_dir = hparams["save_folder"] | |
| save_dir = os.path.join(save_dir, "embeddings") | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Define a file name based on stage | |
| filename = f"{save_dir}/embeddings_{stage}.pkl" | |
| # if os.path.exists(filename): | |
| # # add timestamp to filename | |
| # import time | |
| # timestamp = time.strftime("%Y%m%d-%H%M") | |
| # filename = f"{save_dir}/embeddings_{stage}_{timestamp}.pkl" | |
| # print(f"Filename already exists. Saving to {filename}") | |
| # remove batch size dimension and append to pickle file one by one | |
| embeddings_np_list = [] | |
| for i in range(embeddings_np.shape[0]): | |
| id = labels_np[i] | |
| embedding = embeddings_np[i] | |
| embeddings_np_list.append((id, embedding)) | |
| # append to pickle file | |
| with open(filename, 'ab') as f: | |
| pickle.dump(embeddings_np_list, f) | |
| def dataio_prep(hparams): | |
| """This function prepares the datasets to be used in the brain class. | |
| It also defines the data processing pipeline through user-defined functions. | |
| We expect `common_accent_prepare` to have been called before this, | |
| so that the `train.csv`, `valid.csv`, and `test.csv` manifest files | |
| are available. | |
| Arguments | |
| --------- | |
| hparams : dict | |
| This dictionary is loaded from the `train.yaml` file, and it includes | |
| all the hyperparameters needed for dataset construction and loading. | |
| Returns | |
| ------- | |
| datasets : dict | |
| Contains two keys, "train" and "valid" that correspond | |
| to the appropriate DynamicItemDataset object. | |
| """ | |
| # 1. Define train/valid/test datasets | |
| data_folder = hparams["csv_prepared_folder"] | |
| train_csv = os.path.join(data_folder, "train" + ".csv") | |
| valid_csv = os.path.join(data_folder, "dev" + ".csv") | |
| test_csv = os.path.join(data_folder, "test" + ".csv") | |
| # train_csv = os.path.join(data_folder, "dev" + ".csv") | |
| # valid_csv = os.path.join(data_folder, "test" + ".csv") | |
| train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( | |
| csv_path=train_csv, replacements={"data_root": data_folder}, | |
| ) | |
| if hparams["sorting"] == "ascending": | |
| # we sort training data to speed up training and get better results. | |
| train_data = train_data.filtered_sorted( | |
| sort_key="duration", | |
| key_max_value={"duration": hparams["avoid_if_longer_than"]}, | |
| ) | |
| # when sorting do not shuffle in dataloader ! otherwise is pointless | |
| hparams["train_dataloader_opts"]["shuffle"] = False | |
| elif hparams["sorting"] == "descending": | |
| train_data = train_data.filtered_sorted( | |
| sort_key="duration", | |
| reverse=True, | |
| key_max_value={"duration": hparams["avoid_if_longer_than"]}, | |
| ) | |
| # when sorting do not shuffle in dataloader ! otherwise is pointless | |
| hparams["train_dataloader_opts"]["shuffle"] = False | |
| elif hparams["sorting"] == "random": | |
| train_data = train_data.filtered_sorted( | |
| key_max_value={"duration": hparams["avoid_if_longer_than"]}, | |
| ) | |
| else: | |
| raise NotImplementedError( | |
| "sorting must be random, ascending or descending" | |
| ) | |
| valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( | |
| csv_path=valid_csv, replacements={"data_root": data_folder}, | |
| ) | |
| # We also sort the validation data so it is faster to validate | |
| valid_data = valid_data.filtered_sorted(sort_key="duration") | |
| test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( | |
| csv_path=test_csv, replacements={"data_root": data_folder}, | |
| ) | |
| # We also sort the test data so it is faster to validate | |
| test_data = test_data.filtered_sorted(sort_key="duration") | |
| datasets = [train_data, valid_data, test_data] | |
| # Initialization of the label encoder. The label encoder assignes to each | |
| # of the observed label a unique index (e.g, 'accent01': 0, 'accent02': 1, ..) | |
| accent_encoder = sb.dataio.encoder.CategoricalEncoder() | |
| # 2. Define audio pipeline: | |
| # @sb.utils.data_pipeline.takes("wav") | |
| # @sb.utils.data_pipeline.provides("sig") | |
| # def audio_pipeline(wav): | |
| # """Load the signal, and pass it and its length to the corruption class. | |
| # This is done on the CPU in the `collate_fn`.""" | |
| # # info = torchaudio.info(wav) | |
| # # sig = sb.dataio.dataio.read_audio(wav) | |
| # # sig = torchaudio.transforms.Resample( | |
| # # info.sample_rate, hparams["sample_rate"], | |
| # # )(sig) | |
| # sig, _ = librosa.load(wav, sr=hparams["sample_rate"]) | |
| # sig = torch.tensor(sig) | |
| # return sig | |
| # sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) | |
| def audio_offset_pipeline(wav,duration,offset): | |
| sig, sr = librosa.load(wav, sr=hparams["sample_rate"], offset=int(offset), duration=10) | |
| sig = torch.tensor(sig) | |
| return sig | |
| sb.dataio.dataset.add_dynamic_item(datasets, audio_offset_pipeline) | |
| # 3. Define label pipeline: | |
| def label_pipeline(accent): | |
| yield accent | |
| accent_encoded = accent_encoder.encode_label_torch(accent) | |
| yield accent_encoded | |
| sb.dataio.dataset.add_dynamic_item(datasets, label_pipeline) | |
| # 4. Set output: | |
| sb.dataio.dataset.set_output_keys( | |
| datasets, ["id", "sig", "accent_encoded"], | |
| ) | |
| # Load or compute the label encoder (with multi-GPU DDP support) | |
| # Please, take a look into the lab_enc_file to see the label to index | |
| # mappinng. | |
| accent_encoder_file = os.path.join(hparams["save_folder"], "accent_encoder.txt") | |
| accent_encoder.load_or_create( | |
| path=accent_encoder_file, | |
| from_didatasets=[train_data], | |
| output_key="accent", | |
| ) | |
| # 5. If Dynamic Batching is used, we instantiate the needed samplers. | |
| train_batch_sampler = None | |
| valid_batch_sampler = None | |
| if hparams["dynamic_batching"]: | |
| from speechbrain.dataio.sampler import DynamicBatchSampler # noqa | |
| dynamic_hparams = hparams["dynamic_batch_sampler"] | |
| num_buckets = dynamic_hparams["num_buckets"] | |
| train_batch_sampler = DynamicBatchSampler( | |
| train_data, | |
| dynamic_hparams["max_batch_len"], | |
| num_buckets=num_buckets, | |
| length_func=lambda x: x["duration"], | |
| shuffle=dynamic_hparams["shuffle_ex"], | |
| batch_ordering=dynamic_hparams["batch_ordering"], | |
| ) | |
| valid_batch_sampler = DynamicBatchSampler( | |
| valid_data, | |
| dynamic_hparams["max_batch_len_val"], | |
| num_buckets=num_buckets, | |
| length_func=lambda x: x["duration"], | |
| shuffle=dynamic_hparams["shuffle_ex"], | |
| batch_ordering=dynamic_hparams["batch_ordering"], | |
| ) | |
| return ( | |
| train_data, | |
| valid_data, | |
| test_data, | |
| train_batch_sampler, | |
| valid_batch_sampler, | |
| accent_encoder | |
| ) | |
| def get_pooling_layer(hparams): | |
| """function to get the pooling layer based on value in hparams file or CLI""" | |
| pooling = hparams["avg_pool_class"] | |
| # possible classes are statpool, adaptivepool, avgpool | |
| if pooling == "statpool": | |
| from speechbrain.nnet.pooling import StatisticsPooling | |
| pooling_layer = StatisticsPooling(return_std=False) | |
| elif pooling == "adaptivepool": | |
| from speechbrain.nnet.pooling import AdaptivePool | |
| pooling_layer = AdaptivePool(output_size=1) | |
| elif pooling == "avgpool": | |
| from speechbrain.nnet.pooling import Pooling1d | |
| pooling_layer = Pooling1d(pool_type="avg", kernel_size=3) | |
| else: | |
| raise ValueError("Pooling strategy must be in ['statpool', 'adaptivepool', 'avgpool']") | |
| hparams["avg_pool"] = pooling_layer | |
| return hparams | |
| # Recipe begins! | |
| if __name__ == "__main__": | |
| # Reading command line arguments. | |
| hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) | |
| # Initialize ddp (useful only for multi-GPU DDP training). | |
| sb.utils.distributed.ddp_init_group(run_opts) | |
| # Load hyperparameters file with command-line overrides. | |
| with open(hparams_file) as fin: | |
| hparams = load_hyperpyyaml(fin, overrides) | |
| # Create experiment directory | |
| sb.create_experiment_directory( | |
| experiment_directory=hparams["output_folder"], | |
| hyperparams_to_save=hparams_file, | |
| overrides=overrides, | |
| ) | |
| # Data preparation, to be run on only one process. | |
| sb.utils.distributed.run_on_main( | |
| prepare_common_accent, | |
| kwargs={ | |
| "data_folder": hparams["data_folder"], | |
| "save_folder": hparams["save_folder"], | |
| "skip_prep": hparams["skip_prep"], | |
| }, | |
| ) | |
| # defining the Pooling strategy based on hparams file: | |
| hparams = get_pooling_layer(hparams) | |
| # Create dataset objects "train", "valid", and "test", train/val samples and accent_encoder | |
| ( | |
| train_data, | |
| valid_data, | |
| test_data, | |
| train_bsampler, | |
| valid_bsampler, | |
| accent_encoder | |
| ) = dataio_prep(hparams) | |
| # Load the Wav2Vec 2.0 model | |
| hparams["wav2vec2"] = hparams["wav2vec2"].to("cuda:0") | |
| # freeze the feature extractor part when unfreezing | |
| if not hparams["freeze_wav2vec2"] and hparams["freeze_wav2vec2_conv"]: | |
| hparams["wav2vec2"].model.feature_extractor._freeze_parameters() | |
| if hparams["load_pretrained"]: | |
| sb.utils.distributed.run_on_main(hparams["pretrainer"].collect_files) | |
| hparams["pretrainer"].load_collected(device=run_opts["device"]) | |
| print("Pretrained model loaded") | |
| print(sb.utils.distributed.run_on_main(hparams["pretrainer"].collect_files)) | |
| else: | |
| print("No pretrained model loaded") | |
| # Initialize the Brain object to prepare for mask training. | |
| aid_brain = AID( | |
| modules=hparams["modules"], | |
| opt_class=hparams["opt_class"], | |
| hparams=hparams, | |
| run_opts=run_opts, | |
| checkpointer=hparams["checkpointer"], | |
| ) | |
| # adding objects to trainer: | |
| train_dataloader_opts = hparams["train_dataloader_opts"] | |
| valid_dataloader_opts = hparams["valid_dataloader_opts"] | |
| if train_bsampler is not None: | |
| train_dataloader_opts = { | |
| "batch_sampler": train_bsampler, | |
| "num_workers": hparams["num_workers"], | |
| } | |
| if valid_bsampler is not None: | |
| valid_dataloader_opts = {"batch_sampler": valid_bsampler} | |
| # The `fit()` method iterates the training loop, calling the methods | |
| # necessary to update the parameters of the model. Since all objects | |
| # with changing state are managed by the Checkpointer, training can be | |
| # stopped at any point, and will be resumed on next call. | |
| aid_brain.fit( | |
| aid_brain.hparams.epoch_counter, | |
| train_data, | |
| valid_data, | |
| train_loader_kwargs=train_dataloader_opts, | |
| valid_loader_kwargs=valid_dataloader_opts, | |
| ) | |
| save_dir = hparams["save_folder"] | |
| stage = "Stage.TEST" | |
| filename = f"{save_dir}/embeddings/embeddings_{stage}.pkl" | |
| print(f"filename for saving test pickle file {filename}") | |
| if os.path.exists(filename): | |
| print("Removing embeddings pickle file") | |
| os.remove(filename) | |
| else: | |
| print("The test pickle file does not exist, writing to file: ", filename) | |
| # Load the best checkpoint for evaluation | |
| test_stats = aid_brain.evaluate( | |
| test_set=test_data, | |
| min_key="error_rate", | |
| test_loader_kwargs=hparams["test_dataloader_opts"], | |
| ) | |
| # validate for valid data (to export embeddings) | |
| # valid_stats = aid_brain.evaluate( | |
| # test_set=valid_data, | |
| # min_key="error_rate", | |
| # test_loader_kwargs=hparams["valid_dataloader_opts"], | |
| # ) | |
| # accents_lists_int=range(int(hparams["n_accents"])) | |
| # accents_list=[] | |
| # for a in accents_lists_int: | |
| # accents_list.append(str(a)) | |
| # print("Test for all accents") | |
| # print("accents_list: ", accents_list) | |
| # #get available accents in test_data using filtered_sorted | |
| # unique_accents = set([data['accent_encoded'].item() for data in test_data]) | |
| # for acc in accents_list: | |
| # if int(acc) in unique_accents: | |
| # test_data_acc = test_data.filtered_sorted(key_test={"accent_encoded": lambda x: x.item() == int(acc)}) | |
| # # ipdb.set_trace() | |
| # print("test_data_acc: ", test_data_acc) | |
| # # get length of test_data_acc | |
| # print("len(test_data_acc): ", len(test_data_acc)) | |
| # print("Test for: "+acc) | |
| # test_stats = aid_brain.evaluate( | |
| # test_set=test_data_acc, | |
| # min_key="error_rate", | |
| # test_loader_kwargs=hparams["test_dataloader_opts"], | |
| # ) |