knowledge_graph / app.py
jaibadachiya's picture
Update app.py
c641539 verified
# app.py
import streamlit as st
import spacy
import subprocess
from neo4j import GraphDatabase
import matplotlib.pyplot as plt
import networkx as nx
from sklearn.feature_extraction.text import TfidfVectorizer
# === Ensure spaCy model is installed ===
def install_spacy_model():
try:
spacy.load("en_core_web_sm")
except OSError:
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
spacy.load("en_core_web_sm")
install_spacy_model()
# Load spaCy model
nlp = spacy.load("en_core_web_sm")
# === Neo4j credentials ===
NEO4J_URI = "neo4j+s://ff701b1c.databases.neo4j.io"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "BfZM7YRKpFz1b_V7acAmOtaSQHPU9xK03rJlfPep88g"
def get_neo4j_driver():
try:
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
return driver
except Exception as e:
st.error(f"Failed to connect to Neo4j: {e}")
return None
# === TF-IDF Filtering (Optional for noise reduction) ===
def compute_tfidf_keywords(text: str, top_n=100):
vectorizer = TfidfVectorizer(stop_words='english')
X = vectorizer.fit_transform([text])
scores = zip(vectorizer.get_feature_names_out(), X.toarray()[0])
sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True)
return {word for word, _ in sorted_scores[:top_n]}
# === Triple Extraction ===
def extract_triples(text, use_tfidf=False):
doc = nlp(text)
tfidf_keywords = compute_tfidf_keywords(text) if use_tfidf else None
triples = []
for sent in doc.sents:
subject = ""
obj = ""
verb = ""
noun_chunks = list(sent.noun_chunks)
root = [token for token in sent if token.dep_ == "ROOT"]
if root:
verb = root[0].lemma_
for chunk in noun_chunks:
if chunk.root.dep_ in ("nsubj", "nsubjpass") and not subject:
subject = chunk.text
elif chunk.root.dep_ in ("dobj", "pobj", "attr") and not obj:
obj = chunk.text
if subject and verb and obj:
if tfidf_keywords:
if subject.lower() in tfidf_keywords or obj.lower() in tfidf_keywords:
triples.append((subject.strip(), verb.strip(), obj.strip()))
else:
triples.append((subject.strip(), verb.strip(), obj.strip()))
return triples
# === Visualization Function ===
def show_graph(triples):
if not triples:
st.warning("No triples found to visualize.")
return
G = nx.DiGraph()
for s, p, o in triples:
G.add_node(s)
G.add_node(o)
G.add_edge(s, o, label=p)
pos = nx.spring_layout(G, seed=42) # fixed layout
plt.figure(figsize=(10, 8))
nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=10, edge_color='gray')
nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['label'] for u, v, d in G.edges(data=True)})
st.pyplot(plt.gcf())
plt.clf()
# === Streamlit UI ===
st.title("🧠 Knowledge Graph Generator")
text_input = st.text_area("Paste your text here:", height=200)
use_tfidf = st.checkbox("Use TF-IDF filtering (Optional: Recommended for large texts)")
if st.button("Generate Graph"):
if text_input:
all_triples = extract_triples(text_input, use_tfidf=use_tfidf)
if all_triples:
st.subheader("πŸ”— Extracted Triples:")
for triple in all_triples:
st.markdown(f"- **({triple[0]} β†’ {triple[1]} β†’ {triple[2]})**")
show_graph(all_triples)
else:
st.warning("No valid triples could be extracted. Try different text.")
else:
st.warning("Please enter some text.")