Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import networkx as nx
|
| 6 |
+
import seaborn as sns
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from langchain.document_loaders import DirectoryLoader
|
| 9 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 10 |
+
from pyvis.network import Network
|
| 11 |
+
from helpers.df_helpers import documents2Dataframe, df2Graph, graph2Df
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
# Constants
|
| 16 |
+
CHUNK_SIZE = 1500
|
| 17 |
+
CHUNK_OVERLAP = 150
|
| 18 |
+
WEIGHT_MULTIPLIER = 4
|
| 19 |
+
COLOR_PALETTE = "hls"
|
| 20 |
+
GRAPH_OUTPUT_DIRECTORY = "./docs/index.html"
|
| 21 |
+
|
| 22 |
+
# Configure logging
|
| 23 |
+
logging.basicConfig(level=logging.INFO)
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
def colors2Community(communities) -> pd.DataFrame:
|
| 27 |
+
palette = sns.color_palette(COLOR_PALETTE, len(communities)).as_hex()
|
| 28 |
+
random.shuffle(palette)
|
| 29 |
+
rows = [{"node": node, "color": color, "group": group + 1}
|
| 30 |
+
for group, community in enumerate(communities)
|
| 31 |
+
for node, color in zip(community, palette)]
|
| 32 |
+
return pd.DataFrame(rows)
|
| 33 |
+
|
| 34 |
+
def contextual_proximity(df: pd.DataFrame) -> pd.DataFrame:
|
| 35 |
+
dfg_long = pd.melt(df, id_vars=["chunk_id"], value_vars=["node_1", "node_2"], value_name="node").drop(columns=["variable"])
|
| 36 |
+
dfg_wide = pd.merge(dfg_long, dfg_long, on="chunk_id", suffixes=("_1", "_2"))
|
| 37 |
+
dfg_wide = dfg_wide[dfg_wide["node_1"] != dfg_wide["node_2"]].reset_index(drop=True)
|
| 38 |
+
dfg2 = dfg_wide.groupby(["node_1", "node_2"]).agg({"chunk_id": [",".join, "count"]}).reset_index()
|
| 39 |
+
dfg2.columns = ["node_1", "node_2", "chunk_id", "count"]
|
| 40 |
+
dfg2.dropna(subset=["node_1", "node_2"], inplace=True)
|
| 41 |
+
dfg2 = dfg2[dfg2["count"] != 1]
|
| 42 |
+
dfg2["edge"] = "contextual proximity"
|
| 43 |
+
return dfg2
|
| 44 |
+
|
| 45 |
+
def load_documents(input_dir):
|
| 46 |
+
loader = DirectoryLoader(input_dir, show_progress=True)
|
| 47 |
+
return loader.load()
|
| 48 |
+
|
| 49 |
+
def split_documents(documents):
|
| 50 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len, is_separator_regex=False)
|
| 51 |
+
return splitter.split_documents(documents)
|
| 52 |
+
|
| 53 |
+
def save_dataframes(df, dfg1, output_dir):
|
| 54 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 55 |
+
dfg1.to_csv(output_dir / "graph.csv", sep="|", index=False)
|
| 56 |
+
df.to_csv(output_dir / "chunks.csv", sep="|", index=False)
|
| 57 |
+
|
| 58 |
+
def load_dataframes(output_dir):
|
| 59 |
+
df = pd.read_csv(output_dir / "chunks.csv", sep="|")
|
| 60 |
+
dfg1 = pd.read_csv(output_dir / "graph.csv", sep="|")
|
| 61 |
+
return df, dfg1
|
| 62 |
+
|
| 63 |
+
def build_graph(dfg):
|
| 64 |
+
nodes = pd.concat([dfg['node_1'], dfg['node_2']], axis=0).unique()
|
| 65 |
+
G = nx.Graph()
|
| 66 |
+
G.add_nodes_from(nodes)
|
| 67 |
+
for _, row in dfg.iterrows():
|
| 68 |
+
G.add_edge(row["node_1"], row["node_2"], title=row["edge"], weight=row['count'] / WEIGHT_MULTIPLIER)
|
| 69 |
+
return G
|
| 70 |
+
|
| 71 |
+
def visualize_graph(G, communities):
|
| 72 |
+
colors = colors2Community(communities)
|
| 73 |
+
for _, row in colors.iterrows():
|
| 74 |
+
G.nodes[row['node']].update(group=row['group'], color=row['color'], size=G.degree[row['node']])
|
| 75 |
+
nt = Network(notebook=False, cdn_resources="remote", height="900px", width="100%", select_menu=True)
|
| 76 |
+
nt.from_nx(G)
|
| 77 |
+
nt.force_atlas_2based(central_gravity=0.015, gravity=-31)
|
| 78 |
+
nt.show_buttons(filter_=["physics"])
|
| 79 |
+
html = nt.generate_html().replace("'", "\"")
|
| 80 |
+
return f"""<iframe style="width: 100%; height: 600px; margin:0 auto"
|
| 81 |
+
name="result" allow="midi; geolocation; microphone; camera;
|
| 82 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
| 83 |
+
allow-scripts allow-same-origin allow-popups
|
| 84 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen
|
| 85 |
+
allowpaymentrequest frameborder="0" srcdoc='{html}'></iframe>"""
|
| 86 |
+
|
| 87 |
+
def process_pdfs(input_dir, output_dir, regenerate=False):
|
| 88 |
+
if regenerate:
|
| 89 |
+
documents = load_documents(input_dir)
|
| 90 |
+
pages = split_documents(documents)
|
| 91 |
+
df = documents2Dataframe(pages)
|
| 92 |
+
concepts_list = df2Graph(df, model='zephyr:latest')
|
| 93 |
+
dfg1 = graph2Df(concepts_list)
|
| 94 |
+
save_dataframes(df, dfg1, output_dir)
|
| 95 |
+
else:
|
| 96 |
+
df, dfg1 = load_dataframes(output_dir)
|
| 97 |
+
|
| 98 |
+
dfg1.replace("", np.nan, inplace=True)
|
| 99 |
+
dfg1.dropna(subset=["node_1", "node_2", 'edge'], inplace=True)
|
| 100 |
+
dfg1['count'] = WEIGHT_MULTIPLIER
|
| 101 |
+
dfg2 = contextual_proximity(dfg1)
|
| 102 |
+
dfg = pd.concat([dfg1, dfg2], axis=0).groupby(["node_1", "node_2"]).agg({"chunk_id": ",".join, "edge": ','.join, 'count': 'sum'}).reset_index()
|
| 103 |
+
G = build_graph(dfg)
|
| 104 |
+
|
| 105 |
+
communities_generator = nx.community.girvan_newman(G)
|
| 106 |
+
next_level_communities = next(communities_generator)
|
| 107 |
+
next_level_communities = next(communities_generator) # Two levels of communities
|
| 108 |
+
communities = sorted(map(sorted, next_level_communities))
|
| 109 |
+
logger.info(f"Number of Communities = {len(communities)}")
|
| 110 |
+
logger.info(communities)
|
| 111 |
+
|
| 112 |
+
html = visualize_graph(G, communities)
|
| 113 |
+
return html
|
| 114 |
+
|
| 115 |
+
def main():
|
| 116 |
+
data_dir = "cureus"
|
| 117 |
+
input_dir = Path(f"./data_input/{data_dir}")
|
| 118 |
+
output_dir = Path(f"./data_output/{data_dir}")
|
| 119 |
+
html = process_pdfs(input_dir, output_dir, regenerate=False)
|
| 120 |
+
|
| 121 |
+
demo = gr.Interface(fn=lambda: html, inputs=None, outputs=gr.HTML(), title="Text to knowledge graph", allow_flagging='never')
|
| 122 |
+
demo.launch()
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
main()
|