| import logging |
| from argparse import ArgumentParser |
|
|
| import torch |
| import torchaudio |
| from datamodule.data_module import DataModule |
| from pytorch_lightning import Trainer |
|
|
|
|
| |
| logging.basicConfig(level=logging.WARNING) |
|
|
|
|
| def get_trainer(args): |
| return Trainer(num_nodes=1, devices=1, accelerator="gpu") |
|
|
|
|
| def get_lightning_module(args): |
| |
| from lightning import ModelModule |
| modelmodule = ModelModule(args) |
| return modelmodule |
|
|
|
|
| def parse_args(): |
| parser = ArgumentParser() |
| parser.add_argument( |
| "--modality", |
| type=str, |
| help="Type of input modality", |
| required=True, |
| choices=["audio", "video"], |
| ) |
| parser.add_argument( |
| "--root-dir", |
| type=str, |
| help="Root directory of preprocessed dataset", |
| required=True, |
| ) |
| parser.add_argument( |
| "--test-file", |
| default="lrs3_test_transcript_lengths_seg16s.csv", |
| type=str, |
| help="Filename of testing label list. (Default: lrs3_test_transcript_lengths_seg16s.csv)", |
| required=True, |
| ) |
| parser.add_argument( |
| "--pretrained-model-path", |
| type=str, |
| help="Path to the pre-trained model", |
| required=True, |
| ) |
| parser.add_argument( |
| "--decode-snr-target", |
| type=float, |
| default=999999, |
| help="Level of signal-to-noise ratio (SNR)", |
| ) |
| parser.add_argument( |
| "--debug", |
| action="store_true", |
| help="Flag to use debug level for logging", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def init_logger(debug): |
| fmt = "%(asctime)s %(message)s" if debug else "%(message)s" |
| level = logging.DEBUG if debug else logging.INFO |
| logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S") |
|
|
|
|
| def cli_main(): |
| args = parse_args() |
| init_logger(args.debug) |
| modelmodule = get_lightning_module(args) |
| datamodule = DataModule(args) |
| trainer = get_trainer(args) |
| trainer.test(model=modelmodule, datamodule=datamodule) |
|
|
|
|
| if __name__ == "__main__": |
| cli_main() |
|
|