Spaces:
Sleeping
Sleeping
Fangrui Liu
commited on
Commit
·
0b449a5
1
Parent(s):
b73f599
add selective db / feat / lang
Browse files
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import numpy as np
|
|
| 3 |
import base64
|
| 4 |
from io import BytesIO
|
| 5 |
from multilingual_clip import pt_multilingual_clip
|
| 6 |
-
from transformers import CLIPTokenizerFast, AutoTokenizer
|
| 7 |
import torch
|
| 8 |
import logging
|
| 9 |
from os import environ
|
|
@@ -12,30 +12,22 @@ environ['TOKENIZERS_PARALLELISM'] = 'true'
|
|
| 12 |
|
| 13 |
|
| 14 |
db_name_map = {
|
| 15 |
-
"Unsplash Photos 25K": "mqdb_demo.
|
| 16 |
-
"RSICD: Remote Sensing Images 11K": "mqdb_demo.
|
| 17 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer"
|
| 20 |
-
MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
|
| 21 |
DIMS = 512
|
| 22 |
# Ignore some bad links (broken in the dataset already)
|
| 23 |
BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8',
|
| 24 |
'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'}
|
| 25 |
|
| 26 |
|
| 27 |
-
@st.experimental_singleton(show_spinner=False)
|
| 28 |
-
def init_clip():
|
| 29 |
-
""" Initialize CLIP Model
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
|
| 33 |
-
"""
|
| 34 |
-
clip = pt_multilingual_clip.MultilingualCLIP.from_pretrained(MODEL_ID)
|
| 35 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 36 |
-
return tokenizer, clip
|
| 37 |
-
|
| 38 |
-
|
| 39 |
@st.experimental_singleton(show_spinner=False)
|
| 40 |
def init_db():
|
| 41 |
""" Initialize the Database Connection
|
|
@@ -82,15 +74,15 @@ def query(xq, top_k=10):
|
|
| 82 |
# Using PREWHERE allows you to do column filter before vector search
|
| 83 |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
| 84 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
| 85 |
-
FROM {db_name_map[st.session_state.db_name_ref]} \
|
| 86 |
PREWHERE id NOT IN ({exclude_list})")
|
| 87 |
else:
|
| 88 |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
| 89 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
| 90 |
-
FROM {db_name_map[st.session_state.db_name_ref]}")
|
| 91 |
real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
| 92 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
| 93 |
-
FROM {db_name_map[st.session_state.db_name_ref]}")
|
| 94 |
top_k = real_xc
|
| 95 |
xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or
|
| 96 |
st.session_state.meta[xi['id']] < 1]
|
|
@@ -166,38 +158,6 @@ class NormalizingLayer(torch.nn.Module):
|
|
| 166 |
return x / torch.norm(x, dim=-1, keepdim=True)
|
| 167 |
|
| 168 |
|
| 169 |
-
def prompt2vec(prompt: str):
|
| 170 |
-
""" Convert prompt into a computational vector
|
| 171 |
-
|
| 172 |
-
Args:
|
| 173 |
-
prompt (str): Text to be tokenized
|
| 174 |
-
|
| 175 |
-
Returns:
|
| 176 |
-
xq: vector from the tokenizer, representing the original prompt
|
| 177 |
-
"""
|
| 178 |
-
# inputs = tokenizer(prompt, return_tensors='pt')
|
| 179 |
-
# out = clip.get_text_features(**inputs)
|
| 180 |
-
out = clip.forward(prompt, tokenizer)
|
| 181 |
-
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
| 182 |
-
return xq
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def pil_to_bytes(img):
|
| 186 |
-
""" Convert a Pillow image into base64
|
| 187 |
-
|
| 188 |
-
Args:
|
| 189 |
-
img (PIL.Image): Pillow (PIL) Image
|
| 190 |
-
|
| 191 |
-
Returns:
|
| 192 |
-
img_bin: image in base64 format
|
| 193 |
-
"""
|
| 194 |
-
with BytesIO() as buf:
|
| 195 |
-
img.save(buf, format='jpeg')
|
| 196 |
-
img_bin = buf.getvalue()
|
| 197 |
-
img_bin = base64.b64encode(img_bin).decode('utf-8')
|
| 198 |
-
return img_bin
|
| 199 |
-
|
| 200 |
-
|
| 201 |
def card(i, url):
|
| 202 |
return f'<img id="img{i}" src="{url}" width="200px;">'
|
| 203 |
|
|
@@ -286,6 +246,63 @@ def delete_element(element):
|
|
| 286 |
del element
|
| 287 |
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
st.markdown("""
|
| 290 |
<link
|
| 291 |
rel="stylesheet"
|
|
@@ -323,13 +340,23 @@ messages = [
|
|
| 323 |
"""
|
| 324 |
]
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
with st.spinner("Connecting DB..."):
|
| 327 |
st.session_state.meta, st.session_state.index = init_db()
|
| 328 |
|
| 329 |
with st.spinner("Loading Models..."):
|
| 330 |
# Initialize CLIP model
|
| 331 |
if 'xq' not in st.session_state:
|
| 332 |
-
|
|
|
|
|
|
|
| 333 |
st.session_state.query_num = 0
|
| 334 |
|
| 335 |
if 'xq' not in st.session_state:
|
|
@@ -347,8 +374,15 @@ if 'xq' not in st.session_state:
|
|
| 347 |
start = [st.empty(), st.empty(), st.empty(), st.empty(),
|
| 348 |
st.empty(), st.empty(), st.empty()]
|
| 349 |
start[0].info(msg)
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
prompt = start[2].text_input(
|
| 353 |
"Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
|
| 354 |
if len(prompt) > 0:
|
|
@@ -388,7 +422,8 @@ if 'xq' not in st.session_state:
|
|
| 388 |
else:
|
| 389 |
print(f"Input prompt is {prompt}")
|
| 390 |
# Tokenize the vectors
|
| 391 |
-
|
|
|
|
| 392 |
st.session_state.xq = xq
|
| 393 |
st.session_state.orig_xq = xq
|
| 394 |
_ = [elem.empty() for elem in start]
|
|
|
|
| 3 |
import base64
|
| 4 |
from io import BytesIO
|
| 5 |
from multilingual_clip import pt_multilingual_clip
|
| 6 |
+
from transformers import CLIPTokenizerFast, AutoTokenizer, CLIPModel
|
| 7 |
import torch
|
| 8 |
import logging
|
| 9 |
from os import environ
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
db_name_map = {
|
| 15 |
+
"Unsplash Photos 25K": lambda feat: f"mqdb_demo.unsplash_25k_{feat}_indexer",
|
| 16 |
+
"RSICD: Remote Sensing Images 11K": lambda feat: f"mqdb_demo.rsicd_{feat}_b_32",
|
| 17 |
}
|
| 18 |
+
feat_name_map = {
|
| 19 |
+
'Vanilla CLIP': "clip",
|
| 20 |
+
'CLIP finetuned on RSICD': "cliprsicd"
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
|
| 24 |
DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer"
|
|
|
|
| 25 |
DIMS = 512
|
| 26 |
# Ignore some bad links (broken in the dataset already)
|
| 27 |
BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8',
|
| 28 |
'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'}
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
@st.experimental_singleton(show_spinner=False)
|
| 32 |
def init_db():
|
| 33 |
""" Initialize the Database Connection
|
|
|
|
| 74 |
# Using PREWHERE allows you to do column filter before vector search
|
| 75 |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
| 76 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
| 77 |
+
FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])} \
|
| 78 |
PREWHERE id NOT IN ({exclude_list})")
|
| 79 |
else:
|
| 80 |
xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
| 81 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
| 82 |
+
FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}")
|
| 83 |
real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
|
| 84 |
distance('topK={top_k}')(vector, {xq_s}) AS dist\
|
| 85 |
+
FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}")
|
| 86 |
top_k = real_xc
|
| 87 |
xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or
|
| 88 |
st.session_state.meta[xi['id']] < 1]
|
|
|
|
| 158 |
return x / torch.norm(x, dim=-1, keepdim=True)
|
| 159 |
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
def card(i, url):
|
| 162 |
return f'<img id="img{i}" src="{url}" width="200px;">'
|
| 163 |
|
|
|
|
| 246 |
del element
|
| 247 |
|
| 248 |
|
| 249 |
+
@st.experimental_singleton(show_spinner=False)
|
| 250 |
+
def init_clip_mlang():
|
| 251 |
+
""" Initialize CLIP Model
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
|
| 255 |
+
"""
|
| 256 |
+
MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
|
| 257 |
+
clip = pt_multilingual_clip.MultilingualCLIP.from_pretrained(MODEL_ID)
|
| 258 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 259 |
+
return tokenizer, clip
|
| 260 |
+
|
| 261 |
+
@st.experimental_singleton(show_spinner=False)
|
| 262 |
+
def init_clip_vanilla():
|
| 263 |
+
""" Initialize CLIP Model
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
|
| 267 |
+
"""
|
| 268 |
+
MODEL_ID = "openai/clip-vit-base-patch32"
|
| 269 |
+
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
|
| 270 |
+
clip = CLIPModel.from_pretrained(MODEL_ID)
|
| 271 |
+
return tokenizer, clip
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@st.experimental_singleton(show_spinner=False)
|
| 275 |
+
def init_clip_rsicd():
|
| 276 |
+
""" Initialize CLIP Model
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
|
| 280 |
+
"""
|
| 281 |
+
MODEL_ID = "flax-community/clip-rsicd"
|
| 282 |
+
tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
|
| 283 |
+
clip = CLIPModel.from_pretrained(MODEL_ID)
|
| 284 |
+
return tokenizer, clip
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def prompt2vec_mlang(prompt: str, tokenizer, clip):
|
| 288 |
+
""" Convert prompt into a computational vector
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
prompt (str): Text to be tokenized
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
xq: vector from the tokenizer, representing the original prompt
|
| 295 |
+
"""
|
| 296 |
+
out = clip.forward(prompt, tokenizer)
|
| 297 |
+
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
| 298 |
+
return xq
|
| 299 |
+
|
| 300 |
+
def prompt2vec_vanilla(prompt: str, tokenizer, clip):
|
| 301 |
+
inputs = tokenizer(prompt, return_tensors='pt')
|
| 302 |
+
out = clip.get_text_features(**inputs)
|
| 303 |
+
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
| 304 |
+
return xq
|
| 305 |
+
|
| 306 |
st.markdown("""
|
| 307 |
<link
|
| 308 |
rel="stylesheet"
|
|
|
|
| 340 |
"""
|
| 341 |
]
|
| 342 |
|
| 343 |
+
text_model_map = {
|
| 344 |
+
'Multi Lingual': {'Vanilla CLIP': [prompt2vec_mlang, ]},
|
| 345 |
+
'English': {'Vanilla CLIP': [prompt2vec_vanilla, ],
|
| 346 |
+
'CLIP finetuned on RSICD': [prompt2vec_vanilla, ],
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
|
| 351 |
with st.spinner("Connecting DB..."):
|
| 352 |
st.session_state.meta, st.session_state.index = init_db()
|
| 353 |
|
| 354 |
with st.spinner("Loading Models..."):
|
| 355 |
# Initialize CLIP model
|
| 356 |
if 'xq' not in st.session_state:
|
| 357 |
+
text_model_map['Multi Lingual']['Vanilla CLIP'].append(init_clip_mlang())
|
| 358 |
+
text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla())
|
| 359 |
+
text_model_map['English']['CLIP finetuned on RSICD'].append(init_clip_rsicd())
|
| 360 |
st.session_state.query_num = 0
|
| 361 |
|
| 362 |
if 'xq' not in st.session_state:
|
|
|
|
| 374 |
start = [st.empty(), st.empty(), st.empty(), st.empty(),
|
| 375 |
st.empty(), st.empty(), st.empty()]
|
| 376 |
start[0].info(msg)
|
| 377 |
+
start_col = start[1].columns(3)
|
| 378 |
+
st.session_state.db_name_ref = start_col[0].selectbox("Select Database:", list(db_name_map.keys()))
|
| 379 |
+
st.session_state.lang = start_col[1].selectbox("Select Language:", list(text_model_map.keys()))
|
| 380 |
+
st.session_state.feat_name = start_col[2].selectbox("Select Image Feature:",
|
| 381 |
+
list(text_model_map[st.session_state.lang].keys()))
|
| 382 |
+
if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K":
|
| 383 |
+
st.warning('If you are searching for Remote Sensing Images, \
|
| 384 |
+
try to use prompt "An aerial photograph of <your-real-query>" \
|
| 385 |
+
to obtain best search experience!')
|
| 386 |
prompt = start[2].text_input(
|
| 387 |
"Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
|
| 388 |
if len(prompt) > 0:
|
|
|
|
| 422 |
else:
|
| 423 |
print(f"Input prompt is {prompt}")
|
| 424 |
# Tokenize the vectors
|
| 425 |
+
p2v_func, args = text_model_map[st.session_state.lang][st.session_state.feat_name]
|
| 426 |
+
xq = p2v_func(prompt, *args)
|
| 427 |
st.session_state.xq = xq
|
| 428 |
st.session_state.orig_xq = xq
|
| 429 |
_ = [elem.empty() for elem in start]
|