ayushmodi001 commited on
Commit
2eedf0d
·
1 Parent(s): ab1217c

updated distractor generator

Browse files
Files changed (2) hide show
  1. app/distractor_generator.py +51 -255
  2. requirements.txt +2 -1
app/distractor_generator.py CHANGED
@@ -1,272 +1,68 @@
1
- '''
2
- import google.generativeai as genai
3
- from .config import GOOGLE_API_KEY, GEMINI_MODEL_NAME, DISTRACTOR_BATCH_SIZE, MAX_PARALLEL_DISTRACTOR_BATCHES
4
- from .utils import get_logger, parse_json_from_response
5
- import json
6
- import random
7
- import asyncio
8
- import time
9
-
10
- logger = get_logger(__name__)
11
-
12
- genai.configure(api_key=GOOGLE_API_KEY)
13
-
14
- async def generate_distractors(qa_pairs):
15
- logger.info(f"Generating distractors for {len(qa_pairs)} QA pairs.")
16
-
17
- if not qa_pairs:
18
- logger.warning("No QA pairs provided to generate distractors.")
19
- return []
20
-
21
- try:
22
- model = genai.GenerativeModel(GEMINI_MODEL_NAME)
23
-
24
- example_input = [{
25
- "question": "What is the primary goal of the A* search algorithm?",
26
- "correct_answer": "To find the shortest path between a start and end node",
27
- "context": "The A* search algorithm is widely used in pathfinding and graph traversal to find the shortest path between a start and end node."
28
- }]
29
- example_output = [{
30
- "question": "What is the primary goal of the A* search algorithm?",
31
- "correct_answer": "To find the shortest path between a start and end node",
32
- "distractors": [
33
- "To find the longest path in a graph",
34
- "To visit every node in a graph"
35
- ]
36
- }]
37
-
38
- # Optimize batch size based on number of QA pairs
39
- if len(qa_pairs) <= 5:
40
- BATCH_SIZE = 2
41
- elif len(qa_pairs) <= 10:
42
- BATCH_SIZE = 3
43
- else:
44
- BATCH_SIZE = DISTRACTOR_BATCH_SIZE
45
-
46
- batches = []
47
- all_distractors_data = []
48
-
49
- # Create batches of questions
50
- for i in range(0, len(qa_pairs), BATCH_SIZE):
51
- batches.append(qa_pairs[i:i + BATCH_SIZE])
52
-
53
- logger.info(f"Processing distractor generation in {len(batches)} parallel batches")
54
-
55
- # Define an async function to process a single batch
56
- async def process_batch(batch_index, batch_data):
57
- start_time = time.time()
58
- batch_id = f"batch-{batch_index+1}"
59
- logger.info(f"Starting distractor generation for {batch_id} with {len(batch_data)} questions")
60
-
61
- prompt = f"""Generate three plausible but incorrect distractors for each of the following question-answer pairs.
62
- The distractors must be conceptually related to the correct answer but clearly wrong. They should be of the same entity type as the correct answer. For example, if the answer is a specific algorithm, the distractors should be other algorithm names.
63
-
64
- Return the output as a single, valid JSON array of objects. Do NOT include any text or markdown formatting outside of the JSON array. Each object must have three keys: "question", "correct_answer", and "distractors" (a list of 3 strings).
65
-
66
- EXAMPLE INPUT:
67
- {json.dumps(example_input, indent=2)}
68
-
69
- EXAMPLE OUTPUT:
70
- {json.dumps(example_output, indent=2)}
71
-
72
- ACTUAL INPUT:
73
- {json.dumps(batch_data, indent=2)}
74
- """
75
- try:
76
- response = await model.generate_content_async(prompt)
77
- distractors_data = await parse_json_from_response(response.text, logger)
78
- processing_time = time.time() - start_time
79
- logger.info(f"Completed distractor generation for {batch_id} in {processing_time:.2f}s")
80
- return distractors_data or []
81
- except Exception as e:
82
- logger.error(f"Error processing distractor batch {batch_id}: {e}", exc_info=True)
83
- return []
84
-
85
- # Process all batches in parallel, but limit concurrency to avoid overloading
86
- semaphore = asyncio.Semaphore(MAX_PARALLEL_DISTRACTOR_BATCHES)
87
-
88
- async def process_with_semaphore(idx, batch):
89
- async with semaphore:
90
- return await process_batch(idx, batch)
91
-
92
- batch_results = await asyncio.gather(*[process_with_semaphore(i, batches[i]) for i in range(len(batches))], return_exceptions=True)
93
-
94
- # Combine results
95
- for result in batch_results:
96
- if isinstance(result, Exception):
97
- logger.error(f"Exception during distractor generation: {result}", exc_info=True)
98
- continue
99
- all_distractors_data.extend(result)
100
-
101
- if not all_distractors_data:
102
- raise ValueError("Failed to parse any JSON from Gemini responses.")
103
-
104
- distractor_map = {item['question']: item['distractors'] for item in all_distractors_data if 'question' in item and 'distractors' in item}
105
-
106
- all_correct_answers = [qa['correct_answer'] for qa in qa_pairs]
107
- final_qa_with_distractors = []
108
-
109
- for qa in qa_pairs:
110
- question = qa['question']
111
- correct_answer = qa['correct_answer']
112
- # Preserve the original topic information
113
- topic = qa.get('topic', 'Unknown')
114
-
115
- distractors = distractor_map.get(question, [])
116
-
117
- distractors = list(set(d for d in distractors if d.lower() != correct_answer.lower()))
118
-
119
- if len(distractors) < 3:
120
- logger.warning(f"Not enough unique distractors for question: '{question}'. Using fallback.")
121
- potential_fallback_distractors = [ans for ans in all_correct_answers if ans.lower() != correct_answer.lower()]
122
- random.shuffle(potential_fallback_distractors)
123
-
124
- needed = 3 - len(distractors)
125
- distractors.extend(potential_fallback_distractors[:needed])
126
-
127
- # Keep the topic when adding distractors - preserve both topic fields
128
- qa['distractors'] = distractors[:3]
129
-
130
- # Ensure topic is preserved
131
- topic = qa.get('topic', 'Unknown')
132
- qa['topic'] = topic # Keep the original topic
133
-
134
- # Log topic information for debugging
135
- logger.debug(f"Preserving topic '{topic}' for question: '{question[:30]}...'")
136
-
137
- final_qa_with_distractors.append(qa)
138
-
139
- logger.info("Successfully generated and processed distractors.")
140
- return final_qa_with_distractors
141
-
142
- except Exception as e:
143
- logger.error(f"Error generating distractors: {e}", exc_info=True)
144
- return [dict(qa, distractors=[]) for qa in qa_pairs]
145
- '''
146
- # --- Local GloVe+KNN distractor generation implementation with FlashText filtering ---
147
  import os
