Jay-10020 commited on
Commit
3d6be84
·
1 Parent(s): c8cb4df

MCQ test1

Browse files
Files changed (2) hide show
  1. api/main.py +32 -1
  2. mcq/generator.py +34 -2
api/main.py CHANGED
@@ -473,8 +473,39 @@ async def generate_mcqs(request: MCQGenerateRequest):
473
  else:
474
  raise HTTPException(status_code=400, detail="Invalid source_type")
475
 
476
- # Filter valid MCQs
477
  valid_mcqs = [mcq for mcq in mcqs if mcq_validator.validate_mcq(mcq)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
  return {
480
  "status": "success",
 
473
  else:
474
  raise HTTPException(status_code=400, detail="Invalid source_type")
475
 
476
+ # Filter valid MCQs first.
477
  valid_mcqs = [mcq for mcq in mcqs if mcq_validator.validate_mcq(mcq)]
478
+
479
+ # If strict validation drops too many questions, top up with normalized
480
+ # parsed MCQs so caller still gets requested count.
481
+ if len(valid_mcqs) < request.num_questions:
482
+ for mcq in mcqs:
483
+ if len(valid_mcqs) >= request.num_questions:
484
+ break
485
+ if mcq in valid_mcqs:
486
+ continue
487
+
488
+ question = str(mcq.get("question", "")).strip()
489
+ options_map = mcq.get("options", {}) or {}
490
+ correct = str(mcq.get("correct_answer", "A")).strip().upper()
491
+
492
+ normalized = {
493
+ "question": question,
494
+ "options": {
495
+ "A": str(options_map.get("A", "Option A")),
496
+ "B": str(options_map.get("B", "Option B")),
497
+ "C": str(options_map.get("C", "Option C")),
498
+ "D": str(options_map.get("D", "Option D")),
499
+ },
500
+ "correct_answer": correct if correct in ["A", "B", "C", "D"] else "A",
501
+ "explanation": str(mcq.get("explanation", "Based on the provided context.")),
502
+ "difficulty": str(mcq.get("difficulty", request.difficulty or "medium")).lower(),
503
+ }
504
+
505
+ if normalized["question"]:
506
+ valid_mcqs.append(normalized)
507
+
508
+ valid_mcqs = valid_mcqs[:request.num_questions]
509
 
510
  return {
511
  "status": "success",
mcq/generator.py CHANGED
@@ -39,8 +39,23 @@ class MCQGenerator:
39
 
40
  # Parse MCQs from response
41
  mcqs = self._parse_mcqs_improved(response, text, num_questions)
42
-
43
- # Ensure we return the requested number or fewer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  return mcqs[:num_questions]
45
 
46
  def generate_from_document(
@@ -194,6 +209,7 @@ Now generate {num_questions} questions:
194
  question = lines[0].rstrip('?')
195
  if question.endswith(':'):
196
  question = question[:-1]
 
197
 
198
  if not question or len(question) < 5:
199
  return None
@@ -239,6 +255,22 @@ Now generate {num_questions} questions:
239
  }
240
 
241
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  def _parse_mcqs_fallback(self, response: str) -> List[Dict]:
244
  """Fallback parsing for various formats"""
 
39
 
40
  # Parse MCQs from response
41
  mcqs = self._parse_mcqs_improved(response, text, num_questions)
42
+
43
+ # Retry once with stricter generation if count is short.
44
+ if len(mcqs) < num_questions:
45
+ retry_prompt = prompt + "\nIMPORTANT: Return EXACTLY the requested number of questions in the specified format."
46
+ retry_response = self.llm.generate(
47
+ prompt=retry_prompt,
48
+ max_new_tokens=min(tokens_needed + 250, 1400),
49
+ temperature=0.6
50
+ )
51
+ retry_mcqs = self._parse_mcqs_improved(retry_response, text, num_questions)
52
+ mcqs = self._merge_unique_mcqs(mcqs, retry_mcqs)
53
+
54
+ # Last-resort synthetic top-up so API returns requested count.
55
+ if len(mcqs) < num_questions:
56
+ synthetic = self._generate_synthetic_mcqs(text, num_questions - len(mcqs))
57
+ mcqs = self._merge_unique_mcqs(mcqs, synthetic)
58
+
59
  return mcqs[:num_questions]
60
 
61
  def generate_from_document(
 
209
  question = lines[0].rstrip('?')
210
  if question.endswith(':'):
211
  question = question[:-1]
212
+ question = re.sub(r'^\s*(Q|Question)\s*\d+\s*[:.)-]\s*', '', question, flags=re.IGNORECASE).strip()
213
 
214
  if not question or len(question) < 5:
215
  return None
 
255
  }
256
 
257
  return None
258
+
259
+ def _merge_unique_mcqs(self, base: List[Dict], extra: List[Dict]) -> List[Dict]:
260
+ """Merge MCQ lists and keep unique questions by normalized text."""
261
+ merged = []
262
+ seen = set()
263
+
264
+ for item in (base + extra):
265
+ question = str(item.get('question', '')).strip()
266
+ key = re.sub(r'^\s*(Q|Question)\s*\d+\s*[:.)-]\s*', '', question, flags=re.IGNORECASE).lower()
267
+ if not key or key in seen:
268
+ continue
269
+ seen.add(key)
270
+ item['question'] = re.sub(r'^\s*(Q|Question)\s*\d+\s*[:.)-]\s*', '', question, flags=re.IGNORECASE).strip()
271
+ merged.append(item)
272
+
273
+ return merged
274
 
275
  def _parse_mcqs_fallback(self, response: str) -> List[Dict]:
276
  """Fallback parsing for various formats"""