| import sys |
|
|
| import torch |
| import speechbrain as sb |
| from speechbrain.dataio import dataset |
| from speechbrain.utils.distributed import run_on_main |
| from hyperpyyaml import load_hyperpyyaml |
|
|
|
|
| class LM(sb.core.Brain): |
| def compute_forward(self, batch, stage): |
| batch = batch.to(self.device) |
| tokens_bos, _ = batch.tokens_bos |
| logits = self.hparams.model(tokens_bos) |
| pred = self.hparams.log_softmax(logits) |
| return pred |
|
|
| def compute_objectives(self, predictions, batch, stage): |
| batch = batch.to(self.device) |
| tokens_eos, tokens_len = batch.tokens_eos |
| loss = self.hparams.compute_cost( |
| predictions, tokens_eos, length=tokens_len |
| ) |
| return loss |
|
|
| def fit_batch(self, batch): |
| predictions = self.compute_forward(batch, sb.Stage.TRAIN) |
| loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) |
|
|
| (loss / self.hparams.accumulation_steps).backward() |
|
|
| if self.step % self.hparams.accumulation_steps == 0: |
| self.check_gradients(loss) |
|
|
| self.optimizer.step() |
| self.optimizer.zero_grad() |
|
|
| if isinstance( |
| self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler |
| ) or isinstance( |
| self.hparams.lr_annealing, |
| sb.nnet.schedulers.CyclicCosineScheduler, |
| ): |
| self.hparams.lr_annealing(self.optimizer) |
|
|
| return loss |
|
|
| def on_stage_end(self, stage, stage_loss, epoch): |
| stage_stats = {"loss": stage_loss} |
| if stage == sb.Stage.TRAIN: |
| self.train_stats = stage_stats |
|
|
| if stage == sb.Stage.VALID and sb.utils.distributed.if_main_process(): |
| if not ( |
| isinstance( |
| self.hparams.lr_annealing, sb.nnet.schedulers.NoamScheduler |
| ) |
| or isinstance( |
| self.hparams.lr_annealing, |
| sb.nnet.schedulers.CyclicCosineScheduler, |
| ) |
| ): |
| old_lr, new_lr = self.hparams.lr_annealing(stage_loss) |
| sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) |
| else: |
| old_lr = self.hparams.lr_annealing.current_lr |
|
|
| self.hparams.train_logger.log_stats( |
| stats_meta={"epoch": epoch, "lr": old_lr}, |
| train_stats=self.train_stats, |
| valid_stats=stage_stats, |
| ) |
| self.checkpointer.save_and_keep_only( |
| meta=stage_stats, min_keys=["loss"], |
| ) |
|
|
| if stage == sb.Stage.TEST and sb.utils.distributed.if_main_process(): |
| self.hparams.train_logger.log_stats( |
| stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, |
| test_stats=stage_stats, |
| ) |
|
|
|
|
| def dataio_prepare(hparams): |
| @sb.utils.data_pipeline.takes("transcription") |
| @sb.utils.data_pipeline.provides( |
| "transcription", "tokens_bos", "tokens_eos" |
| ) |
| def transcription_pipline(transcription): |
| yield transcription |
| tokens_list = hparams["tokenizer"].encode_as_ids(transcription) |
| tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) |
| yield tokens_bos |
| tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) |
| yield tokens_eos |
|
|
| data_folder = hparams["data_folder"] |
| datasets = {} |
| for dataset_name in ["train", "dev", "test"]: |
| json_path = f"{data_folder}/{dataset_name}.json" |
| datasets[dataset_name] = dataset.DynamicItemDataset.from_json( |
| json_path=json_path, |
| replacements={"data_root": data_folder}, |
| dynamic_items=[transcription_pipline], |
| output_keys=["transcription", "tokens_bos", "tokens_eos"], |
| ) |
|
|
| return datasets |
|
|
|
|
| if __name__ == "__main__": |
| hparams_file_path, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) |
| with open(hparams_file_path) as hparams_file: |
| hparams = load_hyperpyyaml(hparams_file, overrides) |
|
|
| sb.utils.distributed.ddp_init_group(run_opts) |
|
|
| sb.create_experiment_directory( |
| experiment_directory=hparams["output_folder"], |
| hyperparams_to_save=hparams_file_path, |
| overrides=overrides, |
| ) |
|
|
| run_on_main(hparams["pretrainer"].collect_files) |
| hparams["pretrainer"].load_collected(device=run_opts["device"]) |
|
|
| datasets = dataio_prepare(hparams) |
|
|
| lm_brain = LM( |
| modules=hparams["modules"], |
| opt_class=hparams["optimizer"], |
| hparams=hparams, |
| run_opts=run_opts, |
| checkpointer=hparams["checkpointer"], |
| ) |
|
|
| lm_brain.fit( |
| lm_brain.hparams.epoch_counter, |
| datasets["train"], |
| datasets["dev"], |
| train_loader_kwargs=hparams["train_dataloader_opts"], |
| valid_loader_kwargs=hparams["valid_dataloader_opts"], |
| ) |
|
|
| |
| lm_brain.evaluate( |
| datasets["test"], |
| min_key="loss", |
| test_loader_kwargs=hparams["test_dataloader_opts"], |
| ) |
|
|