| import os |
|
|
| os.system('git clone https://github.com/pytorch/fairseq.git; cd fairseq;' |
| 'pip install --use-feature=in-tree-build ./; cd ..') |
| os.system('ls -l') |
|
|
| import torch |
| import numpy as np |
| import re |
| from fairseq import utils,tasks |
| from fairseq import checkpoint_utils |
| from fairseq import distributed_utils, options, tasks, utils |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
| from utils.zero_shot_utils import zero_shot_step |
| from tasks.mm_tasks.vqa_gen import VqaGenTask |
| from models.ofa import OFAModel |
| from PIL import Image |
| from torchvision import transforms |
| import gradio as gr |
|
|
| |
| tasks.register_task('vqa_gen',VqaGenTask) |
| |
| use_cuda = torch.cuda.is_available() |
| |
| use_fp16 = False |
|
|
| os.system('wget https://www.dropbox.com/s/5al62v0pumbfch7/checkpoint_best_25_004_13_4_480.pt; ' |
| 'mkdir -p checkpoints; mv checkpoint_best_25_004_13_4_480.pt checkpoints/checkpoint_best_25_004_13_4_480.pt') |
|
|
| |
| parser = options.get_generation_parser() |
|
|
| input_args = ["", "--task=vqa_gen", |
| "--beam=100", |
| "--unnormalized", |
| "--path=checkpoints/checkpoint_best_25_004_13_4_480.pt", |
| "--bpe-dir=utils/BPE", |
| "--ans2label-file=dataset/trainval_ans2label.pkl" |
| ] |
|
|
| args = options.parse_args_and_arch(parser, input_args) |
| cfg = convert_namespace_to_omegaconf(args) |
|
|
| |
|
|
|
|
| use_fp16 = cfg.common.fp16 |
| use_cuda = torch.cuda.is_available() and not cfg.common.cpu |
| if use_cuda: |
| torch.cuda.set_device(cfg.distributed_training.device_id) |
| overrides = eval(cfg.common_eval.model_overrides) |
|
|
| task = tasks.setup_task(cfg.task) |
|
|
| if cfg.task._name == "vqa_gen": |
| overrides['val_inference_type'] = "allcand" |
| |
| models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( |
| utils.split_paths(cfg.common_eval.path), |
| arg_overrides=overrides, |
| suffix=cfg.checkpoint.checkpoint_suffix, |
| strict=(cfg.checkpoint.checkpoint_shard_count == 1), |
| num_shards=cfg.checkpoint.checkpoint_shard_count, |
| ) |
|
|
|
|
| |
| for model in models: |
| model.eval() |
| if use_fp16: |
| model.half() |
| if use_cuda and not cfg.distributed_training.pipeline_model_parallel: |
| model.cuda() |
| model.prepare_for_inference_(cfg) |
|
|
| |
| generator = task.build_generator(models, cfg.generation) |
|
|
|
|
| for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)): |
|
|
| model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model']) |
| model.eval() |
| if use_fp16: |
| model.half() |
| if use_cuda and not cfg.distributed_training.pipeline_model_parallel: |
| model.cuda() |
| model.prepare_for_inference_(cfg) |
|
|
|
|
|
|
|
|
|
|
|
|
| |
| from torchvision import transforms |
| mean = [0.5, 0.5, 0.5] |
| std = [0.5, 0.5, 0.5] |
|
|
| patch_resize_transform = transforms.Compose([ |
| lambda image: image.convert("RGB"), |
| transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=mean, std=std), |
| ]) |
|
|
| |
| bos_item = torch.LongTensor([task.src_dict.bos()]) |
| eos_item = torch.LongTensor([task.src_dict.eos()]) |
| pad_idx = task.src_dict.pad() |
|
|
| |
| def pre_question(question, max_ques_words): |
| question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ') |
| question = re.sub( |
| r"\s{2,}", |
| ' ', |
| question, |
| ) |
| question = question.rstrip('\n') |
| question = question.strip(' ') |
| |
| question_words = question.split(' ') |
| if len(question_words) > max_ques_words: |
| question = ' '.join(question_words[:max_ques_words]) |
| return question |
|
|
| def encode_text(text, length=None, append_bos=False, append_eos=False): |
| s = task.tgt_dict.encode_line( |
| line=task.bpe.encode(text), |
| add_if_not_exist=False, |
| append_eos=False |
| ).long() |
| if length is not None: |
| s = s[:length] |
| if append_bos: |
| s = torch.cat([bos_item, s]) |
| if append_eos: |
| s = torch.cat([s, eos_item]) |
| return s |
|
|
| |
| def construct_sample(image: Image, question: str): |
| patch_image = patch_resize_transform(image).unsqueeze(0) |
| patch_mask = torch.tensor([True]) |
|
|
| question = pre_question(question, task.cfg.max_src_length) |
| question = question + '?' if not question.endswith('?') else question |
| src_text = encode_text(' {}'.format(question), append_bos=True, append_eos=True).unsqueeze(0) |
|
|
| src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text]) |
| ref_dict = np.array([{'yes': 1.0}]) |
| sample = { |
| "id":np.array(['42']), |
| "net_input": { |
| "src_tokens": src_text, |
| "src_lengths": src_length, |
| "patch_images": patch_image, |
| "patch_masks": patch_mask, |
| }, |
| "ref_dict": ref_dict, |
| } |
| return sample |
|
|
| |
| |
| def apply_half(t): |
| if t.dtype is torch.float32: |
| return t.to(dtype=torch.half) |
| return t |
|
|
|
|
| |
| def open_domain_vqa(Image, Question): |
| sample = construct_sample(Image, Question) |
| sample = utils.move_to_cuda(sample) if use_cuda else sample |
| sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample |
| |
| with torch.no_grad(): |
| result, scores = zero_shot_step(task, generator, models, sample) |
| return result[0]['answer'] |
|
|
|
|
| title = "TimeMachine-Visual_Question_Answering" |
| description = "TimeMachine-Visual_Question_Answering. Upload your own pair of image (a pair of images for comparison) or click any one of the examples, and click " \ |
| "\"Submit\" and then wait for OFA's answer. " |
| article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \ |
| "Repo</a></p> " |
| examples = [['test5.jpg', "Which side of the two images has building under construction?"], |
| ['test2.jpg', "Which side of the two images has building under construction?"], |
| ['test.jpg', "Which side of the two images has better pedestrian crossing?"], |
| ['test4.jpg', "Which side of the two images has more building?"]] |
| io = gr.Interface(fn=open_domain_vqa, inputs=[gr.inputs.Image(type='pil'), "textbox"], outputs=gr.outputs.Textbox(label="Answer"), |
| title=title, description=description, article=article, examples=examples, |
| allow_flagging=False, allow_screenshot=False) |
| io.launch(cache_examples=True) |