Spaces:
Runtime error
Runtime error
| import torch | |
| from swapae.evaluation.base_evaluator import BaseEvaluator | |
| import swapae.util as util | |
| def find_evaluator_using_name(filename): | |
| target_class_name = filename | |
| module_name = 'swapae.evaluation.' + filename | |
| eval_class = util.find_class_in_module(target_class_name, module_name) | |
| assert issubclass(eval_class, BaseEvaluator), \ | |
| "Class %s should be a subclass of BaseEvaluator" % eval_class | |
| return eval_class | |
| def find_evaluator_classes(opt): | |
| if len(opt.evaluation_metrics) == 0: | |
| return [] | |
| eval_metrics = opt.evaluation_metrics.split(",") | |
| all_classes = [] | |
| target_phases = [] | |
| for metric in eval_metrics: | |
| if metric.startswith("train"): | |
| target_phases.append("train") | |
| metric = metric[len("train"):] | |
| elif metric.startswith("test"): | |
| target_phases.append("test") | |
| metric = metric[len("test"):] | |
| else: | |
| target_phases.append("test") | |
| metric_class = find_evaluator_using_name("%s_evaluator" % metric) | |
| all_classes.append(metric_class) | |
| return all_classes, target_phases | |
| class GroupEvaluator(BaseEvaluator): | |
| def modify_commandline_options(parser, is_train): | |
| parser.add_argument("--evaluation_metrics", default="structure_style_grid_generation") | |
| opt, _ = parser.parse_known_args() | |
| evaluator_classes, _ = find_evaluator_classes(opt) | |
| for eval_class in evaluator_classes: | |
| parser = eval_class.modify_commandline_options(parser, is_train) | |
| return parser | |
| def __init__(self, opt, target_phase=None): | |
| super().__init__(opt, target_phase=None) | |
| self.opt = opt | |
| evaluator_classes, target_phases = find_evaluator_classes(opt) | |
| self.evaluators = [cls(opt, target_phase=phs) for cls, phs in zip(evaluator_classes, target_phases)] | |
| def evaluate(self, model, dataset, nsteps=None): | |
| original_phase = dataset.phase | |
| metrics = {} | |
| for i, evaluator in enumerate(self.evaluators): | |
| print("Entering evaluation using %s on %s images" % (type(evaluator).__name__, evaluator.target_phase)) | |
| dataset.set_phase(evaluator.target_phase) | |
| with torch.no_grad(): | |
| new_metrics = evaluator.evaluate(model, dataset, nsteps) | |
| metrics.update(new_metrics) | |
| print("Finished evaluation of %s" % type(evaluator).__name__) | |
| dataset.set_phase(original_phase) | |
| return metrics | |