Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import spacy | |
| import networkx as nx | |
| import plotly.graph_objs as go | |
| import spacy.cli | |
| spacy.cli.download("en_core_web_sm") | |
| # Load the spaCy model | |
| nlp = spacy.load("en_core_web_sm") | |
| def create_knowledge_graph(text): | |
| doc = nlp(text) | |
| entities = [(ent.text, ent.label_) for ent in doc.ents] | |
| G = nx.Graph() | |
| for entity, label in entities: | |
| G.add_node(entity, label=label) | |
| for i in range(len(entities) - 1): | |
| G.add_edge(entities[i][0], entities[i + 1][0]) | |
| return G, entities | |
| def plot_graph(G): | |
| pos = nx.spring_layout(G) | |
| edges = G.edges() | |
| edge_x = [] | |
| edge_y = [] | |
| for edge in edges: | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| edge_x.append(x0) | |
| edge_x.append(x1) | |
| edge_x.append(None) | |
| edge_y.append(y0) | |
| edge_y.append(y1) | |
| edge_y.append(None) | |
| node_x = [] | |
| node_y = [] | |
| for node in G.nodes(): | |
| x, y = pos[node] | |
| node_x.append(x) | |
| node_y.append(y) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=edge_x, y=edge_y, | |
| line=dict(width=0.5, color='black'), | |
| hoverinfo='none', | |
| mode='lines')) | |
| fig.add_trace(go.Scatter( | |
| x=node_x, y=node_y, | |
| mode='markers+text', | |
| text=list(G.nodes()), | |
| textposition="top center", | |
| hoverinfo='text', | |
| marker=dict( | |
| showscale=True, | |
| colorscale='YlGnBu', | |
| size=10, | |
| color=node_y, | |
| colorbar=dict(thickness=15, title="Node Connections", xanchor='left', titleside='right'), | |
| ))) | |
| fig.update_layout(showlegend=False, hovermode='closest', | |
| margin=dict(b=0, t=0, l=0, r=0), | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)) | |
| return fig | |
| # Streamlit app | |
| # st.title("Knowledge Graph Generator") | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload a text file", type="txt") | |
| if uploaded_file is not None: | |
| text_input = uploaded_file.read().decode("utf-8") | |
| else: | |
| text_input = st.text_area("Or enter your text here:") | |
| if st.button("Generate Knowledge Graph"): | |
| if text_input: | |
| G, entities = create_knowledge_graph(text_input) | |
| fig = plot_graph(G) | |
| st.plotly_chart(fig) | |
| st.write("Extracted Entities:") | |
| st.write(entities) | |
| else: | |
| st.warning("Please enter some text or upload a file.") | |