import re from typing import Any from graphgen.bases import BaseGenerator from graphgen.templates import TF_GENERATION_PROMPT from graphgen.utils import detect_main_language, logger class TrueFalseGenerator(BaseGenerator): def __init__(self, llm_client, num_of_questions) -> None: super().__init__(llm_client) self.num_of_questions = num_of_questions @staticmethod def parse_response(response: str) -> list[dict]: """ Parse true/false QA pairs from the LLM response. Each QA pair contains a statement question and True/False answer. :param response: The LLM response containing XML-formatted QA pairs :return: Dictionary mapping question hash to question data, where each value is a dict with "question", "options", and "answer" keys """ qa_pairs: list[dict[str, str]] = [] # Extract all QA pair blocks qa_blocks = re.findall(r"(.*?)", response, re.DOTALL) if not qa_blocks: logger.warning("No QA pairs found in response: %s", response) return qa_pairs for block in qa_blocks: # Extract and clean question text q_match = re.search(r"(.*?)", block, re.DOTALL) if not q_match: logger.warning("Failed to parse question from block: %s", block) continue question = q_match.group(1).strip().strip('"').strip("'") # Extract and validate answer ans_match = re.search(r"(.*?)", block, re.DOTALL) if not ans_match: logger.warning("Failed to parse answer from block: %s", block) continue answer = ans_match.group(1).strip().strip('"').strip("'") # Ensure answer exists in options if answer.lower() not in ["true", "false"]: logger.warning("Invalid answer '%s' in block: %s", answer, block) continue qa_pairs.append( { "question": question, "answer": answer, # "True" or "False" } ) logger.debug("Successfully parsed TF question: %s", question[:50]) if not qa_pairs: logger.error("Failed to parse any valid true/false pairs from response") return qa_pairs # pylint: disable=W0221 def build_prompt( self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] ) -> str: nodes, edges = batch entities_str = "\n".join( [ f"{index + 1}. {node[0]}: {node[1]['description']}" for index, node in enumerate(nodes) ] ) relationships_str = "\n".join( [ f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" for index, edge in enumerate(edges) ] ) context = entities_str + "\n" + relationships_str language = detect_main_language(entities_str + relationships_str) prompt = TF_GENERATION_PROMPT[language].format( context=context, num_of_questions=self.num_of_questions, ) return prompt