# app.py import streamlit as st import spacy import subprocess from neo4j import GraphDatabase import matplotlib.pyplot as plt import networkx as nx from sklearn.feature_extraction.text import TfidfVectorizer # === Ensure spaCy model is installed === def install_spacy_model(): try: spacy.load("en_core_web_sm") except OSError: subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"]) spacy.load("en_core_web_sm") install_spacy_model() # Load spaCy model nlp = spacy.load("en_core_web_sm") # === Neo4j credentials === NEO4J_URI = "neo4j+s://ff701b1c.databases.neo4j.io" NEO4J_USERNAME = "neo4j" NEO4J_PASSWORD = "BfZM7YRKpFz1b_V7acAmOtaSQHPU9xK03rJlfPep88g" def get_neo4j_driver(): try: driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) return driver except Exception as e: st.error(f"Failed to connect to Neo4j: {e}") return None # === TF-IDF Filtering (Optional for noise reduction) === def compute_tfidf_keywords(text: str, top_n=100): vectorizer = TfidfVectorizer(stop_words='english') X = vectorizer.fit_transform([text]) scores = zip(vectorizer.get_feature_names_out(), X.toarray()[0]) sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True) return {word for word, _ in sorted_scores[:top_n]} # === Triple Extraction === def extract_triples(text, use_tfidf=False): doc = nlp(text) tfidf_keywords = compute_tfidf_keywords(text) if use_tfidf else None triples = [] for sent in doc.sents: subject = "" obj = "" verb = "" noun_chunks = list(sent.noun_chunks) root = [token for token in sent if token.dep_ == "ROOT"] if root: verb = root[0].lemma_ for chunk in noun_chunks: if chunk.root.dep_ in ("nsubj", "nsubjpass") and not subject: subject = chunk.text elif chunk.root.dep_ in ("dobj", "pobj", "attr") and not obj: obj = chunk.text if subject and verb and obj: if tfidf_keywords: if subject.lower() in tfidf_keywords or obj.lower() in tfidf_keywords: triples.append((subject.strip(), verb.strip(), obj.strip())) else: triples.append((subject.strip(), verb.strip(), obj.strip())) return triples # === Visualization Function === def show_graph(triples): if not triples: st.warning("No triples found to visualize.") return G = nx.DiGraph() for s, p, o in triples: G.add_node(s) G.add_node(o) G.add_edge(s, o, label=p) pos = nx.spring_layout(G, seed=42) # fixed layout plt.figure(figsize=(10, 8)) nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=10, edge_color='gray') nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['label'] for u, v, d in G.edges(data=True)}) st.pyplot(plt.gcf()) plt.clf() # === Streamlit UI === st.title("🧠 Knowledge Graph Generator") text_input = st.text_area("Paste your text here:", height=200) use_tfidf = st.checkbox("Use TF-IDF filtering (Optional: Recommended for large texts)") if st.button("Generate Graph"): if text_input: all_triples = extract_triples(text_input, use_tfidf=use_tfidf) if all_triples: st.subheader("🔗 Extracted Triples:") for triple in all_triples: st.markdown(f"- **({triple[0]} → {triple[1]} → {triple[2]})**") show_graph(all_triples) else: st.warning("No valid triples could be extracted. Try different text.") else: st.warning("Please enter some text.")