148
  import random
 
 
 
149
  from .utils import get_logger
150
 
151
- logger = get_logger(__name__)
152
-
153
- try:
154
- import gensim
155
- from gensim.models import KeyedVectors
156
- except ImportError:
157
- gensim = None
158
- KeyedVectors = None
159
- logger.warning("gensim is not installed. Please install it with 'pip install gensim' and download GloVe vectors.")
160
-
161
  try:
162
- from flashtext import KeywordProcessor
163
- except ImportError:
164
- KeywordProcessor = None
165
- logger.warning("flashtext is not installed. Please install it with 'pip install flashtext'.")
166
 
167
- # Construct a robust, absolute path to the GloVe model
168
- # This path navigates up from the script's location (/app/app) to the parent (/app) and then into the models directory
169
- GLOVE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'models', 'glove.6B.100d.txt'))
170
- _glove_model = None
171
-
172
- def load_glove_model():
173
- """Loads the GloVe model from the specified path."""
174
- global _glove_model
175
- if _glove_model is None:
176
- if not gensim or not KeyedVectors:
177
- raise ImportError("gensim is not installed. Please install it with 'pip install gensim'.")
178
- if not os.path.exists(GLOVE_PATH):
179
- # This check is important for debugging within the container
180
- raise FileNotFoundError(f"GloVe file not found at {GLOVE_PATH}. Ensure it was downloaded and extracted during the Docker build.")
181
- logger.info(f"Loading GloVe vectors from {GLOVE_PATH}...")
182
- _glove_model = KeyedVectors.load_word2vec_format(GLOVE_PATH, no_header=True)
183
- logger.info("GloVe vectors loaded successfully.")
184
- return _glove_model
185
 
