Spaces:
Build error
Build error
fix curse issue
Browse files
app.py
CHANGED
|
@@ -4,7 +4,6 @@ import time
|
|
| 4 |
import random
|
| 5 |
import torch
|
| 6 |
import torchvision.transforms as transforms
|
| 7 |
-
#import requests
|
| 8 |
import gradio as gr
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
|
|
@@ -12,16 +11,13 @@ from models import get_model
|
|
| 12 |
from dotmap import DotMap
|
| 13 |
from PIL import Image
|
| 14 |
|
| 15 |
-
|
| 16 |
-
os.environ['
|
| 17 |
-
os.environ['TERMINFO'] = '/etc/terminfo'
|
| 18 |
-
|
| 19 |
|
| 20 |
# args
|
| 21 |
args = DotMap()
|
| 22 |
args.deploy = 'vanilla'
|
| 23 |
args.arch = 'dino_small_patch16'
|
| 24 |
-
#args.resume = '/fast_scratch/hushell/fluidstack/FS125_few-shot-transformer/outputs/dinosmall_1e-4/best_converted.pth'
|
| 25 |
args.resume = 'https://huggingface.co/hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth'
|
| 26 |
args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
|
| 27 |
args.cx = '06d75168141bc47f1'
|
|
@@ -31,7 +27,6 @@ args.cx = '06d75168141bc47f1'
|
|
| 31 |
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
model = get_model(args)
|
| 33 |
model.to(device)
|
| 34 |
-
#checkpoint = torch.load(args.resume, map_location='cpu')
|
| 35 |
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
|
| 36 |
model.load_state_dict(checkpoint['model'], strict=True)
|
| 37 |
|
|
@@ -63,6 +58,12 @@ def denormalize(x, mean, std):
|
|
| 63 |
# Google image search
|
| 64 |
from google_images_search import GoogleImagesSearch
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
# define search params
|
| 67 |
# option for commonly used search param are shown below for easy reference.
|
| 68 |
# For param marked with '##':
|
|
@@ -90,7 +91,6 @@ def inference(query, labels, n_supp=10):
|
|
| 90 |
labels = labels.split(',')
|
| 91 |
n_supp = int(n_supp)
|
| 92 |
|
| 93 |
-
#print(f'#rows={len(labels)}, #cols={n_supp}')
|
| 94 |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
|
| 95 |
|
| 96 |
with torch.no_grad():
|
|
@@ -102,26 +102,24 @@ def inference(query, labels, n_supp=10):
|
|
| 102 |
|
| 103 |
# search support images
|
| 104 |
for idx, y in enumerate(labels):
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
supp_x.append(x_im)
|
| 124 |
-
supp_y.append(idx)
|
| 125 |
|
| 126 |
print('Searching for support images is done.')
|
| 127 |
|
|
@@ -148,7 +146,6 @@ gr.Interface(fn=inference,
|
|
| 148 |
inputs=[
|
| 149 |
gr.inputs.Image(label="Image to classify", type="pil"),
|
| 150 |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
|
| 151 |
-
#gr.inputs.Number(default=1, label="Number of support examples from Google")
|
| 152 |
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples from Google")
|
| 153 |
],
|
| 154 |
theme="grass",
|
|
|
|
| 4 |
import random
|
| 5 |
import torch
|
| 6 |
import torchvision.transforms as transforms
|
|
|
|
| 7 |
import gradio as gr
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
|
|
|
|
| 11 |
from dotmap import DotMap
|
| 12 |
from PIL import Image
|
| 13 |
|
| 14 |
+
#os.environ['TERM'] = 'linux'
|
| 15 |
+
#os.environ['TERMINFO'] = '/etc/terminfo'
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# args
|
| 18 |
args = DotMap()
|
| 19 |
args.deploy = 'vanilla'
|
| 20 |
args.arch = 'dino_small_patch16'
|
|
|
|
| 21 |
args.resume = 'https://huggingface.co/hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth'
|
| 22 |
args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
|
| 23 |
args.cx = '06d75168141bc47f1'
|
|
|
|
| 27 |
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
model = get_model(args)
|
| 29 |
model.to(device)
|
|
|
|
| 30 |
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
|
| 31 |
model.load_state_dict(checkpoint['model'], strict=True)
|
| 32 |
|
|
|
|
| 58 |
# Google image search
|
| 59 |
from google_images_search import GoogleImagesSearch
|
| 60 |
|
| 61 |
+
class MyGIS(GoogleImagesSearch):
|
| 62 |
+
def __enter__(self):
|
| 63 |
+
return self
|
| 64 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
# define search params
|
| 68 |
# option for commonly used search param are shown below for easy reference.
|
| 69 |
# For param marked with '##':
|
|
|
|
| 91 |
labels = labels.split(',')
|
| 92 |
n_supp = int(n_supp)
|
| 93 |
|
|
|
|
| 94 |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
|
| 95 |
|
| 96 |
with torch.no_grad():
|
|
|
|
| 102 |
|
| 103 |
# search support images
|
| 104 |
for idx, y in enumerate(labels):
|
| 105 |
+
gis = GoogleImagesSearch(args.api_key, args.cx)
|
| 106 |
+
_search_params['q'] = y
|
| 107 |
+
_search_params['num'] = n_supp
|
| 108 |
+
gis.search(search_params=_search_params, custom_image_name='my_image')
|
| 109 |
+
gis._custom_image_name = 'my_image'
|
| 110 |
+
|
| 111 |
+
for j, x in enumerate(gis.results()):
|
| 112 |
+
x.download('./')
|
| 113 |
+
x_im = Image.open(x.path)
|
| 114 |
+
|
| 115 |
+
# vis
|
| 116 |
+
axs[idx, j].imshow(x_im)
|
| 117 |
+
axs[idx, j].set_title(f'{y}{j}')
|
| 118 |
+
axs[idx, j].axis('off')
|
| 119 |
+
|
| 120 |
+
x_im = preprocess(x_im) # (3, H, W)
|
| 121 |
+
supp_x.append(x_im)
|
| 122 |
+
supp_y.append(idx)
|
|
|
|
|
|
|
| 123 |
|
| 124 |
print('Searching for support images is done.')
|
| 125 |
|
|
|
|
| 146 |
inputs=[
|
| 147 |
gr.inputs.Image(label="Image to classify", type="pil"),
|
| 148 |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
|
|
|
|
| 149 |
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples from Google")
|
| 150 |
],
|
| 151 |
theme="grass",
|