gen-question / src /services /AI /false_ans_generator.py
linhnguyen02
set up to deploy in hugging face
42cffde
"""This module generates false answers within same context.
@Author: Karthick T. Sharma
"""
import os
import random
import urllib.request
import tarfile
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from sense2vec import Sense2Vec
from src.utils.text_process import change_format
import tempfile
class FalseAnswerGenerator:
"""Generate false answers within same context."""
_instance = None
# def __init__(self):
# """Initialize false answer generation models."""
# self.__init_sentence_transformer()
# self.__init_sense2vec()
def __new__(cls):
if cls._instance is None:
cls._instance = super(FalseAnswerGenerator, cls).__new__(cls)
cls._instance._init_models()
return cls._instance
def _init_models(self):
self.__init_sentence_transformer()
self.__init_sense2vec()
def __init_sentence_transformer(self):
"""Initialize sentence embedding.
https://www.sbert.net/
"""
self._sentence_model = SentenceTransformer('all-MiniLM-L12-v2')
def __init_sense2vec(self):
"""Initialize word vectors to get similar words.
https://github.com/explosion/sense2vec
"""
if not os.path.isdir(os.getcwd() + '/s2v_old'):
s2v_url = "https://github.com/explosion/sense2vec/releases/download/"
s2v_ver_url = s2v_url + "v1.0.0/s2v_reddit_2015_md.tar.gz"
with urllib.request.urlopen(s2v_ver_url) as req:
# save downloaded to a temp file first
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(req.read())
temp_file_path = temp_file.name
with tarfile.open(temp_file_path, mode='r:gz') as file:
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)
safe_extract(file)
self._s2v = Sense2Vec().from_disk("s2v_old")
def __get_embedding(self, answer, distractors):
"""Returns sentence model embedding of answer and distractors.
Args:
answer (str): correct answer.
distractors (list[str]): false answers.
Returns:
tuple[list[str], list[str]]: sentence model embedding of answer and distractors.
"""
return self._sentence_model.encode([answer]), self._sentence_model.encode(distractors)
def get_embedding_list_word(self, word_list: list[str]):
"""
Returns sentence model embedding of answer and distractors.
"""
return self._sentence_model.encode([word_list])
def filter_output(self, orig, dummies):
"""Filter out final answers.
Args:
orig (str): correct answer.
dummies (list[str]): false answers list generated from correct answer.
Returns:
list[str]: list of final answer which has low similarity.
"""
ans_embedded, dis_embedded = self.__get_embedding(orig, dummies)
# filter using MMMR
dist = self.__mmr(ans_embedded, dis_embedded, dummies)
filtered_dist = []
for dis in dist:
# 0 -> word, 1 -> confidence / probability
filtered_dist.append(dis[0].capitalize())
return filtered_dist
def __mmr(self, doc_embedding, word_embedding, words, diversity=0.9):
"""Word diversity using MMR - Maximal Marginal Relevance.
Args:
doc_embedding (list[str]): sentence embedding of correct answer.
word_embedding (list[str]): sentence embedding of false answer.
words (list[str]): false answers.
diversity (float, optional): diversity coefficient. Defaults to 0.9.
Returns:
list[str]: list of final answers.
"""
# extract similarity between words and docs
word_doc_similarity = cosine_similarity(word_embedding, doc_embedding)
word_similarity = cosine_similarity(word_embedding)
kw_idx = [np.argmax(word_doc_similarity)] # NumPy 2.0.2 vẫn hỗ trợ np.argmax()
dist_idx = [i for i in range(len(words)) if i != kw_idx[0]]
for _ in range(3):
dist_similarities = word_doc_similarity[dist_idx, :]
target_similarities = np.max(
word_similarity[dist_idx][:, kw_idx], axis=1
)
# calculate MMR
mmr = (1 - diversity) * dist_similarities - \
diversity * target_similarities.reshape(-1, 1)
mmr_idx = dist_idx[np.argmax(mmr)] # NumPy vẫn hỗ trợ np.argmax()
# update kw
kw_idx.append(mmr_idx)
dist_idx.remove(mmr_idx)
return [(words[idx], round(float(word_doc_similarity.reshape(1, -1)[0][idx]), 4))
for idx in kw_idx]
def __generate_answer(self, query):
"""Generate false answers from correct answer.
Args:
query (str): correct answer.
Returns:
list(str): list of final answers if input is valid, else None.
"""
# get the best sense for given word (like NOUN, PRONOUN, VERB...)
query_al = self._s2v.get_best_sense(query.lower().replace(' ', '_'))
if query_al is None:
return None
try:
assert query_al in self._s2v
# get most similar 20 words (if any)
temp = self._s2v.most_similar(query_al, n=20)
formatted_string = change_format(temp)
formatted_string.insert(0, query)
# if answers are numbers then we don't need to filter
if query_al.split('|')[1] == 'CARDINAL':
return formatted_string[:4]
# else filter because sometimes similar words will be US, U.S, USA, AMERICA...
return self.filter_output(query, formatted_string)
except AssertionError:
return None
def get_output(self, filtered_kws):
"""Generate false answers for whole context.
Filter out keywords that don't generate 3 false answers.
Args:
filtered_kws (list(str)): list of keywords
Returns:
tuple(list(str), list(list(str))): tuple of correct answers and list of all answers.
"""
crct_ans = []
all_answers = []
for kws in filtered_kws:
for kwx in kws:
results = self.__generate_answer(kwx)
if results is not None:
crct_ans.append(kwx.capitalize())
random.shuffle(results)
all_answers.append(results)
return crct_ans, sum(all_answers, [])
def generate_distractors_from_synonyms(
self,
correct_words: list[str],
num_distractors: int = 3,
sim_min: float = 0.35,
sim_max: float = 0.75
):
"""
Generate distractors for synonym questions.
Input: 2 correct synonymous words
Output: distractors semantically related but NOT synonyms
"""
assert len(correct_words) == 2, "Require exactly 2 correct synonyms"
w1, w2 = [w.lower().strip() for w in correct_words]
candidates = set()
# -------- 1. Collect candidates from sense2vec ----------
for w in [w1, w2]:
sense = self._s2v.get_best_sense(w.replace(" ", "_"))
if sense and sense in self._s2v:
sims = self._s2v.most_similar(sense, n=30)
formatted = change_format(sims)
candidates.update(formatted)
# Remove originals
candidates = {
c for c in candidates
if c.lower() not in {w1, w2}
}
if not candidates:
return []
candidates = list(candidates)
# -------- 2. Sentence embedding ----------
emb_correct = self._sentence_model.encode(correct_words)
emb_candidates = self._sentence_model.encode(candidates)
# similarity to each correct word
sim_1 = cosine_similarity(emb_candidates, emb_correct[0].reshape(1, -1))
sim_2 = cosine_similarity(emb_candidates, emb_correct[1].reshape(1, -1))
final_candidates = []
for idx, word in enumerate(candidates):
s1 = sim_1[idx][0]
s2 = sim_2[idx][0]
# loại bỏ các từ quá giống
if max(s1, s2) > sim_max:
continue
# loại bỏ các từ quá khác
if max(s1, s2) < sim_min:
continue
final_candidates.append((word, max(s1, s2)))
chosen = random.sample(
final_candidates,
k=min(num_distractors, len(final_candidates))
)
return [w.capitalize() for w, _ in chosen]
def generate_distractors_from_antonyms(
self,
correct_words: list[str],
num_distractors: int = 3,
sim_min: float = 0.25,
sim_max: float = 0.8,
balance_threshold: float = 0.2
):
"""
Generate distractors for antonym questions.
Input: 2 opposite words
Output: neutral / intermediate distractors
"""
assert len(correct_words) == 2, "Require exactly 2 antonyms"
w1, w2 = [w.lower().strip() for w in correct_words]
candidates = set()
# -------- 1. Collect candidates from both antonyms ----------
for w in [w1, w2]:
sense = self._s2v.get_best_sense(w.replace(" ", "_"))
if sense and sense in self._s2v:
sims = self._s2v.most_similar(sense, n=40)
candidates.update(change_format(sims))
# Remove originals
candidates = {
c for c in candidates
if c.lower() not in {w1, w2}
}
if not candidates:
return []
candidates = list(candidates)
# -------- 2. Sentence embedding ----------
emb_correct = self._sentence_model.encode(correct_words)
emb_candidates = self._sentence_model.encode(candidates)
sim_1 = cosine_similarity(emb_candidates, emb_correct[0].reshape(1, -1))
sim_2 = cosine_similarity(emb_candidates, emb_correct[1].reshape(1, -1))
final_candidates = []
for idx, word in enumerate(candidates):
s1 = sim_1[idx][0]
s2 = sim_2[idx][0]
# quá gần một cực → loại
if max(s1, s2) > sim_max:
continue
# quá xa cả hai → loại
if max(s1, s2) < sim_min:
continue
# không cân bằng → nghiêng hẳn về 1 phía
if abs(s1 - s2) > balance_threshold:
continue
final_candidates.append(
(word, (s1 + s2) / 2)
)
if not final_candidates:
return []
chosen = random.sample(
final_candidates,
k=min(num_distractors, len(final_candidates))
)
return [w.capitalize() for w, _ in chosen]