186
- # Get KNN-based distractors from GloVe
187
- def get_knn_distractors(answer, topn=10):
188
  """
189
- Gets KNN-based distractors from the loaded GloVe model.
190
- This version is improved to handle multi-word answers by focusing on the last word.
191
  """
192
- model = load_glove_model()
193
- answer_lower = answer.lower()
194
-
195
- # Try to find synonyms for the last word of the answer phrase
196
- last_word = answer_lower.split()[-1]
197
- if last_word in model:
198
- distractors = [w for w, _ in model.most_similar(last_word, topn=topn + 10) if w.isalpha() and w not in answer_lower.split()]
199
- if len(distractors) >= topn:
200
- return distractors[:topn]
201
-
202
- # Fallback to using the whole answer phrase if the last-word approach fails
203
- if answer_lower in model:
204
- distractors = [w for w, _ in model.most_similar(answer_lower, topn=topn + 5) if w.isalpha() and w != answer_lower]
205
- return distractors[:topn]
206
-
207
- return []
208
 
209
- # Use FlashText to filter out distractors that are too similar to the answer or question
210
- def filter_distractors_with_flashtext(distractors, answer, question):
211
- """
212
- Filters distractors to remove any that contain the correct answer.
213
- The filtering of question words has been removed to be less aggressive.
214
- """
215
- if not KeywordProcessor:
216
- return [d for d in distractors if d.lower() != answer.lower() and d.lower() not in question.lower()]
217
- kp = KeywordProcessor(case_sensitive=False)
218
- kp.add_keyword(answer)
219
- # The aggressive filtering of question words has been removed to improve the number of distractors found.
220
- # for word in question.split():
221
- # kp.add_keyword(word)
222
- filtered = [d for d in distractors if not kp.extract_keywords(d)]
223
- return filtered
 
 
 
 
224
 
225
  async def generate_distractors(qa_pairs):
226
  """
227
- Generates distractors for MCQ pipeline using local GloVe embeddings and FlashText filtering.
228
- Returns a list of QA dicts with a 'distractors' key (list of 3 strings).
229
  """
230
- logger.info(f"Generating distractors for {len(qa_pairs)} QA pairs (local GloVe + FlashText mode).")
231
  if not qa_pairs:
232
- logger.warning("No QA pairs provided to generate distractors.")
233
  return []
234
- try:
235
- all_correct_answers = [qa['correct_answer'] for qa in qa_pairs]
236
- final_qa_with_distractors = []
237
- for qa in qa_pairs:
238
- question = qa['question']
239
- correct_answer = qa['correct_answer']
240
- topic = qa.get('topic', 'Unknown')
241
- # Fetch more candidates from GloVe to improve the chances of finding good distractors
242
- distractors = get_knn_distractors(correct_answer, topn=30)
243
- # Filter with FlashText for quality
244
- distractors = filter_distractors_with_flashtext(distractors, correct_answer, question)
245
- # Remove duplicates and the correct answer
246
- distractors = list(dict.fromkeys([d for d in distractors if d.lower() != correct_answer.lower()]))
247
- # Fallback: use other correct answers from the same topic as distractors
248
- if len(distractors) < 3:
249
- logger.warning(f"Not enough unique distractors for question: '{question}'. Using fallback from same topic.")
250
- potential_fallback_distractors = [
251
- qa_other['correct_answer']
252
- for qa_other in qa_pairs
253
- if qa_other['topic'] == topic and qa_other['correct_answer'].lower() != correct_answer.lower()
254
- ]
255
- if not potential_fallback_distractors:
256
- # If no other answers from the same topic, use any other answer
257
- logger.warning(f"No other answers in topic '{topic}'. Using fallback from all topics.")
258
- potential_fallback_distractors = [ans for ans in all_correct_answers if ans.lower() != correct_answer.lower()]
259
 
