jaibadachiya commited on
Commit
e942dfb
·
verified ·
1 Parent(s): bdfed26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -5,6 +5,7 @@ import subprocess
5
  from neo4j import GraphDatabase
6
  import matplotlib.pyplot as plt
7
  import networkx as nx
 
8
 
9
  # === Ensure spaCy model is installed ===
10
  def install_spacy_model():
@@ -14,10 +15,10 @@ def install_spacy_model():
14
  subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
15
  install_spacy_model()
16
 
17
- # Load the model after ensuring it's installed
18
  nlp = spacy.load("en_core_web_sm")
19
 
20
- # Neo4j credentials
21
  uri = "neo4j+s://ff701b1c.databases.neo4j.io"
22
  username = "neo4j"
23
  password = "BfZM7YRKpFz1b_V7acAmOtaSQHPU9xK03rJlfPep88g"
@@ -25,9 +26,18 @@ password = "BfZM7YRKpFz1b_V7acAmOtaSQHPU9xK03rJlfPep88g"
25
  # Connect to Neo4j
26
  driver = GraphDatabase.driver(uri, auth=(username, password))
27
 
28
- # Triple extraction function
 
 
 
 
 
 
 
 
29
  def extract_triples(text):
30
  doc = nlp(text)
 
31
  triples = []
32
  for sent in doc.sents:
33
  subjects = [tok for tok in sent if "subj" in tok.dep_]
@@ -36,12 +46,15 @@ def extract_triples(text):
36
  for subj in subjects:
37
  for verb in verbs:
38
  for obj in objects:
39
- triples.append((subj.text, verb.lemma_, obj.text))
40
- if len(triples) == 10:
41
- return triples
 
 
 
42
  return triples
43
 
44
- # Visualization function
45
  def show_graph(triples):
46
  G = nx.DiGraph()
47
  for s, p, o in triples:
@@ -55,16 +68,16 @@ def show_graph(triples):
55
  st.pyplot(plt)
56
 
57
  # === Streamlit UI ===
58
- st.title("🧠 Knowledge Graph Generator")
59
 
60
  text_input = st.text_area("Paste your text here", height=200)
61
 
62
  if st.button("Generate Graph"):
63
  if text_input:
64
  triples = extract_triples(text_input)
65
- st.write("### Extracted Triples")
66
  for t in triples:
67
  st.write("🔗", t)
68
  show_graph(triples)
69
  else:
70
- st.warning("Please enter some text.")
 
5
  from neo4j import GraphDatabase
6
  import matplotlib.pyplot as plt
7
  import networkx as nx
8
+ from sklearn.feature_extraction.text import TfidfVectorizer
9
 
10
  # === Ensure spaCy model is installed ===
11
  def install_spacy_model():
 
15
  subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
16
  install_spacy_model()
17
 
18
+ # Load the spaCy model
19
  nlp = spacy.load("en_core_web_sm")
20
 
21
+ # === Neo4j credentials ===
22
  uri = "neo4j+s://ff701b1c.databases.neo4j.io"
23
  username = "neo4j"
24
  password = "BfZM7YRKpFz1b_V7acAmOtaSQHPU9xK03rJlfPep88g"
 
26
  # Connect to Neo4j
27
  driver = GraphDatabase.driver(uri, auth=(username, password))
28
 
29
+ # === TF-IDF Filtering ===
30
+ def compute_tfidf_keywords(text: str, top_n=60):
31
+ vectorizer = TfidfVectorizer(stop_words='english')
32
+ X = vectorizer.fit_transform([text])
33
+ scores = zip(vectorizer.get_feature_names_out(), X.toarray()[0])
34
+ sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True)
35
+ return {word for word, _ in sorted_scores[:top_n]}
36
+
37
+ # === Triple extraction with TF-IDF filtering and limit to 10 ===
38
  def extract_triples(text):
39
  doc = nlp(text)
40
+ tfidf_keywords = compute_tfidf_keywords(text)
41
  triples = []
42
  for sent in doc.sents:
43
  subjects = [tok for tok in sent if "subj" in tok.dep_]
 
46
  for subj in subjects:
47
  for verb in verbs:
48
  for obj in objects:
49
+ if (subj.text.lower() in tfidf_keywords or
50
+ verb.lemma_.lower() in tfidf_keywords or
51
+ obj.text.lower() in tfidf_keywords):
52
+ triples.append((subj.text, verb.lemma_, obj.text))
53
+ if len(triples) == 10:
54
+ return triples
55
  return triples
56
 
57
+ # === Visualization Function ===
58
  def show_graph(triples):
59
  G = nx.DiGraph()
60
  for s, p, o in triples:
 
68
  st.pyplot(plt)
69
 
70
  # === Streamlit UI ===
71
+ st.title("🧠 Knowledge Graph Generator with TF-IDF Filtering")
72
 
73
  text_input = st.text_area("Paste your text here", height=200)
74
 
75
  if st.button("Generate Graph"):
76
  if text_input:
77
  triples = extract_triples(text_input)
78
+ st.write("### Extracted Triples (Top 10 filtered by TF-IDF):")
79
  for t in triples:
80
  st.write("🔗", t)
81
  show_graph(triples)
82
  else:
83
+ st.warning("Please enter some text.")