Himanshu Gangwar commited on
Commit
1c40fe6
·
1 Parent(s): 51bbf4c

update app

Browse files
Files changed (1) hide show
  1. app.py +82 -29
app.py CHANGED
@@ -1,37 +1,35 @@
1
  import os
2
- import streamlit as st
3
  from neo4j import GraphDatabase
4
  from langchain_community.graphs import Neo4jGraph
5
  from dotenv import load_dotenv
6
 
7
- # Load environment variables from .env file for local development
8
  load_dotenv()
9
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class Neo4jConnection:
11
  """
12
- A class to manage the connection to a Neo4j database.
13
- It uses the credentials sourced from Streamlit secrets or a local .env file.
14
  """
15
- def __init__(self):
16
- # Prioritize Streamlit secrets, fall back to .env for local dev
17
- if hasattr(st, 'secrets') and "NEO4J_URI" in st.secrets:
18
- uri = st.secrets["NEO4J_URI"]
19
- user = st.secrets["NEO4J_USER"]
20
- password = st.secrets["NEO4J_PASSWORD"]
21
- print("Connecting to Neo4j using Streamlit secrets.")
22
- else:
23
- uri = os.getenv("NEO4J_URI")
24
- user = os.getenv("NEO4J_USER")
25
- password = os.getenv("NEO4J_PASSWORD")
26
- print("Connecting to Neo4j using local .env file.")
27
-
28
  self._driver = GraphDatabase.driver(uri, auth=(user, password))
29
  try:
30
- # Verify connection
31
  self._driver.verify_connectivity()
32
- print("Connected to Neo4j")
33
  except Exception as e:
34
- print(f"Neo4j connection failed: {e}")
35
 
36
  def close(self):
37
  if self._driver is not None:
@@ -45,23 +43,78 @@ class Neo4jConnection:
45
  try:
46
  session = self._driver.session(database=db) if db is not None else self._driver.session()
47
  response = list(session.run(query, parameters))
 
 
48
  except Exception as e:
49
- print("Query failed:", e)
50
  finally:
51
  if session is not None:
52
  session.close()
53
  return response
54
 
 
55
 
56
- graph = Neo4jGraph(
57
- url=st.secrets["NEO4J_URI"],
58
- username=st.secrets["NEO4J_USER"],
59
- password=st.secrets["NEO4J_PASSWORD"]
60
- )
61
 
62
- # Refresh schema information for the LangChain graph object
63
- # This helps the LLM generate more accurate Cypher queries
 
64
  try:
 
 
 
 
 
65
  graph.refresh_schema()
 
66
  except Exception as e:
67
- print(f"Warning: Could not refresh graph schema. The LLM might generate less accurate queries. Error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gradio as gr
3
  from neo4j import GraphDatabase
4
  from langchain_community.graphs import Neo4jGraph
5
  from dotenv import load_dotenv
6
 
7
+ # 1. Load environment variables
8
  load_dotenv()
9
 
10
+ # Helper function to get env vars with a fallback error
11
+ def get_env_var(key):
12
+ value = os.getenv(key)
13
+ if not value:
14
+ raise ValueError(f"Environment variable '{key}' is missing. Please set it in your .env file.")
15
+ return value
16
+
17
+ NEO4J_URI = get_env_var("NEO4J_URI")
18
+ NEO4J_USER = get_env_var("NEO4J_USER")
19
+ NEO4J_PASSWORD = get_env_var("NEO4J_PASSWORD")
20
+
21
+ # 2. Define the Native Connection Class
22
  class Neo4jConnection:
23
  """
24
+ A class to manage the connection to a Neo4j database using the native driver.
 
25
  """
26
+ def __init__(self, uri, user, password):
 
 
 
 
 
 
 
 
 
 
 
 
27
  self._driver = GraphDatabase.driver(uri, auth=(user, password))
28
  try:
 
29
  self._driver.verify_connectivity()
30
+ print("Native Neo4j Driver: Connected successfully.")
31
  except Exception as e:
32
+ print(f"Native Neo4j Driver connection failed: {e}")
33
 
34
  def close(self):
35
  if self._driver is not None:
 
43
  try:
44
  session = self._driver.session(database=db) if db is not None else self._driver.session()
45
  response = list(session.run(query, parameters))
46
+ # Convert Neo4j records to simple dicts for easy reading in UI
47
+ response = [r.data() for r in response]
48
  except Exception as e:
49
+ return f"Query failed: {e}"
50
  finally:
51
  if session is not None:
52
  session.close()
53
  return response
54
 
55
+ # 3. Initialize Connections (Global State)
56
 
57
+ # Initialize Native Driver
58
+ conn = Neo4jConnection(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
 
 
 
59
 
60
+ # Initialize LangChain Graph
61
+ # Note: In a persistent web server (like Gradio), we usually init this once globally.
62
+ print("Initializing LangChain Graph...")
63
  try:
64
+ graph = Neo4jGraph(
65
+ url=NEO4J_URI,
66
+ username=NEO4J_USER,
67
+ password=NEO4J_PASSWORD
68
+ )
69
  graph.refresh_schema()
70
+ print("LangChain Graph Schema refreshed.")
71
  except Exception as e:
72
+ print(f"Warning: Could not refresh graph schema. Error: {e}")
73
+
74
+
75
+ # 4. Define Application Logic
76
+ def run_cypher(query_text):
77
+ """
78
+ Function triggered by the Gradio UI.
79
+ """
80
+ if not query_text:
81
+ return "Please enter a query."
82
+
83
+ results = conn.query(query_text)
84
+ return results
85
+
86
+ def get_schema_info():
87
+ """
88
+ Returns the LangChain graph schema to display in the UI.
89
+ """
90
+ try:
91
+ return graph.get_schema
92
+ except Exception as e:
93
+ return str(e)
94
+
95
+ # 5. Build Gradio Interface
96
+ with gr.Blocks(title="Neo4j Graph Explorer") as demo:
97
+ gr.Markdown("# Neo4j Connection Demo")
98
+
99
+ with gr.Tab("Execute Cypher"):
100
+ gr.Markdown("Enter a Cypher query to test your connection.")
101
+ cypher_input = gr.Code(language="cypher", label="Cypher Query", value="MATCH (n) RETURN n LIMIT 5")
102
+ run_btn = gr.Button("Run Query", variant="primary")
103
+ json_output = gr.JSON(label="Query Results")
104
+
105
+ run_btn.click(fn=run_cypher, inputs=cypher_input, outputs=json_output)
106
+
107
+ with gr.Tab("Graph Schema"):
108
+ gr.Markdown("View the current schema detected by LangChain.")
109
+ refresh_btn = gr.Button("View Schema")
110
+ schema_output = gr.Textbox(label="Schema Definition", lines=10)
111
+
112
+ refresh_btn.click(fn=get_schema_info, inputs=None, outputs=schema_output)
113
+
114
+ # 6. Launch
115
+ if __name__ == "__main__":
116
+ # Ensure the connection closes when the app stops (optional cleanup)
117
+ try:
118
+ demo.launch()
119
+ finally:
120
+ conn.close()