Avik Rao commited on
Commit
a81bb0d
·
1 Parent(s): cbde07b

Skip NLP model when tag is already in training space

Browse files
Files changed (1) hide show
  1. nlp/nlp.py +10 -4
nlp/nlp.py CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoTokenizer, AutoModel
9
  from sklearn.metrics.pairwise import cosine_similarity
10
  import streamlit as st
11
 
 
12
  # FUNCTIONS
13
  # create embeddings
14
  def get_embeddings(text: str, token_length: int, tokenizer, model):
@@ -17,7 +18,7 @@ def get_embeddings(text: str, token_length: int, tokenizer, model):
17
  output = model(torch.tensor(tokens.input_ids).unsqueeze(0),
18
  attention_mask=torch.tensor(
19
  tokens.attention_mask
20
- ).unsqueeze(0)).hidden_states[-1]
21
  return torch.mean(output, axis=1).detach().numpy()
22
 
23
 
@@ -26,19 +27,24 @@ def nearest_doc(doc_list: List[str],
26
  query: str,
27
  tokenizer,
28
  model,
29
- token_length: int = 50):
 
 
 
 
 
30
  # get embeddings for each document
31
  outs = [
32
  get_embeddings(doc, token_length, tokenizer, model) for doc in doc_list
33
  ]
34
  # get embeddings for query
35
- query_embeddings = get_embeddings(query, token_length=token_length)
36
  # get similarity of each document embedding to query embedding
37
  sims = [cosine_similarity(out, query_embeddings)[0][0] for out in outs]
38
  return max(zip(sims, doc_list))[1]
39
 
40
 
41
- # MAIN
42
  def get_nearest_tags(user_tags: List[str]):
43
  st.write("function called")
44
  # download pretrained model
 
9
  from sklearn.metrics.pairwise import cosine_similarity
10
  import streamlit as st
11
 
12
+
13
  # FUNCTIONS
14
  # create embeddings
15
  def get_embeddings(text: str, token_length: int, tokenizer, model):
 
18
  output = model(torch.tensor(tokens.input_ids).unsqueeze(0),
19
  attention_mask=torch.tensor(
20
  tokens.attention_mask
21
+ ).unsqueeze(0)).hidden_states[-1]
22
  return torch.mean(output, axis=1).detach().numpy()
23
 
24
 
 
27
  query: str,
28
  tokenizer,
29
  model,
30
+ token_length: int = 10):
31
+
32
+ # if query is already in doc list, return query
33
+ if query in doc_list:
34
+ return query
35
+
36
  # get embeddings for each document
37
  outs = [
38
  get_embeddings(doc, token_length, tokenizer, model) for doc in doc_list
39
  ]
40
  # get embeddings for query
41
+ query_embeddings = get_embeddings(query, token_length, tokenizer, model)
42
  # get similarity of each document embedding to query embedding
43
  sims = [cosine_similarity(out, query_embeddings)[0][0] for out in outs]
44
  return max(zip(sims, doc_list))[1]
45
 
46
 
47
+ # MAIN
48
  def get_nearest_tags(user_tags: List[str]):
49
  st.write("function called")
50
  # download pretrained model