Spaces:
Running
Running
File size: 4,375 Bytes
5cb2e67 0bd1b0f 5cb2e67 0bd1b0f 5cb2e67 0bd1b0f 5cb2e67 0bd1b0f 5cb2e67 0bd1b0f 5cb2e67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import re
from typing import Any
from graphgen.bases import BaseGenerator
from graphgen.templates import MCQ_GENERATION_PROMPT
from graphgen.utils import detect_main_language, logger
class MultiChoiceGenerator(BaseGenerator):
def __init__(self, llm_client, num_of_questions) -> None:
super().__init__(llm_client)
self.num_of_questions = num_of_questions
@staticmethod
def parse_response(response: str) -> list[dict]:
"""
Parse multiple choice QA pairs from the LLM response.
Each QA pair contains question text, four options, and the correct answer.
:param response: The LLM response containing XML-formatted QA pairs
:return: Dictionary mapping question hash to question data, where each
value is a dict with "question", "options", and "answer" keys
"""
qa_pairs = []
# Extract all QA pair blocks
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
if not qa_blocks:
logger.warning("No QA pairs found in response: %s", response)
return qa_pairs
for block in qa_blocks:
# Extract and clean question text
q_match = re.search(r"<question>(.*?)</question>", block, re.DOTALL)
if not q_match:
logger.warning("Failed to parse question from block: %s", block)
continue
question = q_match.group(1).strip().strip('"').strip("'")
# Extract and parse options (A, B, C, D)
opt_match = re.search(r"<options>(.*?)</options>", block, re.DOTALL)
if not opt_match:
logger.warning("Failed to parse options from block: %s", block)
continue
options = {}
options_text = opt_match.group(1).strip()
for line in options_text.split("\n"):
line = line.strip()
if not line:
continue
# Match patterns like "A. text" or "B. text"
if m := re.match(r"^([A-D])[.\s]\s*(.*)$", line):
letter, text = m.groups()
options[letter] = text.strip()
# Validate options count
if len(options) != 4:
logger.warning(
"Expected 4 options, found %d: %s", len(options), options_text
)
continue
# Extract and validate answer
ans_match = re.search(r"<answer>(.*?)</answer>", block, re.DOTALL)
if not ans_match:
logger.warning("Failed to parse answer from block: %s", block)
continue
answer = ans_match.group(1).strip().strip('"').strip("'")
# Ensure answer exists in options
if answer not in options:
logger.warning(
"Answer '%s' not found in options: %s", answer, list(options.keys())
)
continue
qa_pairs.append(
{
"question": question,
"options": options, # Dict like {"A": "text", "B": "text", ...}
"answer": answer, # Single letter: "A", "B", "C", or "D"
}
)
logger.debug("Successfully parsed MCQ: %s", question[:50])
if not qa_pairs:
logger.error("Failed to parse any valid MCQ pairs from response")
return qa_pairs
# pylint: disable=W0221
def build_prompt(
self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relationships_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
context = entities_str + "\n" + relationships_str
language = detect_main_language(entities_str + relationships_str)
prompt = MCQ_GENERATION_PROMPT[language].format(
context=context,
num_of_questions=self.num_of_questions,
)
return prompt
|