File size: 2,564 Bytes
868cab1
 
 
 
984d51c
 
868cab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b575cdb
868cab1
 
 
 
 
 
 
 
 
b575cdb
868cab1
 
b575cdb
868cab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b575cdb
 
 
 
 
 
 
868cab1
 
 
 
 
 
 
 
 
 
b575cdb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
import spacy
import networkx as nx
import plotly.graph_objs as go
import spacy.cli
spacy.cli.download("en_core_web_sm")

# Load the spaCy model
nlp = spacy.load("en_core_web_sm")

def create_knowledge_graph(text):
    doc = nlp(text)
    entities = [(ent.text, ent.label_) for ent in doc.ents]
    
    G = nx.Graph()
    for entity, label in entities:
        G.add_node(entity, label=label)

    for i in range(len(entities) - 1):
        G.add_edge(entities[i][0], entities[i + 1][0])
    
    return G, entities

def plot_graph(G):
    pos = nx.spring_layout(G)
    edges = G.edges()
    edge_x = []
    edge_y = []
    
    for edge in edges:
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)

    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='black'),
        hoverinfo='none',
        mode='lines'))

    fig.add_trace(go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        text=list(G.nodes()),
        textposition="top center",
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            size=10,
            color=node_y,
            colorbar=dict(thickness=15, title="Node Connections", xanchor='left', titleside='right'),
        )))

    fig.update_layout(showlegend=False, hovermode='closest',
                      margin=dict(b=0, t=0, l=0, r=0),
                      xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                      yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
    
    return fig

# Streamlit app
# st.title("Knowledge Graph Generator")

# File uploader
uploaded_file = st.file_uploader("Upload a text file", type="txt")
if uploaded_file is not None:
    text_input = uploaded_file.read().decode("utf-8")
else:
    text_input = st.text_area("Or enter your text here:")

if st.button("Generate Knowledge Graph"):
    if text_input:
        G, entities = create_knowledge_graph(text_input)
        
        fig = plot_graph(G)
        st.plotly_chart(fig)
        
        st.write("Extracted Entities:")
        st.write(entities)
    else:
        st.warning("Please enter some text or upload a file.")