File size: 3,303 Bytes
caa2d9c
 
 
 
 
0bd1b0f
caa2d9c
 
 
 
 
 
 
 
0bd1b0f
caa2d9c
 
 
 
 
 
 
 
0bd1b0f
caa2d9c
 
 
 
 
 
0bd1b0f
caa2d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bd1b0f
 
 
 
 
 
caa2d9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import re
from typing import Any

from graphgen.bases import BaseGenerator
from graphgen.templates import TF_GENERATION_PROMPT
from graphgen.utils import detect_main_language, logger


class TrueFalseGenerator(BaseGenerator):
    def __init__(self, llm_client, num_of_questions) -> None:
        super().__init__(llm_client)
        self.num_of_questions = num_of_questions

    @staticmethod
    def parse_response(response: str) -> list[dict]:
        """
        Parse true/false QA pairs from the LLM response.
        Each QA pair contains a statement question and True/False answer.

        :param response: The LLM response containing XML-formatted QA pairs
        :return: Dictionary mapping question hash to question data, where each
                 value is a dict with "question", "options", and "answer" keys
        """
        qa_pairs: list[dict[str, str]] = []

        # Extract all QA pair blocks
        qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)

        if not qa_blocks:
            logger.warning("No QA pairs found in response: %s", response)
            return qa_pairs

        for block in qa_blocks:
            # Extract and clean question text
            q_match = re.search(r"<question>(.*?)</question>", block, re.DOTALL)
            if not q_match:
                logger.warning("Failed to parse question from block: %s", block)
                continue
            question = q_match.group(1).strip().strip('"').strip("'")

            # Extract and validate answer
            ans_match = re.search(r"<answer>(.*?)</answer>", block, re.DOTALL)
            if not ans_match:
                logger.warning("Failed to parse answer from block: %s", block)
                continue
            answer = ans_match.group(1).strip().strip('"').strip("'")

            # Ensure answer exists in options
            if answer.lower() not in ["true", "false"]:
                logger.warning("Invalid answer '%s' in block: %s", answer, block)
                continue

            qa_pairs.append(
                {
                    "question": question,
                    "answer": answer,  # "True" or "False"
                }
            )

            logger.debug("Successfully parsed TF question: %s", question[:50])

        if not qa_pairs:
            logger.error("Failed to parse any valid true/false pairs from response")

        return qa_pairs

    # pylint: disable=W0221
    def build_prompt(
        self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
    ) -> str:
        nodes, edges = batch
        entities_str = "\n".join(
            [
                f"{index + 1}. {node[0]}: {node[1]['description']}"
                for index, node in enumerate(nodes)
            ]
        )

        relationships_str = "\n".join(
            [
                f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
                for index, edge in enumerate(edges)
            ]
        )
        context = entities_str + "\n" + relationships_str
        language = detect_main_language(entities_str + relationships_str)
        prompt = TF_GENERATION_PROMPT[language].format(
            context=context,
            num_of_questions=self.num_of_questions,
        )
        return prompt