Vivien
commited on
Commit
·
0779f15
1
Parent(s):
ee0cebf
Add eval and torch.no_grad (because inference only)
Browse files
app.py
CHANGED
|
@@ -2,13 +2,14 @@ from html import escape
|
|
| 2 |
import re
|
| 3 |
import streamlit as st
|
| 4 |
import pandas as pd, numpy as np
|
|
|
|
| 5 |
from transformers import CLIPProcessor, CLIPModel
|
| 6 |
from st_clickable_images import clickable_images
|
| 7 |
|
| 8 |
MODEL_NAMES = [
|
| 9 |
-
# "base-patch32",
|
| 10 |
-
# "base-patch16",
|
| 11 |
-
# "large-patch14",
|
| 12 |
"large-patch14-336"
|
| 13 |
]
|
| 14 |
|
|
@@ -20,7 +21,7 @@ def load():
|
|
| 20 |
processors = {}
|
| 21 |
embeddings = {}
|
| 22 |
for name in MODEL_NAMES:
|
| 23 |
-
models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}")
|
| 24 |
processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
|
| 25 |
embeddings[name] = {
|
| 26 |
0: np.load(f"embeddings-vit-{name}.npy"),
|
|
@@ -39,7 +40,8 @@ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
|
|
| 39 |
|
| 40 |
def compute_text_embeddings(list_of_strings, name):
|
| 41 |
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
|
| 42 |
-
|
|
|
|
| 43 |
return result / np.linalg.norm(result, axis=1, keepdims=True)
|
| 44 |
|
| 45 |
|
|
@@ -158,9 +160,9 @@ def main():
|
|
| 158 |
st.sidebar.markdown(description)
|
| 159 |
with st.sidebar.expander("Advanced use"):
|
| 160 |
st.markdown(howto)
|
| 161 |
-
#mode = st.sidebar.selectbox(
|
| 162 |
# "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
|
| 163 |
-
#)
|
| 164 |
|
| 165 |
_, c, _ = st.columns((1, 3, 1))
|
| 166 |
if "query" in st.session_state:
|
|
@@ -176,7 +178,7 @@ def main():
|
|
| 176 |
"ViT-L/14@336px (slower)": "large-patch14-336",
|
| 177 |
}
|
| 178 |
|
| 179 |
-
if False
|
| 180 |
c1, c2 = st.columns((1, 1))
|
| 181 |
selection1 = c1.selectbox("", models_dict.keys(), index=0)
|
| 182 |
selection2 = c2.selectbox("", models_dict.keys(), index=2)
|
|
@@ -187,7 +189,7 @@ def main():
|
|
| 187 |
|
| 188 |
if len(query) > 0:
|
| 189 |
results1 = image_search(query, corpus, name1)
|
| 190 |
-
if False
|
| 191 |
with c1:
|
| 192 |
clicked1 = clickable_images(
|
| 193 |
[result[0] for result in results1],
|
|
@@ -225,7 +227,7 @@ def main():
|
|
| 225 |
if change_query:
|
| 226 |
if clicked1 >= 0:
|
| 227 |
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
|
| 228 |
-
#elif clicked2 >= 0:
|
| 229 |
# st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
|
| 230 |
st.experimental_rerun()
|
| 231 |
|
|
|
|
| 2 |
import re
|
| 3 |
import streamlit as st
|
| 4 |
import pandas as pd, numpy as np
|
| 5 |
+
import torch
|
| 6 |
from transformers import CLIPProcessor, CLIPModel
|
| 7 |
from st_clickable_images import clickable_images
|
| 8 |
|
| 9 |
MODEL_NAMES = [
|
| 10 |
+
# "base-patch32",
|
| 11 |
+
# "base-patch16",
|
| 12 |
+
# "large-patch14",
|
| 13 |
"large-patch14-336"
|
| 14 |
]
|
| 15 |
|
|
|
|
| 21 |
processors = {}
|
| 22 |
embeddings = {}
|
| 23 |
for name in MODEL_NAMES:
|
| 24 |
+
models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}").eval()
|
| 25 |
processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
|
| 26 |
embeddings[name] = {
|
| 27 |
0: np.load(f"embeddings-vit-{name}.npy"),
|
|
|
|
| 40 |
|
| 41 |
def compute_text_embeddings(list_of_strings, name):
|
| 42 |
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
result = models[name].get_text_features(**inputs).detach().numpy()
|
| 45 |
return result / np.linalg.norm(result, axis=1, keepdims=True)
|
| 46 |
|
| 47 |
|
|
|
|
| 160 |
st.sidebar.markdown(description)
|
| 161 |
with st.sidebar.expander("Advanced use"):
|
| 162 |
st.markdown(howto)
|
| 163 |
+
# mode = st.sidebar.selectbox(
|
| 164 |
# "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
|
| 165 |
+
# )
|
| 166 |
|
| 167 |
_, c, _ = st.columns((1, 3, 1))
|
| 168 |
if "query" in st.session_state:
|
|
|
|
| 178 |
"ViT-L/14@336px (slower)": "large-patch14-336",
|
| 179 |
}
|
| 180 |
|
| 181 |
+
if False: # "Comparison" in mode:
|
| 182 |
c1, c2 = st.columns((1, 1))
|
| 183 |
selection1 = c1.selectbox("", models_dict.keys(), index=0)
|
| 184 |
selection2 = c2.selectbox("", models_dict.keys(), index=2)
|
|
|
|
| 189 |
|
| 190 |
if len(query) > 0:
|
| 191 |
results1 = image_search(query, corpus, name1)
|
| 192 |
+
if False: # "Comparison" in mode:
|
| 193 |
with c1:
|
| 194 |
clicked1 = clickable_images(
|
| 195 |
[result[0] for result in results1],
|
|
|
|
| 227 |
if change_query:
|
| 228 |
if clicked1 >= 0:
|
| 229 |
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
|
| 230 |
+
# elif clicked2 >= 0:
|
| 231 |
# st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
|
| 232 |
st.experimental_rerun()
|
| 233 |
|