| import torch |
| import torch.nn as nn |
| import cv2 |
| import gradio as gr |
| import numpy as np |
| from PIL import Image |
| import transformers |
| from transformers import RobertaModel, RobertaTokenizer |
| import timm |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| from timm.data import resolve_data_config |
| from timm.data.transforms_factory import create_transform |
|
|
| from model import Model |
| from output import visualize_output |
|
|
|
|
| |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
| |
| vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0, global_pool='').to(device) |
| tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True) |
| roberta = RobertaModel.from_pretrained("roberta-base") |
| model = Model(vit, roberta, tokenizer, device).to(device) |
| model.eval() |
|
|
| |
| state = torch.load('saved_model', map_location=torch.device('cpu')) |
| model.load_state_dict(state['val_model_dict']) |
|
|
| |
| config = resolve_data_config({}, model=vit) |
| config['no_aug'] = True |
| config['interpolation'] = 'bilinear' |
|
|
| |
| def query_image(input_img, query, binarize, eval_threshold, crop_mode, crop_pct): |
|
|
| if crop_mode == 'center': |
| crop_mode = None |
|
|
| config['crop_pct'] = crop_pct |
| config['crop_mode'] = crop_mode |
| transform = create_transform(**config) |
|
|
| PIL_image = Image.fromarray(input_img, "RGB") |
| img = transform(PIL_image) |
| img = torch.unsqueeze(img,0).to(device) |
|
|
| with torch.no_grad(): |
| output = model(img, query) |
|
|
| img = visualize_output(img, output, binarize, eval_threshold) |
| return img |
|
|
| |
| description = """ |
| Gradio demo for an object detection architecture, introduced in my bachelor thesis (link will be added). |
| \n\n |
| You can use this architecture to detect objects using textual queries. To use it, simply upload an image and enter any query you want. |
| The model is trained to recognize only 80 categories (classes) from the COCO Detection 2017 dataset. |
| Refer to <a href="https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/">this</a> website |
| or the original <a href="https://arxiv.org/pdf/1405.0312.pdf">COCO</a> paper to see the full list of categories. |
| \n\n |
| Best results are obtained using one of these sentences, which were used during training: |
| <div class="row"> |
| <div class="column left"> |
| <ul> |
| <li>Find a {class}.</li> |
| <li>Find me a {class}</li> |
| <li>Where is the {class}?</li> |
| <li>Mark a {class}?</li> |
| <li>Can you mark a {class}?</li> |
| <li>Could you mark a {class}?</li> |
| <li>Detect a {class}.</li> |
| </ul> |
| </div> |
| <div class="column right"> |
| <ul> |
| <li>Could you detect a {class}?</li> |
| <li>Where is the {class} located?</li> |
| <li>Where is the {class} positioned?</li> |
| <li>Is there a {class}?</li> |
| <li>Look for a {class}.</li> |
| <li>Where can I find a {class}?</li> |
| <li>Could you pinpoint a {class}?</li> |
| </ul> |
| </div> |
| </div> |
| \n\n |
| When the binarize option is turned off, model will output propabilities of requested {class} for each patch. When the binarize option is turned on |
| the model will binarize each propability based on set eval_threshold. |
| \n\n |
| Each input image is transformed to size 224x224 so it can be processed by ViT. During this transformation, different |
| crop_modes and crop_percentages can be selected. No image is lost if crop_pct = 1.0 and crop_mode='squash' or 'border'. The model was trained using crop_mode='center' and crop_pct = 0.9. |
| For explanation of different crop_modes, please refer to |
| <a href="https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/transforms_factory.py">this</a> website, lines 155-172. |
| """ |
| demo = gr.Interface( |
| query_image, |
| |
| |
| |
| inputs=["image", "text", "checkbox", gr.Slider(0, 1, value=0.25), |
| gr.Radio(["center", "squash", "border"], value='squash', label='crop_mode'), gr.Slider(0.7, 1, value=1, step=0.01)], |
| outputs="image", |
| |
| title="Text-Based Object Detection", |
| description=description, |
| examples=[ |
| ["examples/imga.jpeg", "Find a person.", True, 0.45], |
| ["examples/imgb.jpeg", "Could you mark a horse?", False, 0.25], |
| ["examples/imgc.jpeg", "There should be a cat in this picture, where?", True, 0.25], |
| ["examples/imgd.jpeg", "Mark a tv in this image.", False, 0.1], |
| ["examples/imge.jpeg", "Is there a zebra in this picture?", True, 0.4], |
| ["examples/imgf.jpeg", "Look for a stop sign.", True, 0.5], |
| ], |
| cache_examples=False, |
| allow_flagging = "never", |
| css = """ |
| .column { |
| float: left; |
| padding: 10px; |
| } |
| |
| .left { |
| width: 25%; |
| } |
| |
| .right { |
| width: 75%; |
| } |
| """ |
| ) |
| demo.launch() |
|
|
|
|