Spaces:
Running
Running
File size: 1,886 Bytes
a8c3e2a 799ac7c a8c3e2a 799ac7c a8c3e2a 0b9d8c7 799ac7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import re
from typing import Any
from graphgen.bases import BaseGenerator
from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language, logger
class MultiHopGenerator(BaseGenerator):
@staticmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relationships_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
language = detect_main_language(entities_str + relationships_str)
prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
entities=entities_str, relationships=relationships_str
)
return prompt
@staticmethod
def parse_response(response: str) -> dict:
question_match = re.search(r"<question>(.*?)</question>", response, re.DOTALL)
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
if question_match and answer_match:
question = question_match.group(1).strip()
answer = answer_match.group(1).strip()
else:
logger.warning("Failed to parse response: %s", response)
return {}
question = question.strip('"').strip("'")
answer = answer.strip('"').strip("'")
logger.debug("Question: %s", question)
logger.debug("Answer: %s", answer)
return {
compute_content_hash(question): {
"question": question,
"answer": answer,
}
}
|