Spaces:
Sleeping
Sleeping
Fangrui Liu
commited on
Commit
·
aee10cf
1
Parent(s):
0b449a5
refined layout
Browse files
app.py
CHANGED
|
@@ -258,6 +258,7 @@ def init_clip_mlang():
|
|
| 258 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 259 |
return tokenizer, clip
|
| 260 |
|
|
|
|
| 261 |
@st.experimental_singleton(show_spinner=False)
|
| 262 |
def init_clip_vanilla():
|
| 263 |
""" Initialize CLIP Model
|
|
@@ -297,11 +298,13 @@ def prompt2vec_mlang(prompt: str, tokenizer, clip):
|
|
| 297 |
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
| 298 |
return xq
|
| 299 |
|
|
|
|
| 300 |
def prompt2vec_vanilla(prompt: str, tokenizer, clip):
|
| 301 |
inputs = tokenizer(prompt, return_tensors='pt')
|
| 302 |
out = clip.get_text_features(**inputs)
|
| 303 |
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
| 304 |
-
return xq
|
|
|
|
| 305 |
|
| 306 |
st.markdown("""
|
| 307 |
<link
|
|
@@ -345,7 +348,7 @@ text_model_map = {
|
|
| 345 |
'English': {'Vanilla CLIP': [prompt2vec_vanilla, ],
|
| 346 |
'CLIP finetuned on RSICD': [prompt2vec_vanilla, ],
|
| 347 |
}
|
| 348 |
-
|
| 349 |
|
| 350 |
|
| 351 |
with st.spinner("Connecting DB..."):
|
|
@@ -354,9 +357,11 @@ with st.spinner("Connecting DB..."):
|
|
| 354 |
with st.spinner("Loading Models..."):
|
| 355 |
# Initialize CLIP model
|
| 356 |
if 'xq' not in st.session_state:
|
| 357 |
-
text_model_map['Multi Lingual']['Vanilla CLIP'].append(
|
|
|
|
| 358 |
text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla())
|
| 359 |
-
text_model_map['English']['CLIP finetuned on RSICD'].append(
|
|
|
|
| 360 |
st.session_state.query_num = 0
|
| 361 |
|
| 362 |
if 'xq' not in st.session_state:
|
|
@@ -372,30 +377,34 @@ if 'xq' not in st.session_state:
|
|
| 372 |
del st.session_state.prompt
|
| 373 |
st.title("Visual Dataset Explorer")
|
| 374 |
start = [st.empty(), st.empty(), st.empty(), st.empty(),
|
| 375 |
-
st.empty(), st.empty(), st.empty()]
|
| 376 |
start[0].info(msg)
|
| 377 |
start_col = start[1].columns(3)
|
| 378 |
-
st.session_state.db_name_ref = start_col[0].selectbox(
|
| 379 |
-
|
| 380 |
-
st.session_state.
|
|
|
|
|
|
|
| 381 |
list(text_model_map[st.session_state.lang].keys()))
|
| 382 |
if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K":
|
| 383 |
-
|
| 384 |
try to use prompt "An aerial photograph of <your-real-query>" \
|
| 385 |
to obtain best search experience!')
|
| 386 |
-
prompt = start[2].text_input(
|
| 387 |
-
"Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
|
| 388 |
if len(prompt) > 0:
|
| 389 |
st.session_state.prompt = prompt.replace(' ', '_')
|
| 390 |
-
start[
|
| 391 |
'<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>\
|
| 392 |
<p>🌟 We also support multi-language search. Type any language you know to search! ⌨️ </p>',
|
| 393 |
unsafe_allow_html=True)
|
| 394 |
-
upld_model = start[
|
| 395 |
"Or you can upload your previous run!", type='onnx')
|
| 396 |
-
upld_btn = start[
|
| 397 |
-
"
|
| 398 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
col = st.columns(8)
|
| 400 |
has_no_prompt = (len(prompt) == 0 and upld_model is None)
|
| 401 |
prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0)
|
|
@@ -418,7 +427,8 @@ if 'xq' not in st.session_state:
|
|
| 418 |
assert len(weights) == 1
|
| 419 |
xq = numpy_helper.to_array(weights[0]).tolist()
|
| 420 |
assert len(xq) == DIMS
|
| 421 |
-
st.session_state.prompt = upld_model.name.split(".onnx")[
|
|
|
|
| 422 |
else:
|
| 423 |
print(f"Input prompt is {prompt}")
|
| 424 |
# Tokenize the vectors
|
|
|
|
| 258 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 259 |
return tokenizer, clip
|
| 260 |
|
| 261 |
+
|
| 262 |
@st.experimental_singleton(show_spinner=False)
|
| 263 |
def init_clip_vanilla():
|
| 264 |
""" Initialize CLIP Model
|
|
|
|
| 298 |
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
| 299 |
return xq
|
| 300 |
|
| 301 |
+
|
| 302 |
def prompt2vec_vanilla(prompt: str, tokenizer, clip):
|
| 303 |
inputs = tokenizer(prompt, return_tensors='pt')
|
| 304 |
out = clip.get_text_features(**inputs)
|
| 305 |
xq = out.squeeze(0).cpu().detach().numpy().tolist()
|
| 306 |
+
return xq
|
| 307 |
+
|
| 308 |
|
| 309 |
st.markdown("""
|
| 310 |
<link
|
|
|
|
| 348 |
'English': {'Vanilla CLIP': [prompt2vec_vanilla, ],
|
| 349 |
'CLIP finetuned on RSICD': [prompt2vec_vanilla, ],
|
| 350 |
}
|
| 351 |
+
}
|
| 352 |
|
| 353 |
|
| 354 |
with st.spinner("Connecting DB..."):
|
|
|
|
| 357 |
with st.spinner("Loading Models..."):
|
| 358 |
# Initialize CLIP model
|
| 359 |
if 'xq' not in st.session_state:
|
| 360 |
+
text_model_map['Multi Lingual']['Vanilla CLIP'].append(
|
| 361 |
+
init_clip_mlang())
|
| 362 |
text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla())
|
| 363 |
+
text_model_map['English']['CLIP finetuned on RSICD'].append(
|
| 364 |
+
init_clip_rsicd())
|
| 365 |
st.session_state.query_num = 0
|
| 366 |
|
| 367 |
if 'xq' not in st.session_state:
|
|
|
|
| 377 |
del st.session_state.prompt
|
| 378 |
st.title("Visual Dataset Explorer")
|
| 379 |
start = [st.empty(), st.empty(), st.empty(), st.empty(),
|
| 380 |
+
st.empty(), st.empty(), st.empty(), st.empty()]
|
| 381 |
start[0].info(msg)
|
| 382 |
start_col = start[1].columns(3)
|
| 383 |
+
st.session_state.db_name_ref = start_col[0].selectbox(
|
| 384 |
+
"Select Database:", list(db_name_map.keys()))
|
| 385 |
+
st.session_state.lang = start_col[1].selectbox(
|
| 386 |
+
"Select Language:", list(text_model_map.keys()))
|
| 387 |
+
st.session_state.feat_name = start_col[2].selectbox("Select Image Feature:",
|
| 388 |
list(text_model_map[st.session_state.lang].keys()))
|
| 389 |
if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K":
|
| 390 |
+
start[2].warning('If you are searching for Remote Sensing Images, \
|
| 391 |
try to use prompt "An aerial photograph of <your-real-query>" \
|
| 392 |
to obtain best search experience!')
|
|
|
|
|
|
|
| 393 |
if len(prompt) > 0:
|
| 394 |
st.session_state.prompt = prompt.replace(' ', '_')
|
| 395 |
+
start[4].markdown(
|
| 396 |
'<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>\
|
| 397 |
<p>🌟 We also support multi-language search. Type any language you know to search! ⌨️ </p>',
|
| 398 |
unsafe_allow_html=True)
|
| 399 |
+
upld_model = start[6].file_uploader(
|
| 400 |
"Or you can upload your previous run!", type='onnx')
|
| 401 |
+
upld_btn = start[7].button(
|
| 402 |
+
"Use Loaded Weights", disabled=upld_model is None)
|
| 403 |
+
prompt = start[3].text_input(
|
| 404 |
+
"Prompt:",
|
| 405 |
+
value="An aerial photograph of "if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K" else "",
|
| 406 |
+
placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...",)
|
| 407 |
+
with start[5]:
|
| 408 |
col = st.columns(8)
|
| 409 |
has_no_prompt = (len(prompt) == 0 and upld_model is None)
|
| 410 |
prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0)
|
|
|
|
| 427 |
assert len(weights) == 1
|
| 428 |
xq = numpy_helper.to_array(weights[0]).tolist()
|
| 429 |
assert len(xq) == DIMS
|
| 430 |
+
st.session_state.prompt = upld_model.name.split(".onnx")[
|
| 431 |
+
0].replace(' ', '_')
|
| 432 |
else:
|
| 433 |
print(f"Input prompt is {prompt}")
|
| 434 |
# Tokenize the vectors
|