jaibadachiya commited on
Commit
60413b9
Β·
verified Β·
1 Parent(s): 743d3f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -48
app.py CHANGED
@@ -1,12 +1,17 @@
1
- # app.py
2
  import streamlit as st
3
  import spacy
4
  import subprocess
5
  from neo4j import GraphDatabase
6
  import matplotlib.pyplot as plt
7
  import networkx as nx
 
 
 
8
 
9
- # === Ensure spaCy model is installed ===
 
 
 
10
  def install_spacy_model():
11
  try:
12
  spacy.load("en_core_web_sm")
@@ -14,55 +19,118 @@ def install_spacy_model():
14
  subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
15
  install_spacy_model()
16
 
17
- # Load the model after ensuring it's installed
18
  nlp = spacy.load("en_core_web_sm")
19
 
20
  # Neo4j credentials
21
  uri = "neo4j+s://ff701b1c.databases.neo4j.io"
22
  username = "neo4j"
23
- password = "BfZM7YRKpFz1b_V7acAmOtaSQHPU9xK03rJlfPep88g"
24
-
25
- # Connect to Neo4j
26
- driver = GraphDatabase.driver(uri, auth=(username, password))
27
-
28
- # Triple extraction function
29
- def extract_triples(text):
30
- doc = nlp(text)
31
- triples = []
32
- for sent in doc.sents:
33
- subjects = [tok for tok in sent if "subj" in tok.dep_]
34
- verbs = [tok for tok in sent if tok.pos_ == "VERB"]
35
- objects = [tok for tok in sent if "obj" in tok.dep_]
36
- for subj in subjects:
37
- for verb in verbs:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  for obj in objects:
39
- triples.append((subj.text, verb.lemma_, obj.text))
40
- return triples
41
-
42
- # Visualization function
43
- def show_graph(triples):
44
- G = nx.DiGraph()
45
- for s, p, o in triples:
46
- G.add_node(s)
47
- G.add_node(o)
48
- G.add_edge(s, o, label=p)
49
- pos = nx.spring_layout(G)
50
- plt.figure(figsize=(10, 8))
51
- nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=10)
52
- nx.draw_networkx_edge_labels(G, pos, edge_labels={(u,v):d['label'] for u,v,d in G.edges(data=True)})
53
- st.pyplot(plt)
54
-
55
- # === Streamlit UI ===
56
- st.title("🧠 Knowledge Graph Generator")
57
-
58
- text_input = st.text_area("Paste your text here", height=200)
59
-
60
- if st.button("Generate Graph"):
61
- if text_input:
62
- triples = extract_triples(text_input)
63
- st.write("### Extracted Triples")
64
- for t in triples:
65
- st.write("πŸ”—", t)
66
- show_graph(triples)
67
- else:
68
- st.warning("Please enter some text.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import spacy
3
  import subprocess
4
  from neo4j import GraphDatabase
5
  import matplotlib.pyplot as plt
6
  import networkx as nx
7
+ import logging
8
+ from sklearn.feature_extraction.text import TfidfVectorizer
9
+ import re
10
 
11
+ # Set up logging configuration
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+ # Ensure spaCy model is installed
15
  def install_spacy_model():
16
  try:
17
  spacy.load("en_core_web_sm")
 
19
  subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
20
  install_spacy_model()
21
 
22
+ # Load the spaCy model after ensuring it's installed
23
  nlp = spacy.load("en_core_web_sm")
24
 
25
  # Neo4j credentials
26
  uri = "neo4j+s://ff701b1c.databases.neo4j.io"
27
  username = "neo4j"
28
+ password = "BfZM7YRKpFz1b_V7acAmOtaSQHPU9xK03rJlfPep88g"
29
+
30
+ driver = None
31
+ try:
32
+ driver = GraphDatabase.driver(uri, auth=(username, password))
33
+ logging.info("βœ… Connected to Neo4j!")
34
+
35
+ def close_driver():
36
+ if driver:
37
+ driver.close()
38
+ logging.info("πŸ”’ Neo4j driver closed.")
39
+
40
+ def create_entity(tx, name: str):
41
+ tx.run("MERGE (e:Entity {name: $name})", name=name)
42
+
43
+ def create_relationship(tx, subj: str, pred: str, obj: str):
44
+ tx.run("""
45
+ MERGE (a:Entity {name: $subj})
46
+ MERGE (b:Entity {name: $obj})
47
+ MERGE (a)-[:RELATION {name: $pred}]->(b)
48
+ """, subj=subj, pred=pred, obj=obj)
49
+
50
+ # Text Processing
51
+ def load_and_clean_text(file_path: str) -> str:
52
+ with open(file_path, 'r', encoding='utf-8') as file:
53
+ text = file.read()
54
+ text = re.sub(r'\n+', ' ', text)
55
+ return re.sub(r'\s+', ' ', text).strip().lower()
56
+
57
+ # TF-IDF Filtering
58
+ def compute_tfidf_keywords(text: str, top_n=60):
59
+ vectorizer = TfidfVectorizer(stop_words='english')
60
+ X = vectorizer.fit_transform([text])
61
+ scores = zip(vectorizer.get_feature_names_out(), X.toarray()[0])
62
+ sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True)
63
+ return {word for word, _ in sorted_scores[:top_n]}
64
+
65
+ # Triple Extraction
66
+ def get_full_phrase(token) -> str:
67
+ return ' '.join(tok.text for tok in token.subtree if tok.dep_ != 'punct').strip()
68
+
69
+ def extract_rich_triples(doc, tfidf_keywords) -> list:
70
+ triples = []
71
+ for sent in doc.sents:
72
+ subjects = [tok for tok in sent if "subj" in tok.dep_]
73
+ objects = [tok for tok in sent if "obj" in tok.dep_]
74
+ verbs = [tok for tok in sent if tok.pos_ == "VERB"]
75
+ for subj in subjects:
76
  for obj in objects:
