Spaces:
Running
Running
| import re | |
| from typing import Any | |
| from graphgen.bases import BaseGenerator | |
| from graphgen.templates import ATOMIC_GENERATION_PROMPT | |
| from graphgen.utils import compute_content_hash, detect_main_language, logger | |
| class AtomicGenerator(BaseGenerator): | |
| def build_prompt( | |
| batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] | |
| ) -> str: | |
| nodes, edges = batch | |
| context = "" | |
| for node in nodes: | |
| context += f"- {node[0]}: {node[1]['description']}\n" | |
| for edge in edges: | |
| context += f"- {edge[0]} - {edge[1]}: {edge[2]['description']}\n" | |
| language = detect_main_language(context) | |
| prompt = ATOMIC_GENERATION_PROMPT[language].format(context=context) | |
| return prompt | |
| def parse_response(response: str) -> dict: | |
| """ | |
| AtomicGenerator normally generates one QA pair per response. | |
| So we just need to parse one QA pair from the response. | |
| :param response: | |
| :return: | |
| """ | |
| question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL) | |
| answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL) | |
| if question_match and answer_match: | |
| question = question_match.group(1).strip() | |
| answer = answer_match.group(1).strip() | |
| else: | |
| logger.warning("Failed to parse response: %s", response) | |
| return {} | |
| question = question.strip('"').strip("'") | |
| answer = answer.strip('"').strip("'") | |
| logger.debug("Question: %s", question) | |
| logger.debug("Answer: %s", answer) | |
| return { | |
| compute_content_hash(question): { | |
| "question": question, | |
| "answer": answer, | |
| } | |
| } | |