Spaces:
Sleeping
Sleeping
linhnguyen02
commited on
Commit
·
8e72e5b
1
Parent(s):
48a5849
sen2vec false ans
Browse files
src/factories/gen_question/types/antonym_question.py
CHANGED
|
@@ -5,11 +5,17 @@ from src.factories.gen_question.types.base import Question, nltk_words
|
|
| 5 |
from src.enums import QuestionTypeEnum
|
| 6 |
|
| 7 |
from src.loaders.elastic import Elastic
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
|
| 11 |
class AntonymsQuestion(Question):
|
| 12 |
INDEX = "vocabulary"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def generate_questions(self, list_words: List[str] = None, num_question: int = 1,
|
| 15 |
num_ans_per_question: int = 4, cefr: int = 3):
|
|
@@ -22,48 +28,55 @@ class AntonymsQuestion(Question):
|
|
| 22 |
used_choices = set()
|
| 23 |
|
| 24 |
for _ in range(num_question):
|
| 25 |
-
|
| 26 |
-
question_word, correct_answer, antonym_set = \
|
| 27 |
self._pick_question_word(list_unique_words, used_words, cefr)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
}
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
choices = [correct_answer]
|
| 40 |
|
| 41 |
-
while len(choices) < num_ans_per_question and max_loop > 0:
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
candidate in used_words or
|
| 52 |
-
candidate in antonym_set or
|
| 53 |
-
candidate == question_word or
|
| 54 |
-
candidate == correct_answer
|
| 55 |
-
):
|
| 56 |
-
continue
|
| 57 |
|
| 58 |
-
# Loại distractor có nghĩa trùng với đáp án
|
| 59 |
-
syns = set(self.get_list_antonym(candidate))
|
| 60 |
-
if correct_answer in syns:
|
| 61 |
-
continue
|
| 62 |
|
| 63 |
-
choices.append(candidate)
|
| 64 |
-
used_choices.add(candidate)
|
| 65 |
-
max_loop -= 1
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
random.shuffle(choices)
|
| 68 |
|
| 69 |
final_choices = []
|
|
@@ -122,7 +135,7 @@ class AntonymsQuestion(Question):
|
|
| 122 |
continue
|
| 123 |
|
| 124 |
correct = random.choice(valid_syns)
|
| 125 |
-
return source, correct
|
| 126 |
|
| 127 |
# FALLBACK ES
|
| 128 |
while True:
|
|
@@ -140,4 +153,4 @@ class AntonymsQuestion(Question):
|
|
| 140 |
if not valid_syns:
|
| 141 |
continue
|
| 142 |
|
| 143 |
-
return source, random.choice(valid_syns)
|
|
|
|
| 5 |
from src.enums import QuestionTypeEnum
|
| 6 |
|
| 7 |
from src.loaders.elastic import Elastic
|
| 8 |
+
from src.services.AI.false_ans_generator import FalseAnswerGenerator
|
| 9 |
|
| 10 |
|
| 11 |
|
| 12 |
class AntonymsQuestion(Question):
|
| 13 |
INDEX = "vocabulary"
|
| 14 |
+
false_ans_gen: FalseAnswerGenerator = None
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
if self.false_ans_gen is None:
|
| 18 |
+
self.false_ans_gen = FalseAnswerGenerator()
|
| 19 |
|
| 20 |
def generate_questions(self, list_words: List[str] = None, num_question: int = 1,
|
| 21 |
num_ans_per_question: int = 4, cefr: int = 3):
|
|
|
|
| 28 |
used_choices = set()
|
| 29 |
|
| 30 |
for _ in range(num_question):
|
| 31 |
+
question_word, correct_answer = \
|
|
|
|
| 32 |
self._pick_question_word(list_unique_words, used_words, cefr)
|
| 33 |
+
|
| 34 |
+
# max_loop = 100
|
| 35 |
+
# pos = self.get_pos({
|
| 36 |
+
# "bool": {
|
| 37 |
+
# "must": [
|
| 38 |
+
# {"term": {"word.keyword": question_word.lower()}},
|
| 39 |
+
# {"term": {"antonyms.keyword": correct_answer.lower()}}
|
| 40 |
+
# ]
|
| 41 |
+
# }
|
| 42 |
+
# })
|
| 43 |
+
# used_words.update([question_word, correct_answer])
|
| 44 |
|
| 45 |
choices = [correct_answer]
|
| 46 |
|
| 47 |
+
# while len(choices) < num_ans_per_question and max_loop > 0:
|
| 48 |
+
# doc = self.get_random(self.INDEX, None, cefr=cefr, pos=pos)
|
| 49 |
+
# if not doc:
|
| 50 |
+
# continue
|
| 51 |
+
|
| 52 |
+
# candidate = doc["word"]
|
| 53 |
+
|
| 54 |
+
# # Loại trừ điều kiện chung
|
| 55 |
+
# if (
|
| 56 |
+
# candidate in used_choices or
|
| 57 |
+
# candidate in used_words or
|
| 58 |
+
# candidate in antonym_set or
|
| 59 |
+
# candidate == question_word or
|
| 60 |
+
# candidate == correct_answer
|
| 61 |
+
# ):
|
| 62 |
+
# continue
|
| 63 |
|
| 64 |
+
# # Loại distractor có nghĩa trùng với đáp án
|
| 65 |
+
# syns = set(self.get_list_antonym(candidate))
|
| 66 |
+
# if correct_answer in syns:
|
| 67 |
+
# continue
|
| 68 |
|
| 69 |
+
# choices.append(candidate)
|
| 70 |
+
# used_choices.add(candidate)
|
| 71 |
+
# max_loop -= 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
+
distractors = self.false_ans_gen.generate_distractors_from_antonyms(
|
| 76 |
+
target_word=[correct_answer, question_word],
|
| 77 |
+
num_false_answers=num_ans_per_question - 1
|
| 78 |
+
)
|
| 79 |
+
choices.extend(distractors)
|
| 80 |
random.shuffle(choices)
|
| 81 |
|
| 82 |
final_choices = []
|
|
|
|
| 135 |
continue
|
| 136 |
|
| 137 |
correct = random.choice(valid_syns)
|
| 138 |
+
return source, correct
|
| 139 |
|
| 140 |
# FALLBACK ES
|
| 141 |
while True:
|
|
|
|
| 153 |
if not valid_syns:
|
| 154 |
continue
|
| 155 |
|
| 156 |
+
return source, random.choice(valid_syns)
|
src/factories/gen_question/types/synonym_question.py
CHANGED
|
@@ -5,9 +5,16 @@ from src.factories.gen_question.types.base import Question, nltk_words
|
|
| 5 |
from src.enums import QuestionTypeEnum
|
| 6 |
from src.loaders.elastic import Elastic
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class SynonymsQuestion(Question):
|
| 10 |
INDEX = "vocabulary"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def generate_questions(self, list_words: List[str] = None, num_question: int = 1,
|
| 13 |
num_ans_per_question: int = 4, cefr: int = 3):
|
|
@@ -20,48 +27,53 @@ class SynonymsQuestion(Question):
|
|
| 20 |
used_choices = set()
|
| 21 |
|
| 22 |
for _ in range(num_question):
|
| 23 |
-
|
| 24 |
-
question_word, correct_answer, synonym_set = \
|
| 25 |
self._pick_question_word(list_unique_words, used_words, cefr)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
}
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
choices = [correct_answer]
|
| 38 |
|
| 39 |
-
while len(choices) < num_ans_per_question and max_loop > 0:
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
random.shuffle(choices)
|
| 66 |
|
| 67 |
final_choices = []
|
|
@@ -120,7 +132,8 @@ class SynonymsQuestion(Question):
|
|
| 120 |
continue
|
| 121 |
|
| 122 |
correct = random.choice(valid_syns)
|
| 123 |
-
return source, correct
|
|
|
|
| 124 |
|
| 125 |
# FALLBACK ES
|
| 126 |
while True:
|
|
@@ -138,4 +151,5 @@ class SynonymsQuestion(Question):
|
|
| 138 |
if not valid_syns:
|
| 139 |
continue
|
| 140 |
|
| 141 |
-
return source, random.choice(valid_syns)
|
|
|
|
|
|
| 5 |
from src.enums import QuestionTypeEnum
|
| 6 |
from src.loaders.elastic import Elastic
|
| 7 |
|
| 8 |
+
from src.services.AI.false_ans_generator import FalseAnswerGenerator
|
| 9 |
+
|
| 10 |
|
| 11 |
class SynonymsQuestion(Question):
|
| 12 |
INDEX = "vocabulary"
|
| 13 |
+
false_ans_gen: FalseAnswerGenerator = None
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
if self.false_ans_gen is None:
|
| 17 |
+
self.false_ans_gen = FalseAnswerGenerator()
|
| 18 |
|
| 19 |
def generate_questions(self, list_words: List[str] = None, num_question: int = 1,
|
| 20 |
num_ans_per_question: int = 4, cefr: int = 3):
|
|
|
|
| 27 |
used_choices = set()
|
| 28 |
|
| 29 |
for _ in range(num_question):
|
| 30 |
+
question_word, correct_answer = \
|
|
|
|
| 31 |
self._pick_question_word(list_unique_words, used_words, cefr)
|
| 32 |
|
| 33 |
+
# max_loop = 100
|
| 34 |
+
# pos = self.get_pos({
|
| 35 |
+
# "bool": {
|
| 36 |
+
# "must": [
|
| 37 |
+
# {"term": {"word.keyword": question_word.lower()}},
|
| 38 |
+
# {"term": {"synonyms.keyword": correct_answer.lower()}}
|
| 39 |
+
# ]
|
| 40 |
+
# }
|
| 41 |
+
# })
|
| 42 |
+
# used_words.update([question_word, correct_answer])
|
| 43 |
|
| 44 |
choices = [correct_answer]
|
| 45 |
|
| 46 |
+
# while len(choices) < num_ans_per_question and max_loop > 0:
|
| 47 |
+
# doc = self.get_random(self.INDEX, None, cefr=cefr, pos=pos)
|
| 48 |
+
# if not doc:
|
| 49 |
+
# continue
|
| 50 |
+
|
| 51 |
+
# candidate = doc["word"]
|
| 52 |
+
|
| 53 |
+
# # Loại trừ điều kiện chung
|
| 54 |
+
# if (
|
| 55 |
+
# candidate in used_choices or
|
| 56 |
+
# candidate in used_words or
|
| 57 |
+
# candidate in synonym_set or
|
| 58 |
+
# candidate == question_word or
|
| 59 |
+
# candidate == correct_answer
|
| 60 |
+
# ):
|
| 61 |
+
# continue
|
| 62 |
+
|
| 63 |
+
# # Loại distractor có nghĩa trùng với đáp án
|
| 64 |
+
# syns = set(self.get_list_synonym(candidate))
|
| 65 |
+
# if correct_answer in syns:
|
| 66 |
+
# continue
|
| 67 |
+
|
| 68 |
+
# choices.append(candidate)
|
| 69 |
+
# used_choices.add(candidate)
|
| 70 |
+
# max_loop -= 1
|
| 71 |
+
|
| 72 |
+
distractors = self.false_ans_gen.generate_distractors_from_synonyms(
|
| 73 |
+
target_word=[correct_answer, question_word],
|
| 74 |
+
num_false_answers=num_ans_per_question - 1
|
| 75 |
+
)
|
| 76 |
+
choices.extend(distractors)
|
| 77 |
random.shuffle(choices)
|
| 78 |
|
| 79 |
final_choices = []
|
|
|
|
| 132 |
continue
|
| 133 |
|
| 134 |
correct = random.choice(valid_syns)
|
| 135 |
+
return source, correct
|
| 136 |
+
# return source, correct, set(syns)
|
| 137 |
|
| 138 |
# FALLBACK ES
|
| 139 |
while True:
|
|
|
|
| 151 |
if not valid_syns:
|
| 152 |
continue
|
| 153 |
|
| 154 |
+
return source, random.choice(valid_syns)
|
| 155 |
+
# return source, random.choice(valid_syns), set(syns)
|
src/factories/gen_question_for_paragraph/types/synthetic.py
CHANGED
|
@@ -26,14 +26,13 @@ class ParagraphQuestion(Question):
|
|
| 26 |
num = question_data.num_question
|
| 27 |
type_to_total_count[qtype] = type_to_total_count.get(qtype, 0) + num
|
| 28 |
|
| 29 |
-
final_output = {}
|
| 30 |
for qtype, total_count in type_to_total_count.items():
|
| 31 |
prompt = type_to_prompt_map.get(qtype)
|
| 32 |
if not prompt:
|
| 33 |
continue
|
| 34 |
|
| 35 |
content_user = (
|
| 36 |
-
f"PARAGRAPH: {data.
|
| 37 |
f"QUESTION_COUNT: {total_count}\n"
|
| 38 |
f"OPTIONS_PER_QUESTION: {data.num_ans_per_question}\n"
|
| 39 |
)
|
|
@@ -60,7 +59,7 @@ class ParagraphQuestion(Question):
|
|
| 60 |
|
| 61 |
for question in data.get("list_questions", []):
|
| 62 |
result.append({
|
| 63 |
-
"
|
| 64 |
"type": qtype,
|
| 65 |
"choices": question.get("choices", []),
|
| 66 |
"answer": question.get("answer"),
|
|
|
|
| 26 |
num = question_data.num_question
|
| 27 |
type_to_total_count[qtype] = type_to_total_count.get(qtype, 0) + num
|
| 28 |
|
|
|
|
| 29 |
for qtype, total_count in type_to_total_count.items():
|
| 30 |
prompt = type_to_prompt_map.get(qtype)
|
| 31 |
if not prompt:
|
| 32 |
continue
|
| 33 |
|
| 34 |
content_user = (
|
| 35 |
+
f"PARAGRAPH: {data.paragraph}\n"
|
| 36 |
f"QUESTION_COUNT: {total_count}\n"
|
| 37 |
f"OPTIONS_PER_QUESTION: {data.num_ans_per_question}\n"
|
| 38 |
)
|
|
|
|
| 59 |
|
| 60 |
for question in data.get("list_questions", []):
|
| 61 |
result.append({
|
| 62 |
+
"content": question.get("question"),
|
| 63 |
"type": qtype,
|
| 64 |
"choices": question.get("choices", []),
|
| 65 |
"answer": question.get("answer"),
|
src/interfaces/question.py
CHANGED
|
@@ -15,7 +15,7 @@ class IQuestionConfig(BaseModel):
|
|
| 15 |
num_question: int = Field(..., ge=1, le=5)
|
| 16 |
|
| 17 |
class ICreateQuestionForParagraph(BaseModel):
|
| 18 |
-
|
| 19 |
num_ans_per_question: int = Field(..., ge=2, le=6)
|
| 20 |
list_create_question: List[IQuestionConfig]
|
| 21 |
|
|
|
|
| 15 |
num_question: int = Field(..., ge=1, le=5)
|
| 16 |
|
| 17 |
class ICreateQuestionForParagraph(BaseModel):
|
| 18 |
+
paragraph: Text
|
| 19 |
num_ans_per_question: int = Field(..., ge=2, le=6)
|
| 20 |
list_create_question: List[IQuestionConfig]
|
| 21 |
|
src/services/AI/false_ans_generator.py
CHANGED
|
@@ -198,3 +198,150 @@ class FalseAnswerGenerator:
|
|
| 198 |
all_answers.append(results)
|
| 199 |
|
| 200 |
return crct_ans, sum(all_answers, [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
all_answers.append(results)
|
| 199 |
|
| 200 |
return crct_ans, sum(all_answers, [])
|
| 201 |
+
|
| 202 |
+
def generate_distractors_from_synonyms(
|
| 203 |
+
self,
|
| 204 |
+
correct_words: list[str],
|
| 205 |
+
num_distractors: int = 3,
|
| 206 |
+
sim_min: float = 0.35,
|
| 207 |
+
sim_max: float = 0.75
|
| 208 |
+
):
|
| 209 |
+
"""
|
| 210 |
+
Generate distractors for synonym questions.
|
| 211 |
+
Input: 2 correct synonymous words
|
| 212 |
+
Output: distractors semantically related but NOT synonyms
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
assert len(correct_words) == 2, "Require exactly 2 correct synonyms"
|
| 216 |
+
|
| 217 |
+
w1, w2 = [w.lower().strip() for w in correct_words]
|
| 218 |
+
|
| 219 |
+
candidates = set()
|
| 220 |
+
|
| 221 |
+
# -------- 1. Collect candidates from sense2vec ----------
|
| 222 |
+
for w in [w1, w2]:
|
| 223 |
+
sense = self._s2v.get_best_sense(w.replace(" ", "_"))
|
| 224 |
+
if sense and sense in self._s2v:
|
| 225 |
+
sims = self._s2v.most_similar(sense, n=30)
|
| 226 |
+
formatted = change_format(sims)
|
| 227 |
+
candidates.update(formatted)
|
| 228 |
+
|
| 229 |
+
# Remove originals
|
| 230 |
+
candidates = {
|
| 231 |
+
c for c in candidates
|
| 232 |
+
if c.lower() not in {w1, w2}
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
if not candidates:
|
| 236 |
+
return []
|
| 237 |
+
|
| 238 |
+
candidates = list(candidates)
|
| 239 |
+
|
| 240 |
+
# -------- 2. Sentence embedding ----------
|
| 241 |
+
emb_correct = self._sentence_model.encode(correct_words)
|
| 242 |
+
emb_candidates = self._sentence_model.encode(candidates)
|
| 243 |
+
|
| 244 |
+
# similarity to each correct word
|
| 245 |
+
sim_1 = cosine_similarity(emb_candidates, emb_correct[0].reshape(1, -1))
|
| 246 |
+
sim_2 = cosine_similarity(emb_candidates, emb_correct[1].reshape(1, -1))
|
| 247 |
+
|
| 248 |
+
final_candidates = []
|
| 249 |
+
|
| 250 |
+
for idx, word in enumerate(candidates):
|
| 251 |
+
s1 = sim_1[idx][0]
|
| 252 |
+
s2 = sim_2[idx][0]
|
| 253 |
+
|
| 254 |
+
# loại bỏ các từ quá giống
|
| 255 |
+
if max(s1, s2) > sim_max:
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
# loại bỏ các từ quá khác
|
| 259 |
+
if max(s1, s2) < sim_min:
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
final_candidates.append((word, max(s1, s2)))
|
| 263 |
+
|
| 264 |
+
chosen = random.sample(
|
| 265 |
+
final_candidates,
|
| 266 |
+
k=min(num_distractors, len(final_candidates))
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
return [w.capitalize() for w, _ in chosen]
|
| 270 |
+
|
| 271 |
+
def generate_distractors_from_antonyms(
|
| 272 |
+
self,
|
| 273 |
+
correct_words: list[str],
|
| 274 |
+
num_distractors: int = 3,
|
| 275 |
+
sim_min: float = 0.25,
|
| 276 |
+
sim_max: float = 0.75,
|
| 277 |
+
balance_threshold: float = 0.2
|
| 278 |
+
):
|
| 279 |
+
"""
|
| 280 |
+
Generate distractors for antonym questions.
|
| 281 |
+
Input: 2 opposite words
|
| 282 |
+
Output: neutral / intermediate distractors
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
assert len(correct_words) == 2, "Require exactly 2 antonyms"
|
| 286 |
+
|
| 287 |
+
w1, w2 = [w.lower().strip() for w in correct_words]
|
| 288 |
+
|
| 289 |
+
candidates = set()
|
| 290 |
+
|
| 291 |
+
# -------- 1. Collect candidates from both antonyms ----------
|
| 292 |
+
for w in [w1, w2]:
|
| 293 |
+
sense = self._s2v.get_best_sense(w.replace(" ", "_"))
|
| 294 |
+
if sense and sense in self._s2v:
|
| 295 |
+
sims = self._s2v.most_similar(sense, n=40)
|
| 296 |
+
candidates.update(change_format(sims))
|
| 297 |
+
|
| 298 |
+
# Remove originals
|
| 299 |
+
candidates = {
|
| 300 |
+
c for c in candidates
|
| 301 |
+
if c.lower() not in {w1, w2}
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
if not candidates:
|
| 305 |
+
return []
|
| 306 |
+
|
| 307 |
+
candidates = list(candidates)
|
| 308 |
+
|
| 309 |
+
# -------- 2. Sentence embedding ----------
|
| 310 |
+
emb_correct = self._sentence_model.encode(correct_words)
|
| 311 |
+
emb_candidates = self._sentence_model.encode(candidates)
|
| 312 |
+
|
| 313 |
+
sim_1 = cosine_similarity(emb_candidates, emb_correct[0].reshape(1, -1))
|
| 314 |
+
sim_2 = cosine_similarity(emb_candidates, emb_correct[1].reshape(1, -1))
|
| 315 |
+
|
| 316 |
+
final_candidates = []
|
| 317 |
+
|
| 318 |
+
for idx, word in enumerate(candidates):
|
| 319 |
+
s1 = sim_1[idx][0]
|
| 320 |
+
s2 = sim_2[idx][0]
|
| 321 |
+
|
| 322 |
+
# quá gần một cực → loại
|
| 323 |
+
if max(s1, s2) > sim_max:
|
| 324 |
+
continue
|
| 325 |
+
|
| 326 |
+
# quá xa cả hai → loại
|
| 327 |
+
if max(s1, s2) < sim_min:
|
| 328 |
+
continue
|
| 329 |
+
|
| 330 |
+
# không cân bằng → nghiêng hẳn về 1 phía
|
| 331 |
+
if abs(s1 - s2) > balance_threshold:
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
final_candidates.append(
|
| 335 |
+
(word, (s1 + s2) / 2)
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
if not final_candidates:
|
| 339 |
+
return []
|
| 340 |
+
|
| 341 |
+
chosen = random.sample(
|
| 342 |
+
final_candidates,
|
| 343 |
+
k=min(num_distractors, len(final_candidates))
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
return [w.capitalize() for w, _ in chosen]
|
| 347 |
+
|