File size: 4,375 Bytes
5cb2e67
 
 
 
 
0bd1b0f
5cb2e67
 
 
 
 
 
 
 
0bd1b0f
5cb2e67
 
 
 
 
 
 
 
0bd1b0f
5cb2e67
 
 
 
 
 
0bd1b0f
5cb2e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bd1b0f
 
 
 
 
 
 
5cb2e67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import re
from typing import Any

from graphgen.bases import BaseGenerator
from graphgen.templates import MCQ_GENERATION_PROMPT
from graphgen.utils import detect_main_language, logger


class MultiChoiceGenerator(BaseGenerator):
    def __init__(self, llm_client, num_of_questions) -> None:
        super().__init__(llm_client)
        self.num_of_questions = num_of_questions

    @staticmethod
    def parse_response(response: str) -> list[dict]:
        """
        Parse multiple choice QA pairs from the LLM response.
        Each QA pair contains question text, four options, and the correct answer.

        :param response: The LLM response containing XML-formatted QA pairs
        :return: Dictionary mapping question hash to question data, where each
                 value is a dict with "question", "options", and "answer" keys
        """
        qa_pairs = []

        # Extract all QA pair blocks
        qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)

        if not qa_blocks:
            logger.warning("No QA pairs found in response: %s", response)
            return qa_pairs

        for block in qa_blocks:
            # Extract and clean question text
            q_match = re.search(r"<question>(.*?)</question>", block, re.DOTALL)
            if not q_match:
                logger.warning("Failed to parse question from block: %s", block)
                continue
            question = q_match.group(1).strip().strip('"').strip("'")

            # Extract and parse options (A, B, C, D)
            opt_match = re.search(r"<options>(.*?)</options>", block, re.DOTALL)
            if not opt_match:
                logger.warning("Failed to parse options from block: %s", block)
                continue

            options = {}
            options_text = opt_match.group(1).strip()
            for line in options_text.split("\n"):
                line = line.strip()
                if not line:
                    continue
                # Match patterns like "A. text" or "B. text"
                if m := re.match(r"^([A-D])[.\s]\s*(.*)$", line):
                    letter, text = m.groups()
                    options[letter] = text.strip()

            # Validate options count
            if len(options) != 4:
                logger.warning(
                    "Expected 4 options, found %d: %s", len(options), options_text
                )
                continue

            # Extract and validate answer
            ans_match = re.search(r"<answer>(.*?)</answer>", block, re.DOTALL)
            if not ans_match:
                logger.warning("Failed to parse answer from block: %s", block)
                continue
            answer = ans_match.group(1).strip().strip('"').strip("'")

            # Ensure answer exists in options
            if answer not in options:
                logger.warning(
                    "Answer '%s' not found in options: %s", answer, list(options.keys())
                )
                continue

            qa_pairs.append(
                {
                    "question": question,
                    "options": options,  # Dict like {"A": "text", "B": "text", ...}
                    "answer": answer,  # Single letter: "A", "B", "C", or "D"
                }
            )

            logger.debug("Successfully parsed MCQ: %s", question[:50])

        if not qa_pairs:
            logger.error("Failed to parse any valid MCQ pairs from response")

        return qa_pairs

    # pylint: disable=W0221
    def build_prompt(
        self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
    ) -> str:
        nodes, edges = batch
        entities_str = "\n".join(
            [
                f"{index + 1}. {node[0]}: {node[1]['description']}"
                for index, node in enumerate(nodes)
            ]
        )

        relationships_str = "\n".join(
            [
                f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
                for index, edge in enumerate(edges)
            ]
        )
        context = entities_str + "\n" + relationships_str
        language = detect_main_language(entities_str + relationships_str)
        prompt = MCQ_GENERATION_PROMPT[language].format(
            context=context,
            num_of_questions=self.num_of_questions,
        )
        return prompt