Spaces:
Runtime error
Runtime error
hari-huynh
commited on
Commit
·
be52c8f
1
Parent(s):
1b8f0b5
Update KG Search Tool
Browse files- tools/kg_search.py +11 -6
tools/kg_search.py
CHANGED
|
@@ -37,12 +37,12 @@ example_selector = SemanticSimilarityExampleSelector.from_examples(
|
|
| 37 |
# Load schema, prefix, suffix
|
| 38 |
with open("prompts/schema.txt", "r") as file:
|
| 39 |
schema = file.read()
|
| 40 |
-
|
| 41 |
with open("prompts/cypher_instruct.yaml", "r") as file:
|
| 42 |
instruct = yaml.safe_load(file)
|
| 43 |
|
| 44 |
example_prompt = PromptTemplate(
|
| 45 |
-
input_variables = ["
|
| 46 |
template = instruct["example_template"]
|
| 47 |
)
|
| 48 |
|
|
@@ -54,6 +54,7 @@ dynamic_prompt = FewShotPromptTemplate(
|
|
| 54 |
input_variables = ["question"]
|
| 55 |
)
|
| 56 |
|
|
|
|
| 57 |
def generate_cypher(question: str) -> str:
|
| 58 |
"""Make Cypher query from given question."""
|
| 59 |
load_dotenv()
|
|
@@ -69,12 +70,14 @@ def generate_cypher(question: str) -> str:
|
|
| 69 |
)
|
| 70 |
|
| 71 |
chat_messages = [
|
| 72 |
-
|
| 73 |
]
|
| 74 |
|
|
|
|
| 75 |
output_parser = StrOutputParser()
|
|
|
|
| 76 |
chain = dynamic_prompt | gemini_chat | output_parser
|
| 77 |
-
cypher_statement = chain.invoke(question)
|
| 78 |
cypher_statement = cypher_statement.replace("```", "").replace("cypher", "").strip()
|
| 79 |
|
| 80 |
return cypher_statement
|
|
@@ -83,6 +86,7 @@ def run_cypher(question, cypher_statement: str) -> str:
|
|
| 83 |
"""Return result of Cypher query from Knowledge Graph."""
|
| 84 |
knowledge_graph = Neo4jGraph()
|
| 85 |
result = knowledge_graph.query(cypher_statement)
|
|
|
|
| 86 |
|
| 87 |
gemini_chat = ChatGoogleGenerativeAI(
|
| 88 |
model= "gemini-1.5-flash-latest"
|
|
@@ -114,11 +118,12 @@ def lookup_kg(question: str) -> str:
|
|
| 114 |
"""
|
| 115 |
cypher_statement = generate_cypher(question)
|
| 116 |
cypher_statement = cypher_statement.replace("cypher", "").replace("```", "").strip()
|
|
|
|
| 117 |
|
| 118 |
try:
|
| 119 |
answer = run_cypher(question, cypher_statement)
|
| 120 |
except:
|
| 121 |
-
answer = "Knowledge graph doesn't have enough information"
|
| 122 |
|
| 123 |
return answer
|
| 124 |
|
|
@@ -137,5 +142,5 @@ if __name__ == "__main__":
|
|
| 137 |
# print(final_result)
|
| 138 |
|
| 139 |
# Test lookup_kg tool
|
| 140 |
-
kg_info = lookup_kg
|
| 141 |
print(kg_info)
|
|
|
|
| 37 |
# Load schema, prefix, suffix
|
| 38 |
with open("prompts/schema.txt", "r") as file:
|
| 39 |
schema = file.read()
|
| 40 |
+
|
| 41 |
with open("prompts/cypher_instruct.yaml", "r") as file:
|
| 42 |
instruct = yaml.safe_load(file)
|
| 43 |
|
| 44 |
example_prompt = PromptTemplate(
|
| 45 |
+
input_variables = ["question_example", "cypher_example"],
|
| 46 |
template = instruct["example_template"]
|
| 47 |
)
|
| 48 |
|
|
|
|
| 54 |
input_variables = ["question"]
|
| 55 |
)
|
| 56 |
|
| 57 |
+
|
| 58 |
def generate_cypher(question: str) -> str:
|
| 59 |
"""Make Cypher query from given question."""
|
| 60 |
load_dotenv()
|
|
|
|
| 70 |
)
|
| 71 |
|
| 72 |
chat_messages = [
|
| 73 |
+
SystemMessage(content= dynamic_prompt.format(question=question)),
|
| 74 |
]
|
| 75 |
|
| 76 |
+
|
| 77 |
output_parser = StrOutputParser()
|
| 78 |
+
cypher_statement = []
|
| 79 |
chain = dynamic_prompt | gemini_chat | output_parser
|
| 80 |
+
cypher_statement = chain.invoke({"question": question})
|
| 81 |
cypher_statement = cypher_statement.replace("```", "").replace("cypher", "").strip()
|
| 82 |
|
| 83 |
return cypher_statement
|
|
|
|
| 86 |
"""Return result of Cypher query from Knowledge Graph."""
|
| 87 |
knowledge_graph = Neo4jGraph()
|
| 88 |
result = knowledge_graph.query(cypher_statement)
|
| 89 |
+
print(f"\nCypher Result:\n{result}")
|
| 90 |
|
| 91 |
gemini_chat = ChatGoogleGenerativeAI(
|
| 92 |
model= "gemini-1.5-flash-latest"
|
|
|
|
| 118 |
"""
|
| 119 |
cypher_statement = generate_cypher(question)
|
| 120 |
cypher_statement = cypher_statement.replace("cypher", "").replace("```", "").strip()
|
| 121 |
+
print(f"\nQuery:\n {cypher_statement}")
|
| 122 |
|
| 123 |
try:
|
| 124 |
answer = run_cypher(question, cypher_statement)
|
| 125 |
except:
|
| 126 |
+
answer = "Knowledge graph doesn't have enough information\n"
|
| 127 |
|
| 128 |
return answer
|
| 129 |
|
|
|
|
| 142 |
# print(final_result)
|
| 143 |
|
| 144 |
# Test lookup_kg tool
|
| 145 |
+
kg_info = lookup_kg(question)
|
| 146 |
print(kg_info)
|