| | import argparse |
| | import logging |
| | import os |
| | import glob |
| | import tqdm |
| | import torch |
| | import PIL |
| | import cv2 |
| | import numpy as np |
| | import torch.nn.functional as F |
| | from torchvision import transforms |
| | from utils import Config, Logger, CharsetMapper |
| |
|
| | def get_model(config): |
| | import importlib |
| | names = config.model_name.split('.') |
| | module_name, class_name = '.'.join(names[:-1]), names[-1] |
| | cls = getattr(importlib.import_module(module_name), class_name) |
| | model = cls(config) |
| | logging.info(model) |
| | model = model.eval() |
| | return model |
| |
|
| | def preprocess(img, width, height): |
| | img = cv2.resize(np.array(img), (width, height)) |
| | img = transforms.ToTensor()(img).unsqueeze(0) |
| | mean = torch.tensor([0.485, 0.456, 0.406]) |
| | std = torch.tensor([0.229, 0.224, 0.225]) |
| | return (img-mean[...,None,None]) / std[...,None,None] |
| |
|
| | def postprocess(output, charset, model_eval): |
| | def _get_output(last_output, model_eval): |
| | if isinstance(last_output, (tuple, list)): |
| | for res in last_output: |
| | if res['name'] == model_eval: output = res |
| | else: output = last_output |
| | return output |
| |
|
| | def _decode(logit): |
| | """ Greed decode """ |
| | out = F.softmax(logit, dim=2) |
| | pt_text, pt_scores, pt_lengths = [], [], [] |
| | for o in out: |
| | text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) |
| | text = text.split(charset.null_char)[0] |
| | pt_text.append(text) |
| | pt_scores.append(o.max(dim=1)[0]) |
| | pt_lengths.append(min(len(text) + 1, charset.max_length)) |
| | return pt_text, pt_scores, pt_lengths |
| |
|
| | output = _get_output(output, model_eval) |
| | logits, pt_lengths = output['logits'], output['pt_lengths'] |
| | pt_text, pt_scores, pt_lengths_ = _decode(logits) |
| | |
| | return pt_text, pt_scores, pt_lengths_ |
| |
|
| | def load(model, file, device=None, strict=True): |
| | if device is None: device = 'cpu' |
| | elif isinstance(device, int): device = torch.device('cuda', device) |
| | assert os.path.isfile(file) |
| | state = torch.load(file, map_location=device) |
| | if set(state.keys()) == {'model', 'opt'}: |
| | state = state['model'] |
| | model.load_state_dict(state, strict=strict) |
| | return model |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--config', type=str, default='configs/train_iternet.yaml', |
| | help='path to config file') |
| | parser.add_argument('--input', type=str, default='figures/demo') |
| | parser.add_argument('--cuda', type=int, default=-1) |
| | parser.add_argument('--checkpoint', type=str, default='workdir/train-iternet/best-train-iternet.pth') |
| | parser.add_argument('--model_eval', type=str, default='alignment', |
| | choices=['alignment', 'vision', 'language']) |
| | args = parser.parse_args() |
| | config = Config(args.config) |
| | if args.checkpoint is not None: config.model_checkpoint = args.checkpoint |
| | if args.model_eval is not None: config.model_eval = args.model_eval |
| | config.global_phase = 'test' |
| | config.model_vision_checkpoint, config.model_language_checkpoint = None, None |
| | device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}' |
| |
|
| | Logger.init(config.global_workdir, config.global_name, config.global_phase) |
| | Logger.enable_file() |
| | logging.info(config) |
| |
|
| | logging.info('Construct model.') |
| | model = get_model(config).to(device) |
| | model = load(model, config.model_checkpoint, device=device) |
| | charset = CharsetMapper(filename=config.dataset_charset_path, |
| | max_length=config.dataset_max_length + 1) |
| |
|
| | if os.path.isdir(args.input): |
| | paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)] |
| | else: |
| | paths = glob.glob(os.path.expanduser(args.input)) |
| | assert paths, "The input path(s) was not found" |
| | paths = sorted(paths) |
| | for path in tqdm.tqdm(paths): |
| | img = PIL.Image.open(path).convert('RGB') |
| | img = preprocess(img, config.dataset_image_width, config.dataset_image_height) |
| | img = img.to(device) |
| | res = model(img) |
| | pt_text, _, __ = postprocess(res, charset, config.model_eval) |
| | logging.info(f'{path}: {pt_text[0]}') |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|