Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,11 +4,13 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
| 4 |
import streamlit as st
|
| 5 |
import torch
|
| 6 |
import pickle
|
|
|
|
| 7 |
|
| 8 |
model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
|
| 9 |
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
|
| 10 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
| 11 |
text = st.text_input("Enter word or key-phrase")
|
|
|
|
| 12 |
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")
|
| 13 |
|
| 14 |
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
|
|
@@ -17,12 +19,17 @@ k = st.number_input("Top k nearest key-phrases",1,10,5)
|
|
| 17 |
with st.sidebar:
|
| 18 |
diversify_box = st.checkbox("Diversify results",True)
|
| 19 |
if diversify_box:
|
| 20 |
-
|
| 21 |
|
|
|
|
| 22 |
with open("kp_dict_merged.pickle",'rb') as handle:
|
| 23 |
kp_dict = pickle.load(handle)
|
| 24 |
for key in kp_dict.keys():
|
| 25 |
kp_dict[key] = kp_dict[key].detach().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
|
| 28 |
sim_dict = {}
|
|
@@ -65,10 +72,30 @@ def pool_embeddings(out, tok):
|
|
| 65 |
mean_pooled = summed / summed_mask
|
| 66 |
return mean_pooled
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
if text:
|
| 69 |
new_tokens = concat_tokens([text])
|
| 70 |
new_tokens.pop("KPS")
|
| 71 |
with torch.no_grad():
|
| 72 |
outputs = model(**new_tokens)
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import streamlit as st
|
| 5 |
import torch
|
| 6 |
import pickle
|
| 7 |
+
import itertools
|
| 8 |
|
| 9 |
model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
|
| 10 |
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
|
| 11 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
| 12 |
text = st.text_input("Enter word or key-phrase")
|
| 13 |
+
|
| 14 |
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")
|
| 15 |
|
| 16 |
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
|
|
|
|
| 19 |
with st.sidebar:
|
| 20 |
diversify_box = st.checkbox("Diversify results",True)
|
| 21 |
if diversify_box:
|
| 22 |
+
k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
|
| 23 |
|
| 24 |
+
#load kp dict
|
| 25 |
with open("kp_dict_merged.pickle",'rb') as handle:
|
| 26 |
kp_dict = pickle.load(handle)
|
| 27 |
for key in kp_dict.keys():
|
| 28 |
kp_dict[key] = kp_dict[key].detach().numpy()
|
| 29 |
+
|
| 30 |
+
#load cosine distances of kp dict
|
| 31 |
+
with open("cosine_kp.pickle",'rb') as handle:
|
| 32 |
+
cosine_kp = pickle.load(handle)
|
| 33 |
|
| 34 |
def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
|
| 35 |
sim_dict = {}
|
|
|
|
| 72 |
mean_pooled = summed / summed_mask
|
| 73 |
return mean_pooled
|
| 74 |
|
| 75 |
+
def extract_idxs(top_dict, kp_dict):
|
| 76 |
+
idxs = []
|
| 77 |
+
c = 0
|
| 78 |
+
for i in list(kp_dict.keys()):
|
| 79 |
+
if i in top_dict.keys():
|
| 80 |
+
idxs.append(c)
|
| 81 |
+
c+=1
|
| 82 |
+
return idxs
|
| 83 |
+
|
| 84 |
if text:
|
| 85 |
new_tokens = concat_tokens([text])
|
| 86 |
new_tokens.pop("KPS")
|
| 87 |
with torch.no_grad():
|
| 88 |
outputs = model(**new_tokens)
|
| 89 |
+
if not diversify_box:
|
| 90 |
+
sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
|
| 91 |
+
st.json(sim_dict)
|
| 92 |
+
else:
|
| 93 |
+
sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
|
| 94 |
+
idxs = extract_idxs(sim_dict, kp_dict)
|
| 95 |
+
distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
|
| 96 |
+
min_sim = np.inf
|
| 97 |
+
candidate = None
|
| 98 |
+
for combination in itertools.combinations(range(len(idxs)), k):
|
| 99 |
+
sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
|
| 100 |
+
|
| 101 |
+
|