Spaces:
Sleeping
Sleeping
ayushmodi001 commited on
Commit ·
2eedf0d
1
Parent(s): ab1217c
updated distractor generator
Browse files- app/distractor_generator.py +51 -255
- requirements.txt +2 -1
app/distractor_generator.py
CHANGED
|
@@ -1,272 +1,68 @@
|
|
| 1 |
-
'''
|
| 2 |
-
import google.generativeai as genai
|
| 3 |
-
from .config import GOOGLE_API_KEY, GEMINI_MODEL_NAME, DISTRACTOR_BATCH_SIZE, MAX_PARALLEL_DISTRACTOR_BATCHES
|
| 4 |
-
from .utils import get_logger, parse_json_from_response
|
| 5 |
-
import json
|
| 6 |
-
import random
|
| 7 |
-
import asyncio
|
| 8 |
-
import time
|
| 9 |
-
|
| 10 |
-
logger = get_logger(__name__)
|
| 11 |
-
|
| 12 |
-
genai.configure(api_key=GOOGLE_API_KEY)
|
| 13 |
-
|
| 14 |
-
async def generate_distractors(qa_pairs):
|
| 15 |
-
logger.info(f"Generating distractors for {len(qa_pairs)} QA pairs.")
|
| 16 |
-
|
| 17 |
-
if not qa_pairs:
|
| 18 |
-
logger.warning("No QA pairs provided to generate distractors.")
|
| 19 |
-
return []
|
| 20 |
-
|
| 21 |
-
try:
|
| 22 |
-
model = genai.GenerativeModel(GEMINI_MODEL_NAME)
|
| 23 |
-
|
| 24 |
-
example_input = [{
|
| 25 |
-
"question": "What is the primary goal of the A* search algorithm?",
|
| 26 |
-
"correct_answer": "To find the shortest path between a start and end node",
|
| 27 |
-
"context": "The A* search algorithm is widely used in pathfinding and graph traversal to find the shortest path between a start and end node."
|
| 28 |
-
}]
|
| 29 |
-
example_output = [{
|
| 30 |
-
"question": "What is the primary goal of the A* search algorithm?",
|
| 31 |
-
"correct_answer": "To find the shortest path between a start and end node",
|
| 32 |
-
"distractors": [
|
| 33 |
-
"To find the longest path in a graph",
|
| 34 |
-
"To visit every node in a graph"
|
| 35 |
-
]
|
| 36 |
-
}]
|
| 37 |
-
|
| 38 |
-
# Optimize batch size based on number of QA pairs
|
| 39 |
-
if len(qa_pairs) <= 5:
|
| 40 |
-
BATCH_SIZE = 2
|
| 41 |
-
elif len(qa_pairs) <= 10:
|
| 42 |
-
BATCH_SIZE = 3
|
| 43 |
-
else:
|
| 44 |
-
BATCH_SIZE = DISTRACTOR_BATCH_SIZE
|
| 45 |
-
|
| 46 |
-
batches = []
|
| 47 |
-
all_distractors_data = []
|
| 48 |
-
|
| 49 |
-
# Create batches of questions
|
| 50 |
-
for i in range(0, len(qa_pairs), BATCH_SIZE):
|
| 51 |
-
batches.append(qa_pairs[i:i + BATCH_SIZE])
|
| 52 |
-
|
| 53 |
-
logger.info(f"Processing distractor generation in {len(batches)} parallel batches")
|
| 54 |
-
|
| 55 |
-
# Define an async function to process a single batch
|
| 56 |
-
async def process_batch(batch_index, batch_data):
|
| 57 |
-
start_time = time.time()
|
| 58 |
-
batch_id = f"batch-{batch_index+1}"
|
| 59 |
-
logger.info(f"Starting distractor generation for {batch_id} with {len(batch_data)} questions")
|
| 60 |
-
|
| 61 |
-
prompt = f"""Generate three plausible but incorrect distractors for each of the following question-answer pairs.
|
| 62 |
-
The distractors must be conceptually related to the correct answer but clearly wrong. They should be of the same entity type as the correct answer. For example, if the answer is a specific algorithm, the distractors should be other algorithm names.
|
| 63 |
-
|
| 64 |
-
Return the output as a single, valid JSON array of objects. Do NOT include any text or markdown formatting outside of the JSON array. Each object must have three keys: "question", "correct_answer", and "distractors" (a list of 3 strings).
|
| 65 |
-
|
| 66 |
-
EXAMPLE INPUT:
|
| 67 |
-
{json.dumps(example_input, indent=2)}
|
| 68 |
-
|
| 69 |
-
EXAMPLE OUTPUT:
|
| 70 |
-
{json.dumps(example_output, indent=2)}
|
| 71 |
-
|
| 72 |
-
ACTUAL INPUT:
|
| 73 |
-
{json.dumps(batch_data, indent=2)}
|
| 74 |
-
"""
|
| 75 |
-
try:
|
| 76 |
-
response = await model.generate_content_async(prompt)
|
| 77 |
-
distractors_data = await parse_json_from_response(response.text, logger)
|
| 78 |
-
processing_time = time.time() - start_time
|
| 79 |
-
logger.info(f"Completed distractor generation for {batch_id} in {processing_time:.2f}s")
|
| 80 |
-
return distractors_data or []
|
| 81 |
-
except Exception as e:
|
| 82 |
-
logger.error(f"Error processing distractor batch {batch_id}: {e}", exc_info=True)
|
| 83 |
-
return []
|
| 84 |
-
|
| 85 |
-
# Process all batches in parallel, but limit concurrency to avoid overloading
|
| 86 |
-
semaphore = asyncio.Semaphore(MAX_PARALLEL_DISTRACTOR_BATCHES)
|
| 87 |
-
|
| 88 |
-
async def process_with_semaphore(idx, batch):
|
| 89 |
-
async with semaphore:
|
| 90 |
-
return await process_batch(idx, batch)
|
| 91 |
-
|
| 92 |
-
batch_results = await asyncio.gather(*[process_with_semaphore(i, batches[i]) for i in range(len(batches))], return_exceptions=True)
|
| 93 |
-
|
| 94 |
-
# Combine results
|
| 95 |
-
for result in batch_results:
|
| 96 |
-
if isinstance(result, Exception):
|
| 97 |
-
logger.error(f"Exception during distractor generation: {result}", exc_info=True)
|
| 98 |
-
continue
|
| 99 |
-
all_distractors_data.extend(result)
|
| 100 |
-
|
| 101 |
-
if not all_distractors_data:
|
| 102 |
-
raise ValueError("Failed to parse any JSON from Gemini responses.")
|
| 103 |
-
|
| 104 |
-
distractor_map = {item['question']: item['distractors'] for item in all_distractors_data if 'question' in item and 'distractors' in item}
|
| 105 |
-
|
| 106 |
-
all_correct_answers = [qa['correct_answer'] for qa in qa_pairs]
|
| 107 |
-
final_qa_with_distractors = []
|
| 108 |
-
|
| 109 |
-
for qa in qa_pairs:
|
| 110 |
-
question = qa['question']
|
| 111 |
-
correct_answer = qa['correct_answer']
|
| 112 |
-
# Preserve the original topic information
|
| 113 |
-
topic = qa.get('topic', 'Unknown')
|
| 114 |
-
|
| 115 |
-
distractors = distractor_map.get(question, [])
|
| 116 |
-
|
| 117 |
-
distractors = list(set(d for d in distractors if d.lower() != correct_answer.lower()))
|
| 118 |
-
|
| 119 |
-
if len(distractors) < 3:
|
| 120 |
-
logger.warning(f"Not enough unique distractors for question: '{question}'. Using fallback.")
|
| 121 |
-
potential_fallback_distractors = [ans for ans in all_correct_answers if ans.lower() != correct_answer.lower()]
|
| 122 |
-
random.shuffle(potential_fallback_distractors)
|
| 123 |
-
|
| 124 |
-
needed = 3 - len(distractors)
|
| 125 |
-
distractors.extend(potential_fallback_distractors[:needed])
|
| 126 |
-
|
| 127 |
-
# Keep the topic when adding distractors - preserve both topic fields
|
| 128 |
-
qa['distractors'] = distractors[:3]
|
| 129 |
-
|
| 130 |
-
# Ensure topic is preserved
|
| 131 |
-
topic = qa.get('topic', 'Unknown')
|
| 132 |
-
qa['topic'] = topic # Keep the original topic
|
| 133 |
-
|
| 134 |
-
# Log topic information for debugging
|
| 135 |
-
logger.debug(f"Preserving topic '{topic}' for question: '{question[:30]}...'")
|
| 136 |
-
|
| 137 |
-
final_qa_with_distractors.append(qa)
|
| 138 |
-
|
| 139 |
-
logger.info("Successfully generated and processed distractors.")
|
| 140 |
-
return final_qa_with_distractors
|
| 141 |
-
|
| 142 |
-
except Exception as e:
|
| 143 |
-
logger.error(f"Error generating distractors: {e}", exc_info=True)
|
| 144 |
-
return [dict(qa, distractors=[]) for qa in qa_pairs]
|
| 145 |
-
'''
|
| 146 |
-
# --- Local GloVe+KNN distractor generation implementation with FlashText filtering ---
|
| 147 |
import os
|
| 148 |
import random
|
|
|
|
|
|
|
|
|
|
| 149 |
from .utils import get_logger
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
try:
|
| 154 |
-
import gensim
|
| 155 |
-
from gensim.models import KeyedVectors
|
| 156 |
-
except ImportError:
|
| 157 |
-
gensim = None
|
| 158 |
-
KeyedVectors = None
|
| 159 |
-
logger.warning("gensim is not installed. Please install it with 'pip install gensim' and download GloVe vectors.")
|
| 160 |
-
|
| 161 |
try:
|
| 162 |
-
|
| 163 |
-
except
|
| 164 |
-
|
| 165 |
-
logger.warning("flashtext is not installed. Please install it with 'pip install flashtext'.")
|
| 166 |
|
| 167 |
-
|
| 168 |
-
# This path navigates up from the script's location (/app/app) to the parent (/app) and then into the models directory
|
| 169 |
-
GLOVE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'glove.6B.100d.txt'))
|
| 170 |
-
_glove_model = None
|
| 171 |
-
|
| 172 |
-
def load_glove_model():
|
| 173 |
-
"""Loads the GloVe model from the specified path."""
|
| 174 |
-
global _glove_model
|
| 175 |
-
if _glove_model is None:
|
| 176 |
-
if not gensim or not KeyedVectors:
|
| 177 |
-
raise ImportError("gensim is not installed. Please install it with 'pip install gensim'.")
|
| 178 |
-
if not os.path.exists(GLOVE_PATH):
|
| 179 |
-
# This check is important for debugging within the container
|
| 180 |
-
raise FileNotFoundError(f"GloVe file not found at {GLOVE_PATH}. Ensure it was downloaded and extracted during the Docker build.")
|
| 181 |
-
logger.info(f"Loading GloVe vectors from {GLOVE_PATH}...")
|
| 182 |
-
_glove_model = KeyedVectors.load_word2vec_format(GLOVE_PATH, no_header=True)
|
| 183 |
-
logger.info("GloVe vectors loaded successfully.")
|
| 184 |
-
return _glove_model
|
| 185 |
|
| 186 |
-
|
| 187 |
-
def get_knn_distractors(answer, topn=10):
|
| 188 |
"""
|
| 189 |
-
|
| 190 |
-
This version is improved to handle multi-word answers by focusing on the last word.
|
| 191 |
"""
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
# Try to find synonyms for the last word of the answer phrase
|
| 196 |
-
last_word = answer_lower.split()[-1]
|
| 197 |
-
if last_word in model:
|
| 198 |
-
distractors = [w for w, _ in model.most_similar(last_word, topn=topn + 10) if w.isalpha() and w not in answer_lower.split()]
|
| 199 |
-
if len(distractors) >= topn:
|
| 200 |
-
return distractors[:topn]
|
| 201 |
-
|
| 202 |
-
# Fallback to using the whole answer phrase if the last-word approach fails
|
| 203 |
-
if answer_lower in model:
|
| 204 |
-
distractors = [w for w, _ in model.most_similar(answer_lower, topn=topn + 5) if w.isalpha() and w != answer_lower]
|
| 205 |
-
return distractors[:topn]
|
| 206 |
-
|
| 207 |
-
return []
|
| 208 |
|
| 209 |
-
#
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
async def generate_distractors(qa_pairs):
|
| 226 |
"""
|
| 227 |
-
Generates distractors for MCQ pipeline using
|
| 228 |
-
Returns a list of QA dicts with a 'distractors' key (list of 3 strings).
|
| 229 |
"""
|
| 230 |
-
logger.info(f"Generating distractors for {len(qa_pairs)} QA pairs
|
| 231 |
if not qa_pairs:
|
| 232 |
-
logger.warning("No QA pairs provided to generate distractors.")
|
| 233 |
return []
|
| 234 |
-
try:
|
| 235 |
-
all_correct_answers = [qa['correct_answer'] for qa in qa_pairs]
|
| 236 |
-
final_qa_with_distractors = []
|
| 237 |
-
for qa in qa_pairs:
|
| 238 |
-
question = qa['question']
|
| 239 |
-
correct_answer = qa['correct_answer']
|
| 240 |
-
topic = qa.get('topic', 'Unknown')
|
| 241 |
-
# Fetch more candidates from GloVe to improve the chances of finding good distractors
|
| 242 |
-
distractors = get_knn_distractors(correct_answer, topn=30)
|
| 243 |
-
# Filter with FlashText for quality
|
| 244 |
-
distractors = filter_distractors_with_flashtext(distractors, correct_answer, question)
|
| 245 |
-
# Remove duplicates and the correct answer
|
| 246 |
-
distractors = list(dict.fromkeys([d for d in distractors if d.lower() != correct_answer.lower()]))
|
| 247 |
-
# Fallback: use other correct answers from the same topic as distractors
|
| 248 |
-
if len(distractors) < 3:
|
| 249 |
-
logger.warning(f"Not enough unique distractors for question: '{question}'. Using fallback from same topic.")
|
| 250 |
-
potential_fallback_distractors = [
|
| 251 |
-
qa_other['correct_answer']
|
| 252 |
-
for qa_other in qa_pairs
|
| 253 |
-
if qa_other['topic'] == topic and qa_other['correct_answer'].lower() != correct_answer.lower()
|
| 254 |
-
]
|
| 255 |
-
if not potential_fallback_distractors:
|
| 256 |
-
# If no other answers from the same topic, use any other answer
|
| 257 |
-
logger.warning(f"No other answers in topic '{topic}'. Using fallback from all topics.")
|
| 258 |
-
potential_fallback_distractors = [ans for ans in all_correct_answers if ans.lower() != correct_answer.lower()]
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
+
import nltk
|
| 4 |
+
from nltk.tokenize import sent_tokenize, word_tokenize
|
| 5 |
+
from nltk.util import ngrams
|
| 6 |
from .utils import get_logger
|
| 7 |
|
| 8 |
+
# Download NLTK data if not already present
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
try:
|
| 10 |
+
nltk.data.find('tokenizers/punkt')
|
| 11 |
+
except nltk.downloader.DownloadError:
|
| 12 |
+
nltk.download('punkt')
|
|
|
|
| 13 |
|
| 14 |
+
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
def generate_distractors_from_context(correct_answer, context, num_distractors=3):
|
|
|
|
| 17 |
"""
|
| 18 |
+
Generates phrase-based distractors from the given context.
|
|
|
|
| 19 |
"""
|
| 20 |
+
if not context:
|
| 21 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
# Tokenize the context into sentences and the answer into words
|
| 24 |
+
sentences = sent_tokenize(context)
|
| 25 |
+
answer_words = word_tokenize(correct_answer.lower())
|
| 26 |
+
n = len(answer_words)
|
| 27 |
+
|
| 28 |
+
distractors = []
|
| 29 |
+
for sentence in sentences:
|
| 30 |
+
if correct_answer.lower() not in sentence.lower():
|
| 31 |
+
words = word_tokenize(sentence.lower())
|
| 32 |
+
# Generate n-grams of the same length as the answer
|
| 33 |
+
for ngram in ngrams(words, n):
|
| 34 |
+
distractor_phrase = " ".join(ngram)
|
| 35 |
+
# Avoid adding duplicates
|
| 36 |
+
if distractor_phrase not in distractors:
|
| 37 |
+
distractors.append(distractor_phrase)
|
| 38 |
+
|
| 39 |
+
# Return a random sample of the generated distractors
|
| 40 |
+
random.shuffle(distractors)
|
| 41 |
+
return distractors[:num_distractors]
|
| 42 |
|
| 43 |
async def generate_distractors(qa_pairs):
|
| 44 |
"""
|
| 45 |
+
Generates distractors for MCQ pipeline using a context-aware approach only.
|
|
|
|
| 46 |
"""
|
| 47 |
+
logger.info(f"Generating distractors for {len(qa_pairs)} QA pairs with context-aware strategy.")
|
| 48 |
if not qa_pairs:
|
|
|
|
| 49 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
final_qa_with_distractors = []
|
| 52 |
+
|
| 53 |
+
for qa in qa_pairs:
|
| 54 |
+
correct_answer = qa['correct_answer']
|
| 55 |
+
context = qa.get('context', '')
|
| 56 |
+
|
| 57 |
+
# Generate distractors from context
|
| 58 |
+
distractors = generate_distractors_from_context(correct_answer, context, num_distractors=3)
|
| 59 |
+
|
| 60 |
+
# If not enough distractors are found, the list will be shorter than 3
|
| 61 |
+
if len(distractors) < 3:
|
| 62 |
+
logger.warning(f"Could not generate 3 distractors for '{correct_answer}'. Only found {len(distractors)}.")
|
| 63 |
+
|
| 64 |
+
qa['distractors'] = distractors
|
| 65 |
+
final_qa_with_distractors.append(qa)
|
| 66 |
+
|
| 67 |
+
logger.info("Successfully generated and processed distractors.")
|
| 68 |
+
return final_qa_with_distractors
|
requirements.txt
CHANGED
|
@@ -13,4 +13,5 @@ tqdm==4.67.1
|
|
| 13 |
gunicorn==23.0.0
|
| 14 |
pydantic-settings==2.3.4
|
| 15 |
gensim
|
| 16 |
-
huggingface_hub==0.34.0
|
|
|
|
|
|
| 13 |
gunicorn==23.0.0
|
| 14 |
pydantic-settings==2.3.4
|
| 15 |
gensim
|
| 16 |
+
huggingface_hub==0.34.0
|
| 17 |
+
nltk
|