Spaces:
Build error
Build error
add more options for GIS
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import time
|
|
| 4 |
import random
|
| 5 |
import torch
|
| 6 |
import torchvision.transforms as transforms
|
| 7 |
-
import gradio as gr
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
|
| 10 |
from models import get_model
|
|
@@ -83,7 +83,9 @@ _search_params = {
|
|
| 83 |
|
| 84 |
|
| 85 |
# Gradio UI
|
| 86 |
-
def inference(query, labels, n_supp=10
|
|
|
|
|
|
|
| 87 |
'''
|
| 88 |
query: PIL image
|
| 89 |
labels: list of class names
|
|
@@ -91,6 +93,12 @@ def inference(query, labels, n_supp=10):
|
|
| 91 |
labels = labels.split(',')
|
| 92 |
n_supp = int(n_supp)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
|
| 95 |
|
| 96 |
with torch.no_grad():
|
|
@@ -104,9 +112,8 @@ def inference(query, labels, n_supp=10):
|
|
| 104 |
for idx, y in enumerate(labels):
|
| 105 |
gis = GoogleImagesSearch(args.api_key, args.cx)
|
| 106 |
_search_params['q'] = y
|
| 107 |
-
_search_params['num'] = n_supp
|
| 108 |
gis.search(search_params=_search_params, custom_image_name='my_image')
|
| 109 |
-
gis._custom_image_name = 'my_image'
|
| 110 |
|
| 111 |
for j, x in enumerate(gis.results()):
|
| 112 |
x.download('./')
|
|
@@ -135,9 +142,10 @@ def inference(query, labels, n_supp=10):
|
|
| 135 |
|
| 136 |
|
| 137 |
# DEBUG
|
| 138 |
-
|
|
|
|
| 139 |
##labels = 'dog, cat'
|
| 140 |
-
#labels = 'girl,
|
| 141 |
#output = inference(query, labels, n_supp=2)
|
| 142 |
#print(output)
|
| 143 |
|
|
@@ -146,7 +154,11 @@ gr.Interface(fn=inference,
|
|
| 146 |
inputs=[
|
| 147 |
gr.inputs.Image(label="Image to classify", type="pil"),
|
| 148 |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
|
| 149 |
-
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
],
|
| 151 |
theme="grass",
|
| 152 |
outputs=[
|
|
|
|
| 4 |
import random
|
| 5 |
import torch
|
| 6 |
import torchvision.transforms as transforms
|
| 7 |
+
#import gradio as gr
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
|
| 10 |
from models import get_model
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
# Gradio UI
|
| 86 |
+
def inference(query, labels, n_supp=10,
|
| 87 |
+
file_type='png', rights='cc_publicdomain',
|
| 88 |
+
image_type='photo', color_type='color'):
|
| 89 |
'''
|
| 90 |
query: PIL image
|
| 91 |
labels: list of class names
|
|
|
|
| 93 |
labels = labels.split(',')
|
| 94 |
n_supp = int(n_supp)
|
| 95 |
|
| 96 |
+
_search_params['num'] = n_supp
|
| 97 |
+
_search_params['fileType'] = file_type
|
| 98 |
+
_search_params['rights'] = rights
|
| 99 |
+
_search_params['imgType'] = image_type
|
| 100 |
+
_search_params['imgColorType'] = color_type
|
| 101 |
+
|
| 102 |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
|
| 103 |
|
| 104 |
with torch.no_grad():
|
|
|
|
| 112 |
for idx, y in enumerate(labels):
|
| 113 |
gis = GoogleImagesSearch(args.api_key, args.cx)
|
| 114 |
_search_params['q'] = y
|
|
|
|
| 115 |
gis.search(search_params=_search_params, custom_image_name='my_image')
|
| 116 |
+
gis._custom_image_name = 'my_image' # fix: image name sometimes too long
|
| 117 |
|
| 118 |
for j, x in enumerate(gis.results()):
|
| 119 |
x.download('./')
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
# DEBUG
|
| 145 |
+
##query = Image.open('../labrador-puppy.jpg')
|
| 146 |
+
#query = Image.open('/Users/hushell/Documents/Dan_tr.png')
|
| 147 |
##labels = 'dog, cat'
|
| 148 |
+
#labels = 'girl, sussie'
|
| 149 |
#output = inference(query, labels, n_supp=2)
|
| 150 |
#print(output)
|
| 151 |
|
|
|
|
| 154 |
inputs=[
|
| 155 |
gr.inputs.Image(label="Image to classify", type="pil"),
|
| 156 |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
|
| 157 |
+
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="GIS: Number of support examples per class"),
|
| 158 |
+
gr.inputs.Dropdown(['png', 'jpg'], default='png', label='GIS: Image file type'),
|
| 159 |
+
gr.inputs.Dropdown(['cc_publicdomain', 'cc_attribute', 'cc_sharealike', 'cc_noncommercial', 'cc_nonderived'], default='cc_publicdomain', label='GIS: Copy rights'),
|
| 160 |
+
gr.inputs.Dropdown(['clipart', 'face', 'lineart', 'stock', 'photo', 'animated', 'imgTypeUndefined'], default='photo', label='GIS: Image type'),
|
| 161 |
+
gr.inputs.Dropdown(['color', 'gray', 'mono', 'trans', 'imgColorTypeUndefined'], default='color', label='GIS: Image color type'),
|
| 162 |
],
|
| 163 |
theme="grass",
|
| 164 |
outputs=[
|