260
- random.shuffle(potential_fallback_distractors)
261
-
262
- needed = 3 - len(distractors)
263
- distractors.extend(potential_fallback_distractors[:needed])
264
- qa['distractors'] = distractors[:3]
265
- qa['topic'] = topic
266
- logger.debug(f"Preserving topic '{topic}' for question: '{question[:30]}...'")
267
- final_qa_with_distractors.append(qa)
268
- logger.info("Successfully generated and processed distractors (local mode, MCQ pipeline ready).")
269
- return final_qa_with_distractors
270
- except Exception as e:
271
- logger.error(f"Error generating distractors: {e}", exc_info=True)
272
- return [dict(qa, distractors=[]) for qa in qa_pairs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import random
3
+ import nltk
4
+ from nltk.tokenize import sent_tokenize, word_tokenize
5
+ from nltk.util import ngrams
6
  from .utils import get_logger
7
 
8
+ # Download NLTK data if not already present
 
 
 
 
 
 
 
 
 
9
  try:
10
+ nltk.data.find('tokenizers/punkt')
11
+ except nltk.downloader.DownloadError:
12
+ nltk.download('punkt')
 
13
 
14
+ logger = get_logger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def generate_distractors_from_context(correct_answer, context, num_distractors=3):
 
17
  """
18
+ Generates phrase-based distractors from the given context.
 
19
  """
20
+ if not context:
21
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Tokenize the context into sentences and the answer into words
24
+ sentences = sent_tokenize(context)
25
+ answer_words = word_tokenize(correct_answer.lower())
26
+ n = len(answer_words)
27
+
28
+ distractors = []
29
+ for sentence in sentences:
30
+ if correct_answer.lower() not in sentence.lower():
31
+ words = word_tokenize(sentence.lower())
32
+ # Generate n-grams of the same length as the answer
33
+ for ngram in ngrams(words, n):
34
+ distractor_phrase = " ".join(ngram)
35
+ # Avoid adding duplicates
36
+ if distractor_phrase not in distractors:
37
+ distractors.append(distractor_phrase)
38
+
39
+ # Return a random sample of the generated distractors
40
+ random.shuffle(distractors)
41
+ return distractors[:num_distractors]
42
 
43
  async def generate_distractors(qa_pairs):
44
  """
45
+ Generates distractors for MCQ pipeline using a context-aware approach only.
 
46
  """
47
+ logger.info(f"Generating distractors for {len(qa_pairs)} QA pairs with context-aware strategy.")
48
  if not qa_pairs:
 
49
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ final_qa_with_distractors = []
52
+
53
+ for qa in qa_pairs:
54
+ correct_answer = qa['correct_answer']
55
+ context = qa.get('context', '')
56
+
57
+ # Generate distractors from context
58
+ distractors = generate_distractors_from_context(correct_answer, context, num_distractors=3)
59
+
60
+ # If not enough distractors are found, the list will be shorter than 3
61
+ if len(distractors) < 3:
62
+ logger.warning(f"Could not generate 3 distractors for '{correct_answer}'. Only found {len(distractors)}.")
63
+
64
+ qa['distractors'] = distractors
65
+ final_qa_with_distractors.append(qa)
66
+
67
+ logger.info("Successfully generated and processed distractors.")
68
+ return final_qa_with_distractors
requirements.txt CHANGED
@@ -13,4 +13,5 @@ tqdm==4.67.1
13
  gunicorn==23.0.0
14
  pydantic-settings==2.3.4
15
  gensim
16
- huggingface_hub==0.34.0
 
 
13
  gunicorn==23.0.0
14
  pydantic-settings==2.3.4
15
  gensim
16
+ huggingface_hub==0.34.0
17
+ nltk