Spaces:
Runtime error
Runtime error
Avik Rao
commited on
Commit
·
a81bb0d
1
Parent(s):
cbde07b
Skip NLP model when tag is already in training space
Browse files- nlp/nlp.py +10 -4
nlp/nlp.py
CHANGED
|
@@ -9,6 +9,7 @@ from transformers import AutoTokenizer, AutoModel
|
|
| 9 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 10 |
import streamlit as st
|
| 11 |
|
|
|
|
| 12 |
# FUNCTIONS
|
| 13 |
# create embeddings
|
| 14 |
def get_embeddings(text: str, token_length: int, tokenizer, model):
|
|
@@ -17,7 +18,7 @@ def get_embeddings(text: str, token_length: int, tokenizer, model):
|
|
| 17 |
output = model(torch.tensor(tokens.input_ids).unsqueeze(0),
|
| 18 |
attention_mask=torch.tensor(
|
| 19 |
tokens.attention_mask
|
| 20 |
-
|
| 21 |
return torch.mean(output, axis=1).detach().numpy()
|
| 22 |
|
| 23 |
|
|
@@ -26,19 +27,24 @@ def nearest_doc(doc_list: List[str],
|
|
| 26 |
query: str,
|
| 27 |
tokenizer,
|
| 28 |
model,
|
| 29 |
-
token_length: int =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# get embeddings for each document
|
| 31 |
outs = [
|
| 32 |
get_embeddings(doc, token_length, tokenizer, model) for doc in doc_list
|
| 33 |
]
|
| 34 |
# get embeddings for query
|
| 35 |
-
query_embeddings = get_embeddings(query, token_length
|
| 36 |
# get similarity of each document embedding to query embedding
|
| 37 |
sims = [cosine_similarity(out, query_embeddings)[0][0] for out in outs]
|
| 38 |
return max(zip(sims, doc_list))[1]
|
| 39 |
|
| 40 |
|
| 41 |
-
# MAIN
|
| 42 |
def get_nearest_tags(user_tags: List[str]):
|
| 43 |
st.write("function called")
|
| 44 |
# download pretrained model
|
|
|
|
| 9 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 10 |
import streamlit as st
|
| 11 |
|
| 12 |
+
|
| 13 |
# FUNCTIONS
|
| 14 |
# create embeddings
|
| 15 |
def get_embeddings(text: str, token_length: int, tokenizer, model):
|
|
|
|
| 18 |
output = model(torch.tensor(tokens.input_ids).unsqueeze(0),
|
| 19 |
attention_mask=torch.tensor(
|
| 20 |
tokens.attention_mask
|
| 21 |
+
).unsqueeze(0)).hidden_states[-1]
|
| 22 |
return torch.mean(output, axis=1).detach().numpy()
|
| 23 |
|
| 24 |
|
|
|
|
| 27 |
query: str,
|
| 28 |
tokenizer,
|
| 29 |
model,
|
| 30 |
+
token_length: int = 10):
|
| 31 |
+
|
| 32 |
+
# if query is already in doc list, return query
|
| 33 |
+
if query in doc_list:
|
| 34 |
+
return query
|
| 35 |
+
|
| 36 |
# get embeddings for each document
|
| 37 |
outs = [
|
| 38 |
get_embeddings(doc, token_length, tokenizer, model) for doc in doc_list
|
| 39 |
]
|
| 40 |
# get embeddings for query
|
| 41 |
+
query_embeddings = get_embeddings(query, token_length, tokenizer, model)
|
| 42 |
# get similarity of each document embedding to query embedding
|
| 43 |
sims = [cosine_similarity(out, query_embeddings)[0][0] for out in outs]
|
| 44 |
return max(zip(sims, doc_list))[1]
|
| 45 |
|
| 46 |
|
| 47 |
+
# MAIN
|
| 48 |
def get_nearest_tags(user_tags: List[str]):
|
| 49 |
st.write("function called")
|
| 50 |
# download pretrained model
|