Spaces:
Sleeping
Sleeping
| # app.py | |
| import streamlit as st | |
| import spacy | |
| import subprocess | |
| from neo4j import GraphDatabase | |
| import matplotlib.pyplot as plt | |
| import networkx as nx | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| # === Ensure spaCy model is installed === | |
| def install_spacy_model(): | |
| try: | |
| spacy.load("en_core_web_sm") | |
| except OSError: | |
| subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"]) | |
| spacy.load("en_core_web_sm") | |
| install_spacy_model() | |
| # Load spaCy model | |
| nlp = spacy.load("en_core_web_sm") | |
| # === Neo4j credentials === | |
| NEO4J_URI = "neo4j+s://ff701b1c.databases.neo4j.io" | |
| NEO4J_USERNAME = "neo4j" | |
| NEO4J_PASSWORD = "BfZM7YRKpFz1b_V7acAmOtaSQHPU9xK03rJlfPep88g" | |
| def get_neo4j_driver(): | |
| try: | |
| driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) | |
| return driver | |
| except Exception as e: | |
| st.error(f"Failed to connect to Neo4j: {e}") | |
| return None | |
| # === TF-IDF Filtering (Optional for noise reduction) === | |
| def compute_tfidf_keywords(text: str, top_n=100): | |
| vectorizer = TfidfVectorizer(stop_words='english') | |
| X = vectorizer.fit_transform([text]) | |
| scores = zip(vectorizer.get_feature_names_out(), X.toarray()[0]) | |
| sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True) | |
| return {word for word, _ in sorted_scores[:top_n]} | |
| # === Triple Extraction === | |
| def extract_triples(text, use_tfidf=False): | |
| doc = nlp(text) | |
| tfidf_keywords = compute_tfidf_keywords(text) if use_tfidf else None | |
| triples = [] | |
| for sent in doc.sents: | |
| subject = "" | |
| obj = "" | |
| verb = "" | |
| noun_chunks = list(sent.noun_chunks) | |
| root = [token for token in sent if token.dep_ == "ROOT"] | |
| if root: | |
| verb = root[0].lemma_ | |
| for chunk in noun_chunks: | |
| if chunk.root.dep_ in ("nsubj", "nsubjpass") and not subject: | |
| subject = chunk.text | |
| elif chunk.root.dep_ in ("dobj", "pobj", "attr") and not obj: | |
| obj = chunk.text | |
| if subject and verb and obj: | |
| if tfidf_keywords: | |
| if subject.lower() in tfidf_keywords or obj.lower() in tfidf_keywords: | |
| triples.append((subject.strip(), verb.strip(), obj.strip())) | |
| else: | |
| triples.append((subject.strip(), verb.strip(), obj.strip())) | |
| return triples | |
| # === Visualization Function === | |
| def show_graph(triples): | |
| if not triples: | |
| st.warning("No triples found to visualize.") | |
| return | |
| G = nx.DiGraph() | |
| for s, p, o in triples: | |
| G.add_node(s) | |
| G.add_node(o) | |
| G.add_edge(s, o, label=p) | |
| pos = nx.spring_layout(G, seed=42) # fixed layout | |
| plt.figure(figsize=(10, 8)) | |
| nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=10, edge_color='gray') | |
| nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['label'] for u, v, d in G.edges(data=True)}) | |
| st.pyplot(plt.gcf()) | |
| plt.clf() | |
| # === Streamlit UI === | |
| st.title("π§ Knowledge Graph Generator") | |
| text_input = st.text_area("Paste your text here:", height=200) | |
| use_tfidf = st.checkbox("Use TF-IDF filtering (Optional: Recommended for large texts)") | |
| if st.button("Generate Graph"): | |
| if text_input: | |
| all_triples = extract_triples(text_input, use_tfidf=use_tfidf) | |
| if all_triples: | |
| st.subheader("π Extracted Triples:") | |
| for triple in all_triples: | |
| st.markdown(f"- **({triple[0]} β {triple[1]} β {triple[2]})**") | |
| show_graph(all_triples) | |
| else: | |
| st.warning("No valid triples could be extracted. Try different text.") | |
| else: | |
| st.warning("Please enter some text.") | |