File size: 2,154 Bytes
98dc5b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa19b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98dc5b0
aa19b06
 
98dc5b0
aa19b06
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from sense2vec import Sense2Vec
from sentence_transformers import SentenceTransformer
import wget
import os
from .mmr import mmr

url = 'https://github.com/explosion/sense2vec/releases/download/v1.0.0/s2v_reddit_2015_md.tar.gz'
cmd = 'tar -xvf {}'

class S2V:
  def __init__(self):
    self.model= SentenceTransformer('all-MiniLM-L12-v2')
    filename = wget.download(url)
    os.system(cmd.format(filename))
    self.s2v = Sense2Vec().from_disk('s2v_old')
    
  def removeDuplicates(self, most_similar, originalword):
    distractors = []
    #remove duplicates
    for each_word in most_similar:
      append_word = each_word[0].split("|")[0].replace("_", " ")
      if append_word not in distractors and append_word != originalword:
          distractors.append(append_word)
    return distractors
  
  def get_answer_and_distractor_embeddings(self,answer,candidate_distractors):
    answer_embedding = self.model.encode([answer])
    distractor_embeddings = self.model.encode(candidate_distractors)
    return answer_embedding,distractor_embeddings
  
  def execute(self, originalword):
    try:
      word = originalword.lower()
      word = word.replace(" ", "_")
      # Find the best-matching sense for a given word based on the available senses and frequency counts. 
      sense = self.s2v.get_best_sense(word)
      # Get the most similar entries in the table
      most_similar = self.s2v.most_similar(sense, n=20)
      #remove duplicates
      distractors = self.removeDuplicates(most_similar, originalword)
      distractors.insert(0,originalword)
      # encode distractors and answer
      answer_embedd, distractor_embedds = self.get_answer_and_distractor_embeddings(originalword,distractors)
      #Maximal Marginal Relevance origin: https://maartengr.github.io/KeyBERT/api/mmr.html
      final_distractors = mmr(answer_embedd,distractor_embedds,distractors,5)
      filtered_distractors = []

      for dist in final_distractors:
        filtered_distractors.append(dist[0])

      #Answer = filtered_distractors[0]
      Filtered_Distractors =  filtered_distractors[1:]
      return Filtered_Distractors
    except:
      return []