77
+ for verb in verbs:
78
+ s = get_full_phrase(subj)
79
+ o = get_full_phrase(obj)
80
+ if s.lower() in tfidf_keywords or o.lower() in tfidf_keywords:
81
+ triples.append((s, verb.lemma_, o))
82
+ return triples
83
+
84
+ # Graph Visualization
85
+ def visualize_knowledge_graph(triples: list, title: str = "Knowledge Graph"):
86
+ G = nx.DiGraph()
87
+ for subj, pred, obj in triples:
88
+ G.add_node(subj, label='Subject')
89
+ G.add_node(obj, label='Object')
90
+ G.add_edge(subj, obj, label=pred)
91
+
92
+ pos = nx.spring_layout(G, k=1.2, seed=42)
93
+ node_colors = ['skyblue' if G.nodes[n]['label'] == 'Subject' else 'lightgreen' for n in G.nodes]
94
+
95
+ plt.figure(figsize=(16, 16))
96
+ nx.draw(G, pos, with_labels=True, node_size=1200, node_color=node_colors,
97
+ font_size=10, font_weight='bold', edge_color='gray', alpha=0.8)
98
+ nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['label'] for u, v, d in G.edges(data=True)},
99
+ font_size=8, font_color='red')
100
+ plt.title(title, fontsize=20)
101
+ plt.show()
102
+
103
+ # Streamlit UI
104
+ st.title("🧠 Knowledge Graph Generator")
105
+
106
+ text_input = st.text_area("Paste your text here", height=200)
107
+
108
+ if st.button("Generate Graph"):
109
+ if text_input:
110
+ triples = extract_rich_triples(nlp(text_input), compute_tfidf_keywords(text_input))
111
+ logging.info(f"🧠 Extracted {len(triples)} filtered triples.")
112
+ for t in triples[:10]:
113
+ st.write("πŸ”—", t)
114
+
115
+ # Push to Neo4j
116
+ with driver.session() as session:
117
+ for subj, pred, obj in triples:
118
+ session.execute_write(create_entity, subj)
119
+ session.execute_write(create_entity, obj)
120
+ session.execute_write(create_relationship, subj, pred, obj)
121
+
122
+ logging.info("πŸ“‘ Triples successfully stored in Neo4j.")
123
+ st.success("Triples successfully stored in Neo4j.")
124
+
125
+ # Final Visualization
126
+ visualize_knowledge_graph(triples, title="Filtered Knowledge Graph (TF-IDF)")
127
+ else:
128
+ st.warning("Please enter some text.")
129
+
130
+ except Exception as e:
131
+ logging.error(f"❌ An error occurred: {e}", exc_info=True)
132
+ st.error("An error occurred. Please check the logs.")
133
+
134
+ finally:
135
+ if driver:
136
+ close_driver()