Spaces:
Paused
Paused
| import json | |
| import re | |
| from openai import OpenAI | |
| PROMPT_TEMPLATE = """ | |
| You are a math MCQ validator for Grade {grade} students at {difficulty} difficulty level. | |
| Context: | |
| - Grade {grade} students have age-appropriate math knowledge. | |
| - Difficulty level is {difficulty} β ensure the question and answer reflect this level. | |
| Task: | |
| Validate that the correct answer matches the choices, and that the question | |
| is appropriate for Grade {grade} at {difficulty} difficulty. | |
| Rules: | |
| 1. Check the value of the correct answer. | |
| 2. Verify that this value exists in the choices. | |
| 3. Ensure the question complexity suits Grade {grade} at {difficulty} difficulty. | |
| Cases: | |
| Case 1 β Valid: | |
| If the correct answer value exists in the choices AND the correct_answer letter points to that value, | |
| return the JSON unchanged. | |
| Case 2 β Wrong answer letter: | |
| If the correct answer value exists in the choices BUT the correct_answer letter is incorrect, | |
| update correct_answer to the letter that corresponds to the correct value. | |
| Case 3 β Correct value missing: | |
| If the correct answer value does NOT exist in the choices, | |
| replace one incorrect choice with the correct value and assign the correct_answer letter to that choice. | |
| Case 4 β Question and answer both incorrect: | |
| If the question is not understandable or not appropriate for Grade {grade} at {difficulty} difficulty, | |
| rewrite it to be a clear, grade-appropriate question, set the correct answer, and ensure it exists in the choices. | |
| Constraints: | |
| - Keep exactly four choices (A, B, C, D). | |
| - Choices must remain numbers. | |
| - Return ONLY valid JSON. | |
| Input JSON: | |
| {mcq_json} | |
| """ | |
| class MCQValidator: | |
| """ | |
| Parses raw model output text into structured MCQ JSON, | |
| then validates and corrects it using GPT. | |
| """ | |
| def __init__(self, key_string: str, model: str = "gpt-5-nano"): | |
| self.client = OpenAI(api_key=key_string) | |
| self.model = model | |
| # ------------------------------------------------------------------ | |
| # STEP 1: Parse raw model output into structured dict | |
| # ------------------------------------------------------------------ | |
| def _extract_choices(self, text: str) -> dict: | |
| """ | |
| Try multiple strategies to extract choices. | |
| Normalizes all keys to uppercase A-D and values to rounded floats. | |
| Supported formats: | |
| - Formatted list: A) 2 or A) 0.33 | |
| - JSON letter key: "A": 2.0 | |
| - Verbose key: "choice a": 0.418... | |
| """ | |
| letter_map = { | |
| 'a': 'A', 'b': 'B', 'c': 'C', 'd': 'D', | |
| 'A': 'A', 'B': 'B', 'C': 'C', 'D': 'D', | |
| } | |
| # Strategy 1: formatted list "A) 2" | |
| matches = re.findall(r"([A-D])\)\s*(-?\d+(?:\.\d+)?)", text) | |
| if len(matches) >= 2: | |
| return {letter_map[k]: round(float(v), 4) for k, v in matches} | |
| # Strategy 2: JSON letter key "A": 2.0 | |
| matches = re.findall(r'"([A-Da-d])"\s*:\s*(-?\d+(?:\.\d+)?)', text) | |
| if len(matches) >= 2: | |
| return {letter_map[k]: round(float(v), 4) for k, v in matches} | |
| # Strategy 3: verbose key "choice a": 0.418... | |
| matches = re.findall(r'"[\s]*choice\s+([a-dA-D])"\s*:\s*(-?\d+(?:\.\d+)?)', text) | |
| if len(matches) >= 2: | |
| return {letter_map[k]: round(float(v), 4) for k, v in matches} | |
| return {} | |
| def _extract_correct_answer(self, text: str, choices: dict = None) -> str | None: | |
| """ | |
| Extract correct_answer letter, handling formats: | |
| - "correct_answer": "D" | |
| - "correct_answer": "choice b" | |
| - Fallback: evaluate arithmetic expression in question and match to choices | |
| (handles truncated output where correct_answer field is missing) | |
| Always returns uppercase A-D or None. | |
| """ | |
| letter_map = {'a': 'A', 'b': 'B', 'c': 'C', 'd': 'D'} | |
| # Format 1: "correct_answer": "D" | |
| m = re.search(r'"correct_answer"\s*:\s*"([A-D])"', text) | |
| if m: | |
| return m.group(1) | |
| # Format 2: "correct_answer": "choice b" | |
| m = re.search(r'"correct_answer"\s*:\s*"[\s]*choice\s+([a-dA-D])"', text) | |
| if m: | |
| return letter_map.get(m.group(1).lower()) | |
| # Format 3: Fallback for truncated output β correct_answer field never appeared. | |
| # Evaluate the arithmetic expression in the question and match to choices. | |
| if choices: | |
| q_match = re.search(r'"question":\s*"(.*?)"', text) | |
| if q_match: | |
| question_text = q_match.group(1) | |
| expr = re.search(r'(\d+)\s*([\+\-\*\/])\s*(\d+)', question_text) | |
| if expr: | |
| a, op, b = expr.groups() | |
| computed = None | |
| if op == '+': | |
| computed = int(a) + int(b) | |
| elif op == '-': | |
| computed = int(a) - int(b) | |
| elif op == '*': | |
| computed = int(a) * int(b) | |
| elif op == '/' and int(b) != 0: | |
| computed = round(int(a) / int(b), 4) | |
| if computed is not None: | |
| for letter, val in choices.items(): | |
| if round(float(val), 4) == round(float(computed), 4): | |
| return letter | |
| return None | |
| def parse_raw_output(self, text: str) -> dict: | |
| """ | |
| Extract question, choices, and correct_answer from raw model output string. | |
| Returns a dict ready for validation. | |
| Raises ValueError if any field cannot be extracted. | |
| """ | |
| # Extract question from JSON block | |
| question_match = re.search(r'"question":\s*"(.*?)"', text) | |
| if not question_match: | |
| raise ValueError("Could not extract 'question' from model output.") | |
| question = question_match.group(1).strip() | |
| # Extract choices using multi-strategy parser | |
| choices = self._extract_choices(text) | |
| if len(choices) < 2: | |
| raise ValueError(f"Could not extract choices. Found: {choices}") | |
| # Extract correct answer β pass choices for arithmetic fallback | |
| correct = self._extract_correct_answer(text, choices=choices) | |
| if not correct: | |
| raise ValueError("Could not extract 'correct_answer' from model output.") | |
| return { | |
| "question": question, | |
| "choices": choices, | |
| "correct_answer": correct | |
| } | |
| def validate_with_gpt(self, mcq_dict: dict, grade: int = 3, difficulty: str = "easy") -> dict: | |
| """ | |
| Send the parsed MCQ dict to GPT for validation and correction. | |
| Returns a validated/corrected MCQ dict. | |
| """ | |
| mcq_json_str = json.dumps(mcq_dict, indent=2) | |
| prompt = PROMPT_TEMPLATE.format(mcq_json=mcq_json_str, grade=grade, difficulty=difficulty) | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| raw_response = response.choices[0].message.content.strip() | |
| return self._parse_gpt_response(raw_response, fallback=mcq_dict) | |
| def _parse_gpt_response(self, text: str, fallback: dict) -> dict: | |
| """ | |
| Robustly parse GPT JSON response. | |
| Falls back to the original parsed MCQ if GPT output cannot be parsed. | |
| """ | |
| # Strip markdown fences if present | |
| text = re.sub(r"```json|```", "", text).strip() | |
| try: | |
| parsed = json.loads(text) | |
| if "choices" in parsed: | |
| parsed["choices"] = {k: round(float(v), 4) for k, v in parsed["choices"].items()} | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| # Try extracting JSON block | |
| try: | |
| start = text.find('{') | |
| end = text.rfind('}') + 1 | |
| if start != -1 and end > start: | |
| parsed = json.loads(text[start:end]) | |
| if "choices" in parsed: | |
| parsed["choices"] = {k: round(float(v), 4) for k, v in parsed["choices"].items()} | |
| return parsed | |
| except Exception: | |
| pass | |
| return fallback | |
| GPT_FALLBACK_PROMPT = """ | |
| You are a math MCQ generator for Grade {grade} students at {difficulty} difficulty level. | |
| The fine-tuned model failed to produce a parseable question for the topic: {topic}. | |
| Generate ONE valid math MCQ appropriate for Grade {grade} at {difficulty} difficulty. | |
| Rules: | |
| - The question must directly test the topic: {topic} | |
| - Keep the question simple and age-appropriate for Grade {grade} | |
| - Provide exactly 4 numeric answer choices labeled A, B, C, D | |
| - Only one choice must be the correct answer | |
| - Return ONLY valid JSON in exactly this format: | |
| {{ | |
| "question": "<question text>", | |
| "choices": {{"A": <number>, "B": <number>, "C": <number>, "D": <number>}}, | |
| "correct_answer": "<letter>" | |
| }} | |
| """ | |
| def _generate_fallback_question(self, topic: str, grade: int, difficulty: str) -> dict: | |
| """ | |
| Called when parsing fails. Generates a fresh MCQ via GPT | |
| in the same format as the fine-tuned model output. | |
| """ | |
| prompt = self.GPT_FALLBACK_PROMPT.format( | |
| topic=topic, | |
| grade=grade, | |
| difficulty=difficulty, | |
| ) | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| raw = response.choices[0].message.content.strip() | |
| return self._parse_gpt_response(raw, fallback={ | |
| "error": "GPT fallback also failed to produce valid JSON", | |
| "question": f"What is a basic {topic} problem?", | |
| "choices": {"A": 1, "B": 2, "C": 3, "D": 4}, | |
| "correct_answer": "A", | |
| }) | |
| def validate(self, raw_model_output: str, grade: int = 3, difficulty: str = "easy", topic: str = "") -> dict: | |
| """ | |
| Full pipeline: | |
| 1. Parse raw model output text into structured MCQ dict | |
| 2. Validate and correct via GPT using grade and difficulty context | |
| 3. Return final validated MCQ dict | |
| Args: | |
| raw_model_output: The string returned by infer_question_gen() | |
| grade: Grade level (e.g. 3, 4, 5) | |
| difficulty: Difficulty level (e.g. "easy", "medium", "hard") | |
| Returns: | |
| Validated MCQ dict with keys: question, choices, correct_answer | |
| """ | |
| try: | |
| mcq_dict = self.parse_raw_output(raw_model_output) | |
| except ValueError: | |
| # Parsing failed β fine-tuned model output was unusable. | |
| # Fall back to GPT to generate a fresh question for the same topic/grade/difficulty. | |
| return self._generate_fallback_question(topic=topic, grade=grade, difficulty=difficulty) | |
| validated = self.validate_with_gpt(mcq_dict, grade=grade, difficulty=difficulty) | |
| return validated | |
| TOPIC_IMPROVEMENTS_PROMPT = """ | |
| You are an educational content reviewer for elementary school mathematics. | |
| A question-topic matching model returned a score but gave no explanation. | |
| Your job is to write the improvements list explaining why the question does not perfectly match the topic. | |
| Use this exact style from examples: | |
| - Score 0.75 β "This question is somewhat related to {topic} but does not focus on its core concept. It [what the question actually tests] rather than {topic}." | |
| - Score 0.5 β "This question is partially related to the topic. It [what the question actually tests] rather than addressing {topic} directly." | |
| - Score 0.0 β "This question does not match the topic because it [what the question actually tests] instead of {topic}." | |
| Rules: | |
| - Write 1 clear improvement sentence explaining the mismatch. | |
| - Be specific about what the question actually tests vs what the topic expects. | |
| - Return ONLY valid JSON in this exact format: | |
| {{"improvements": ["<your improvement sentence here>"]}} | |
| Input: | |
| Topic: {topic} | |
| Grade: {grade} | |
| Question: {question} | |
| Matching Score: {score} | |
| """ | |
| class TopicMatchValidator: | |
| """ | |
| Validates output from the question-topic matching model. | |
| If score < 1.0 and improvements is empty, calls GPT to generate them. | |
| """ | |
| def __init__(self, key_string: str, model: str = "gpt-5-nano"): | |
| self.client = OpenAI(api_key=key_string) | |
| self.model = model | |
| def _needs_improvements(self, result: dict) -> bool: | |
| """ | |
| Returns True if the model returned a non-perfect score | |
| but left the improvements list empty. | |
| """ | |
| score = result.get("matching_score", 1.0) | |
| improvements = result.get("improvements", []) | |
| return score < 1.0 and (not improvements or improvements == []) | |
| def _generate_improvements(self, topic: str, grade: int, question: str, score: float) -> list: | |
| """ | |
| Calls GPT to generate improvements in the dataset style. | |
| Returns a list with one improvement string. | |
| """ | |
| prompt = TOPIC_IMPROVEMENTS_PROMPT.format( | |
| topic=topic, | |
| grade=grade, | |
| question=question, | |
| score=score, | |
| ) | |
| response = self.client.chat.completions.create( | |
| model=self.model, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| raw = response.choices[0].message.content.strip() | |
| return self._parse_improvements_response(raw) | |
| def _parse_improvements_response(self, text: str) -> list: | |
| """ | |
| Safely parse GPT improvements response. | |
| Returns a list of improvement strings. | |
| """ | |
| text = re.sub(r"```json|```", "", text).strip() | |
| try: | |
| parsed = json.loads(text) | |
| if "improvements" in parsed and isinstance(parsed["improvements"], list): | |
| return parsed["improvements"] | |
| except json.JSONDecodeError: | |
| pass | |
| # Try extracting JSON block | |
| try: | |
| start = text.find('{') | |
| end = text.rfind('}') + 1 | |
| if start != -1 and end > start: | |
| parsed = json.loads(text[start:end]) | |
| if "improvements" in parsed and isinstance(parsed["improvements"], list): | |
| return parsed["improvements"] | |
| except Exception: | |
| pass | |
| # Last resort: return the raw text as a single improvement | |
| if text: | |
| return [text[:300]] | |
| return ["The question does not sufficiently match the stated topic."] | |
| def validate(self, result: dict, topic: str, grade: int, question: str) -> dict: | |
| """ | |
| Validates the topic match result. If score < 1.0 and improvements | |
| is empty, generates improvements via GPT. | |
| Args: | |
| result: The dict returned by evaluate_question_topic_match() | |
| topic: The topic string from the original request | |
| grade: The grade level from the original request | |
| question: The question string from the original request | |
| Returns: | |
| The result dict, guaranteed to have improvements if score < 1.0 | |
| """ | |
| if self._needs_improvements(result): | |
| score = result.get("matching_score", 0.0) | |
| result["improvements"] = self._generate_improvements(topic, grade, question, score) | |
| return result |