Spaces:
Running
Running
| """This module generates false answers within same context. | |
| @Author: Karthick T. Sharma | |
| """ | |
| import os | |
| import random | |
| import urllib.request | |
| import tarfile | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sentence_transformers import SentenceTransformer | |
| from sense2vec import Sense2Vec | |
| from src.utils.text_process import change_format | |
| import tempfile | |
| class FalseAnswerGenerator: | |
| """Generate false answers within same context.""" | |
| _instance = None | |
| # def __init__(self): | |
| # """Initialize false answer generation models.""" | |
| # self.__init_sentence_transformer() | |
| # self.__init_sense2vec() | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(FalseAnswerGenerator, cls).__new__(cls) | |
| cls._instance._init_models() | |
| return cls._instance | |
| def _init_models(self): | |
| self.__init_sentence_transformer() | |
| self.__init_sense2vec() | |
| def __init_sentence_transformer(self): | |
| """Initialize sentence embedding. | |
| https://www.sbert.net/ | |
| """ | |
| self._sentence_model = SentenceTransformer('all-MiniLM-L12-v2') | |
| def __init_sense2vec(self): | |
| """Initialize word vectors to get similar words. | |
| https://github.com/explosion/sense2vec | |
| """ | |
| if not os.path.isdir(os.getcwd() + '/s2v_old'): | |
| s2v_url = "https://github.com/explosion/sense2vec/releases/download/" | |
| s2v_ver_url = s2v_url + "v1.0.0/s2v_reddit_2015_md.tar.gz" | |
| with urllib.request.urlopen(s2v_ver_url) as req: | |
| # save downloaded to a temp file first | |
| with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
| temp_file.write(req.read()) | |
| temp_file_path = temp_file.name | |
| with tarfile.open(temp_file_path, mode='r:gz') as file: | |
| def is_within_directory(directory, target): | |
| abs_directory = os.path.abspath(directory) | |
| abs_target = os.path.abspath(target) | |
| prefix = os.path.commonprefix([abs_directory, abs_target]) | |
| return prefix == abs_directory | |
| def safe_extract(tar, path=".", members=None, *, numeric_owner=False): | |
| for member in tar.getmembers(): | |
| member_path = os.path.join(path, member.name) | |
| if not is_within_directory(path, member_path): | |
| raise Exception("Attempted Path Traversal in Tar File") | |
| tar.extractall(path, members, numeric_owner=numeric_owner) | |
| safe_extract(file) | |
| self._s2v = Sense2Vec().from_disk("s2v_old") | |
| def __get_embedding(self, answer, distractors): | |
| """Returns sentence model embedding of answer and distractors. | |
| Args: | |
| answer (str): correct answer. | |
| distractors (list[str]): false answers. | |
| Returns: | |
| tuple[list[str], list[str]]: sentence model embedding of answer and distractors. | |
| """ | |
| return self._sentence_model.encode([answer]), self._sentence_model.encode(distractors) | |
| def get_embedding_list_word(self, word_list: list[str]): | |
| """ | |
| Returns sentence model embedding of answer and distractors. | |
| """ | |
| return self._sentence_model.encode([word_list]) | |
| def filter_output(self, orig, dummies): | |
| """Filter out final answers. | |
| Args: | |
| orig (str): correct answer. | |
| dummies (list[str]): false answers list generated from correct answer. | |
| Returns: | |
| list[str]: list of final answer which has low similarity. | |
| """ | |
| ans_embedded, dis_embedded = self.__get_embedding(orig, dummies) | |
| # filter using MMMR | |
| dist = self.__mmr(ans_embedded, dis_embedded, dummies) | |
| filtered_dist = [] | |
| for dis in dist: | |
| # 0 -> word, 1 -> confidence / probability | |
| filtered_dist.append(dis[0].capitalize()) | |
| return filtered_dist | |
| def __mmr(self, doc_embedding, word_embedding, words, diversity=0.9): | |
| """Word diversity using MMR - Maximal Marginal Relevance. | |
| Args: | |
| doc_embedding (list[str]): sentence embedding of correct answer. | |
| word_embedding (list[str]): sentence embedding of false answer. | |
| words (list[str]): false answers. | |
| diversity (float, optional): diversity coefficient. Defaults to 0.9. | |
| Returns: | |
| list[str]: list of final answers. | |
| """ | |
| # extract similarity between words and docs | |
| word_doc_similarity = cosine_similarity(word_embedding, doc_embedding) | |
| word_similarity = cosine_similarity(word_embedding) | |
| kw_idx = [np.argmax(word_doc_similarity)] # NumPy 2.0.2 vẫn hỗ trợ np.argmax() | |
| dist_idx = [i for i in range(len(words)) if i != kw_idx[0]] | |
| for _ in range(3): | |
| dist_similarities = word_doc_similarity[dist_idx, :] | |
| target_similarities = np.max( | |
| word_similarity[dist_idx][:, kw_idx], axis=1 | |
| ) | |
| # calculate MMR | |
| mmr = (1 - diversity) * dist_similarities - \ | |
| diversity * target_similarities.reshape(-1, 1) | |
| mmr_idx = dist_idx[np.argmax(mmr)] # NumPy vẫn hỗ trợ np.argmax() | |
| # update kw | |
| kw_idx.append(mmr_idx) | |
| dist_idx.remove(mmr_idx) | |
| return [(words[idx], round(float(word_doc_similarity.reshape(1, -1)[0][idx]), 4)) | |
| for idx in kw_idx] | |
| def __generate_answer(self, query): | |
| """Generate false answers from correct answer. | |
| Args: | |
| query (str): correct answer. | |
| Returns: | |
| list(str): list of final answers if input is valid, else None. | |
| """ | |
| # get the best sense for given word (like NOUN, PRONOUN, VERB...) | |
| query_al = self._s2v.get_best_sense(query.lower().replace(' ', '_')) | |
| if query_al is None: | |
| return None | |
| try: | |
| assert query_al in self._s2v | |
| # get most similar 20 words (if any) | |
| temp = self._s2v.most_similar(query_al, n=20) | |
| formatted_string = change_format(temp) | |
| formatted_string.insert(0, query) | |
| # if answers are numbers then we don't need to filter | |
| if query_al.split('|')[1] == 'CARDINAL': | |
| return formatted_string[:4] | |
| # else filter because sometimes similar words will be US, U.S, USA, AMERICA... | |
| return self.filter_output(query, formatted_string) | |
| except AssertionError: | |
| return None | |
| def get_output(self, filtered_kws): | |
| """Generate false answers for whole context. | |
| Filter out keywords that don't generate 3 false answers. | |
| Args: | |
| filtered_kws (list(str)): list of keywords | |
| Returns: | |
| tuple(list(str), list(list(str))): tuple of correct answers and list of all answers. | |
| """ | |
| crct_ans = [] | |
| all_answers = [] | |
| for kws in filtered_kws: | |
| for kwx in kws: | |
| results = self.__generate_answer(kwx) | |
| if results is not None: | |
| crct_ans.append(kwx.capitalize()) | |
| random.shuffle(results) | |
| all_answers.append(results) | |
| return crct_ans, sum(all_answers, []) | |
| def generate_distractors_from_synonyms( | |
| self, | |
| correct_words: list[str], | |
| num_distractors: int = 3, | |
| sim_min: float = 0.35, | |
| sim_max: float = 0.75 | |
| ): | |
| """ | |
| Generate distractors for synonym questions. | |
| Input: 2 correct synonymous words | |
| Output: distractors semantically related but NOT synonyms | |
| """ | |
| assert len(correct_words) == 2, "Require exactly 2 correct synonyms" | |
| w1, w2 = [w.lower().strip() for w in correct_words] | |
| candidates = set() | |
| # -------- 1. Collect candidates from sense2vec ---------- | |
| for w in [w1, w2]: | |
| sense = self._s2v.get_best_sense(w.replace(" ", "_")) | |
| if sense and sense in self._s2v: | |
| sims = self._s2v.most_similar(sense, n=30) | |
| formatted = change_format(sims) | |
| candidates.update(formatted) | |
| # Remove originals | |
| candidates = { | |
| c for c in candidates | |
| if c.lower() not in {w1, w2} | |
| } | |
| if not candidates: | |
| return [] | |
| candidates = list(candidates) | |
| # -------- 2. Sentence embedding ---------- | |
| emb_correct = self._sentence_model.encode(correct_words) | |
| emb_candidates = self._sentence_model.encode(candidates) | |
| # similarity to each correct word | |
| sim_1 = cosine_similarity(emb_candidates, emb_correct[0].reshape(1, -1)) | |
| sim_2 = cosine_similarity(emb_candidates, emb_correct[1].reshape(1, -1)) | |
| final_candidates = [] | |
| for idx, word in enumerate(candidates): | |
| s1 = sim_1[idx][0] | |
| s2 = sim_2[idx][0] | |
| # loại bỏ các từ quá giống | |
| if max(s1, s2) > sim_max: | |
| continue | |
| # loại bỏ các từ quá khác | |
| if max(s1, s2) < sim_min: | |
| continue | |
| final_candidates.append((word, max(s1, s2))) | |
| chosen = random.sample( | |
| final_candidates, | |
| k=min(num_distractors, len(final_candidates)) | |
| ) | |
| return [w.capitalize() for w, _ in chosen] | |
| def generate_distractors_from_antonyms( | |
| self, | |
| correct_words: list[str], | |
| num_distractors: int = 3, | |
| sim_min: float = 0.25, | |
| sim_max: float = 0.8, | |
| balance_threshold: float = 0.2 | |
| ): | |
| """ | |
| Generate distractors for antonym questions. | |
| Input: 2 opposite words | |
| Output: neutral / intermediate distractors | |
| """ | |
| assert len(correct_words) == 2, "Require exactly 2 antonyms" | |
| w1, w2 = [w.lower().strip() for w in correct_words] | |
| candidates = set() | |
| # -------- 1. Collect candidates from both antonyms ---------- | |
| for w in [w1, w2]: | |
| sense = self._s2v.get_best_sense(w.replace(" ", "_")) | |
| if sense and sense in self._s2v: | |
| sims = self._s2v.most_similar(sense, n=40) | |
| candidates.update(change_format(sims)) | |
| # Remove originals | |
| candidates = { | |
| c for c in candidates | |
| if c.lower() not in {w1, w2} | |
| } | |
| if not candidates: | |
| return [] | |
| candidates = list(candidates) | |
| # -------- 2. Sentence embedding ---------- | |
| emb_correct = self._sentence_model.encode(correct_words) | |
| emb_candidates = self._sentence_model.encode(candidates) | |
| sim_1 = cosine_similarity(emb_candidates, emb_correct[0].reshape(1, -1)) | |
| sim_2 = cosine_similarity(emb_candidates, emb_correct[1].reshape(1, -1)) | |
| final_candidates = [] | |
| for idx, word in enumerate(candidates): | |
| s1 = sim_1[idx][0] | |
| s2 = sim_2[idx][0] | |
| # quá gần một cực → loại | |
| if max(s1, s2) > sim_max: | |
| continue | |
| # quá xa cả hai → loại | |
| if max(s1, s2) < sim_min: | |
| continue | |
| # không cân bằng → nghiêng hẳn về 1 phía | |
| if abs(s1 - s2) > balance_threshold: | |
| continue | |
| final_candidates.append( | |
| (word, (s1 + s2) / 2) | |
| ) | |
| if not final_candidates: | |
| return [] | |
| chosen = random.sample( | |
| final_candidates, | |
| k=min(num_distractors, len(final_candidates)) | |
| ) | |
| return [w.capitalize() for w, _ in chosen] | |