File size: 3,740 Bytes
bdfed26
c641539
7fc01f2
 
743d3f5
7fc01f2
 
 
e942dfb
7fc01f2
bdfed26
743d3f5
 
 
 
 
c641539
 
743d3f5
 
0de17ce
743d3f5
 
e942dfb
c641539
 
 
a5dfac5
c641539
 
 
 
 
 
 
a5dfac5
c641539
0de17ce
e942dfb
 
 
 
 
 
c641539
 
a5dfac5
c641539
bdfed26
0de17ce
bdfed26
0de17ce
 
 
 
 
 
 
 
 
 
c641539
0de17ce
c641539
0de17ce
 
 
c641539
 
 
 
 
0de17ce
bdfed26
 
e942dfb
bdfed26
c641539
 
 
 
bdfed26
 
 
 
 
c641539
 
bdfed26
c641539
0de17ce
c641539
 
bdfed26
 
c641539
bdfed26
c641539
 
bdfed26
 
 
c641539
0de17ce
c641539
 
 
 
0de17ce
c641539
 
 
bdfed26
e942dfb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# app.py

import streamlit as st
import spacy
import subprocess
from neo4j import GraphDatabase
import matplotlib.pyplot as plt
import networkx as nx
from sklearn.feature_extraction.text import TfidfVectorizer

# === Ensure spaCy model is installed ===
def install_spacy_model():
    try:
        spacy.load("en_core_web_sm")
    except OSError:
        subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
        spacy.load("en_core_web_sm")

install_spacy_model()

# Load spaCy model
nlp = spacy.load("en_core_web_sm")

# === Neo4j credentials ===
NEO4J_URI = "neo4j+s://ff701b1c.databases.neo4j.io"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "BfZM7YRKpFz1b_V7acAmOtaSQHPU9xK03rJlfPep88g"  

def get_neo4j_driver():
    try:
        driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
        return driver
    except Exception as e:
        st.error(f"Failed to connect to Neo4j: {e}")
        return None

# === TF-IDF Filtering (Optional for noise reduction) ===
def compute_tfidf_keywords(text: str, top_n=100):
    vectorizer = TfidfVectorizer(stop_words='english')
    X = vectorizer.fit_transform([text])
    scores = zip(vectorizer.get_feature_names_out(), X.toarray()[0])
    sorted_scores = sorted(scores, key=lambda x: x[1], reverse=True)
    return {word for word, _ in sorted_scores[:top_n]}

# === Triple Extraction ===
def extract_triples(text, use_tfidf=False):
    doc = nlp(text)
    tfidf_keywords = compute_tfidf_keywords(text) if use_tfidf else None
    triples = []

    for sent in doc.sents:
        subject = ""
        obj = ""
        verb = ""

        noun_chunks = list(sent.noun_chunks)
        root = [token for token in sent if token.dep_ == "ROOT"]
        if root:
            verb = root[0].lemma_

        for chunk in noun_chunks:
            if chunk.root.dep_ in ("nsubj", "nsubjpass") and not subject:
                subject = chunk.text
            elif chunk.root.dep_ in ("dobj", "pobj", "attr") and not obj:
                obj = chunk.text

        if subject and verb and obj:
            if tfidf_keywords:
                if subject.lower() in tfidf_keywords or obj.lower() in tfidf_keywords:
                    triples.append((subject.strip(), verb.strip(), obj.strip()))
            else:
                triples.append((subject.strip(), verb.strip(), obj.strip()))

    return triples

# === Visualization Function ===
def show_graph(triples):
    if not triples:
        st.warning("No triples found to visualize.")
        return

    G = nx.DiGraph()
    for s, p, o in triples:
        G.add_node(s)
        G.add_node(o)
        G.add_edge(s, o, label=p)

    pos = nx.spring_layout(G, seed=42)  # fixed layout
    plt.figure(figsize=(10, 8))
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=2000, font_size=10, edge_color='gray')
    nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['label'] for u, v, d in G.edges(data=True)})
    st.pyplot(plt.gcf())
    plt.clf()

# === Streamlit UI ===
st.title("๐Ÿง  Knowledge Graph Generator")

text_input = st.text_area("Paste your text here:", height=200)
use_tfidf = st.checkbox("Use TF-IDF filtering (Optional: Recommended for large texts)")

if st.button("Generate Graph"):
    if text_input:
        all_triples = extract_triples(text_input, use_tfidf=use_tfidf)

        if all_triples:
            st.subheader("๐Ÿ”— Extracted Triples:")
            for triple in all_triples:
                st.markdown(f"- **({triple[0]} โ†’ {triple[1]} โ†’ {triple[2]})**")

            show_graph(all_triples)
        else:
            st.warning("No valid triples could be extracted. Try different text.")
    else:
        st.warning("Please enter some text.")