Spaces:
Running
Running
| import re | |
| from typing import Any | |
| from graphgen.bases import BaseGenerator | |
| from graphgen.templates import TF_GENERATION_PROMPT | |
| from graphgen.utils import detect_main_language, logger | |
| class TrueFalseGenerator(BaseGenerator): | |
| def __init__(self, llm_client, num_of_questions) -> None: | |
| super().__init__(llm_client) | |
| self.num_of_questions = num_of_questions | |
| def parse_response(response: str) -> list[dict]: | |
| """ | |
| Parse true/false QA pairs from the LLM response. | |
| Each QA pair contains a statement question and True/False answer. | |
| :param response: The LLM response containing XML-formatted QA pairs | |
| :return: Dictionary mapping question hash to question data, where each | |
| value is a dict with "question", "options", and "answer" keys | |
| """ | |
| qa_pairs: list[dict[str, str]] = [] | |
| # Extract all QA pair blocks | |
| qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL) | |
| if not qa_blocks: | |
| logger.warning("No QA pairs found in response: %s", response) | |
| return qa_pairs | |
| for block in qa_blocks: | |
| # Extract and clean question text | |
| q_match = re.search(r"<question>(.*?)</question>", block, re.DOTALL) | |
| if not q_match: | |
| logger.warning("Failed to parse question from block: %s", block) | |
| continue | |
| question = q_match.group(1).strip().strip('"').strip("'") | |
| # Extract and validate answer | |
| ans_match = re.search(r"<answer>(.*?)</answer>", block, re.DOTALL) | |
| if not ans_match: | |
| logger.warning("Failed to parse answer from block: %s", block) | |
| continue | |
| answer = ans_match.group(1).strip().strip('"').strip("'") | |
| # Ensure answer exists in options | |
| if answer.lower() not in ["true", "false"]: | |
| logger.warning("Invalid answer '%s' in block: %s", answer, block) | |
| continue | |
| qa_pairs.append( | |
| { | |
| "question": question, | |
| "answer": answer, # "True" or "False" | |
| } | |
| ) | |
| logger.debug("Successfully parsed TF question: %s", question[:50]) | |
| if not qa_pairs: | |
| logger.error("Failed to parse any valid true/false pairs from response") | |
| return qa_pairs | |
| # pylint: disable=W0221 | |
| def build_prompt( | |
| self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] | |
| ) -> str: | |
| nodes, edges = batch | |
| entities_str = "\n".join( | |
| [ | |
| f"{index + 1}. {node[0]}: {node[1]['description']}" | |
| for index, node in enumerate(nodes) | |
| ] | |
| ) | |
| relationships_str = "\n".join( | |
| [ | |
| f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" | |
| for index, edge in enumerate(edges) | |
| ] | |
| ) | |
| context = entities_str + "\n" + relationships_str | |
| language = detect_main_language(entities_str + relationships_str) | |
| prompt = TF_GENERATION_PROMPT[language].format( | |
| context=context, | |
| num_of_questions=self.num_of_questions, | |
| ) | |
| return prompt | |