GraphGen / graphgen /models /generator /true_false_generator.py
github-actions[bot]
Auto-sync from demo at Thu Jan 29 12:51:48 UTC 2026
0bd1b0f
import re
from typing import Any
from graphgen.bases import BaseGenerator
from graphgen.templates import TF_GENERATION_PROMPT
from graphgen.utils import detect_main_language, logger
class TrueFalseGenerator(BaseGenerator):
def __init__(self, llm_client, num_of_questions) -> None:
super().__init__(llm_client)
self.num_of_questions = num_of_questions
@staticmethod
def parse_response(response: str) -> list[dict]:
"""
Parse true/false QA pairs from the LLM response.
Each QA pair contains a statement question and True/False answer.
:param response: The LLM response containing XML-formatted QA pairs
:return: Dictionary mapping question hash to question data, where each
value is a dict with "question", "options", and "answer" keys
"""
qa_pairs: list[dict[str, str]] = []
# Extract all QA pair blocks
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
if not qa_blocks:
logger.warning("No QA pairs found in response: %s", response)
return qa_pairs
for block in qa_blocks:
# Extract and clean question text
q_match = re.search(r"<question>(.*?)</question>", block, re.DOTALL)
if not q_match:
logger.warning("Failed to parse question from block: %s", block)
continue
question = q_match.group(1).strip().strip('"').strip("'")
# Extract and validate answer
ans_match = re.search(r"<answer>(.*?)</answer>", block, re.DOTALL)
if not ans_match:
logger.warning("Failed to parse answer from block: %s", block)
continue
answer = ans_match.group(1).strip().strip('"').strip("'")
# Ensure answer exists in options
if answer.lower() not in ["true", "false"]:
logger.warning("Invalid answer '%s' in block: %s", answer, block)
continue
qa_pairs.append(
{
"question": question,
"answer": answer, # "True" or "False"
}
)
logger.debug("Successfully parsed TF question: %s", question[:50])
if not qa_pairs:
logger.error("Failed to parse any valid true/false pairs from response")
return qa_pairs
# pylint: disable=W0221
def build_prompt(
self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relationships_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
context = entities_str + "\n" + relationships_str
language = detect_main_language(entities_str + relationships_str)
prompt = TF_GENERATION_PROMPT[language].format(
context=context,
num_of_questions=self.num_of_questions,
)
return prompt