| | |
| | |
| | |
| | |
| | import os |
| | import glob |
| | import argparse |
| | import pprint |
| | import omegaconf |
| |
|
| | from omegaconf import OmegaConf |
| | from torch.utils.data import DataLoader |
| |
|
| | from mmpt.utils import load_config, set_seed |
| | from mmpt.evaluators import Evaluator |
| | from mmpt.evaluators import predictor as predictor_path |
| | from mmpt.tasks import Task |
| | from mmpt import processors |
| | from mmpt.datasets import MMDataset |
| |
|
| |
|
| | def get_dataloader(config): |
| | meta_processor_cls = getattr(processors, config.dataset.meta_processor) |
| | video_processor_cls = getattr(processors, config.dataset.video_processor) |
| | text_processor_cls = getattr(processors, config.dataset.text_processor) |
| | aligner_cls = getattr(processors, config.dataset.aligner) |
| |
|
| | meta_processor = meta_processor_cls(config.dataset) |
| | video_processor = video_processor_cls(config.dataset) |
| | text_processor = text_processor_cls(config.dataset) |
| | aligner = aligner_cls(config.dataset) |
| |
|
| | test_data = MMDataset( |
| | meta_processor, |
| | video_processor, |
| | text_processor, |
| | aligner, |
| | ) |
| | print("test_len", len(test_data)) |
| | output = test_data[0] |
| | test_data.print_example(output) |
| |
|
| | test_dataloader = DataLoader( |
| | test_data, |
| | batch_size=config.fairseq.dataset.batch_size, |
| | shuffle=False, |
| | num_workers=6, |
| | collate_fn=test_data.collater, |
| | ) |
| | return test_dataloader |
| |
|
| |
|
| | def main(args): |
| | config = load_config(args) |
| |
|
| | if isinstance(config, omegaconf.dictconfig.DictConfig): |
| | print(OmegaConf.to_yaml(config)) |
| | else: |
| | pp = pprint.PrettyPrinter(indent=4) |
| | pp.print(config) |
| |
|
| | mmtask = Task.config_task(config) |
| | mmtask.build_model() |
| |
|
| | test_dataloader = get_dataloader(config) |
| | checkpoint_search_path = os.path.dirname(config.eval.save_path) |
| | results = [] |
| |
|
| | prefix = os.path.basename(args.taskconfig) |
| | if prefix.startswith("test"): |
| | |
| | if "best" not in config.fairseq.common_eval.path: |
| | print("eval each epoch.") |
| | for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"): |
| | model = mmtask.load_checkpoint(checkpoint) |
| | ckpt = os.path.basename(checkpoint) |
| | evaluator = Evaluator(config) |
| | output = evaluator.evaluate( |
| | model, test_dataloader, ckpt + "_merged") |
| | results.append((checkpoint, output)) |
| | |
| | model = mmtask.load_checkpoint(config.fairseq.common_eval.path) |
| | evaluator = Evaluator(config) |
| | output = evaluator.evaluate(model, test_dataloader) |
| | results.append((config.fairseq.common_eval.path, output)) |
| |
|
| | best_result = None |
| | best_metric = 0. |
| | for checkpoint, result in results: |
| | print(checkpoint) |
| | evaluator.metric.print_computed_metrics(result) |
| | best_score = evaluator.metric.best_metric(result) |
| | if best_score > best_metric: |
| | best_result = (checkpoint, result) |
| | best_metric = best_score |
| | print("best results:") |
| | print(best_result[0]) |
| | evaluator.metric.print_computed_metrics(best_result[1]) |
| |
|
| | elif prefix.startswith("vis"): |
| | model = mmtask.load_checkpoint(config.fairseq.common_eval.path) |
| | predictor_cls = getattr(predictor_path, config.predictor) |
| | predictor = predictor_cls(config) |
| | predictor.predict_loop(model, test_dataloader, mmtask, None) |
| | else: |
| | raise ValueError("unknown prefix of the config file", args.taskconfig) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("taskconfig", type=str) |
| | args = parser.parse_args() |
| | main(args) |
| |
|