Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import spacy | |
| import shutil | |
| import pickle | |
| import random | |
| import hashlib | |
| import logging | |
| import asyncio | |
| import warnings | |
| import rapidjson | |
| import gradio as gr | |
| import networkx as nx | |
| from llm_graph import LLMGraph, MODEL_LIST | |
| from pyvis.network import Network | |
| from spacy import displacy | |
| from spacy.tokens import Span | |
| logging.basicConfig(level=logging.INFO) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # Constants | |
| TITLE = "π Text2Graph: Extract Knowledge Graphs from Natural Language" | |
| SUBTITLE = "β¨ Extract and visualize knowledge graphs from texts in any language!" | |
| MIN_CHARS = 20 | |
| MAX_CHARS = 3500 | |
| # Keep track of all processed texts | |
| doc_ids = [] | |
| # Basic CSS for styling | |
| CUSTOM_CSS = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Roboto, sans-serif; | |
| } | |
| """ | |
| # Cache directory and file paths | |
| CACHE_DIR = "./cache" | |
| WORKING_DIR = "./sample" | |
| EXAMPLE_CACHE_FILE = os.path.join(CACHE_DIR, "first_example_cache.pkl") | |
| GRAPHML_FILE = WORKING_DIR + "/graph_chunk_entity_relation.graphml" | |
| # Load the sample texts | |
| text_en_file1 = "./data/sample1_en.txt" | |
| with open(text_en_file1, 'r', encoding='utf-8') as file: | |
| text1_en = file.read() | |
| text_en_file2 = "./data/sample2_en.txt" | |
| with open(text_en_file2, 'r', encoding='utf-8') as file: | |
| text2_en = file.read() | |
| text_en_file3 = "./data/sample3_en.txt" | |
| with open(text_en_file3, 'r', encoding='utf-8') as file: | |
| text3_en = file.read() | |
| text_fr_file = "./data/sample_fr.txt" | |
| with open(text_fr_file, 'r', encoding='utf-8') as file: | |
| text_fr = file.read() | |
| text_es_file = "./data/sample_es.txt" | |
| with open(text_es_file, 'r', encoding='utf-8') as file: | |
| text_es = file.read() | |
| text_it_file = "./data/sample_it.txt" | |
| with open(text_it_file, 'r', encoding='utf-8') as file: | |
| text_it = file.read() | |
| # Create cache directory if it doesn't exist | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| os.makedirs(WORKING_DIR, exist_ok=True) | |
| def get_random_light_color(): | |
| """ | |
| Color utilities | |
| """ | |
| r = random.randint(140, 255) | |
| g = random.randint(140, 255) | |
| b = random.randint(140, 255) | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| def handle_text(text=""): | |
| """ | |
| Text preprocessing | |
| """ | |
| # Catch empty text | |
| if not text: | |
| return "" | |
| return " ".join(text.split()) | |
| def extract_kg(text="", model_name=MODEL_LIST[0], model=None): | |
| """ | |
| Extract knowledge graph from text | |
| """ | |
| # Catch empty text | |
| if not text or not model_name: | |
| raise gr.Error("β οΈ Both text and model must be provided!") | |
| if not model: | |
| raise gr.Error("β οΈ Model must be provided!") | |
| try: | |
| start_time = time.time() | |
| result = model.extract(text, model_name) | |
| end_time = time.time() | |
| duration = end_time - start_time | |
| logging.info(f"Response time: {duration:.4f} seconds") | |
| if isinstance(result, dict): | |
| return result | |
| else: # convert string to dict | |
| return rapidjson.loads(result) | |
| except Exception as e: | |
| raise gr.Error(f"β Extraction error: {str(e)}") | |
| def find_token_indices(doc, substring, text): | |
| """ | |
| Find token indices for a given substring in the text | |
| based on the provided spaCy doc. | |
| """ | |
| result = [] | |
| start_idx = text.find(substring) | |
| while start_idx != -1: | |
| end_idx = start_idx + len(substring) | |
| start_token = None | |
| end_token = None | |
| for token in doc: | |
| if token.idx == start_idx: | |
| start_token = token.i | |
| if token.idx + len(token) == end_idx: | |
| end_token = token.i + 1 | |
| if start_token is not None and end_token is not None: | |
| result.append({ | |
| "start": start_token, | |
| "end": end_token | |
| }) | |
| # Search for next occurrence | |
| start_idx = text.find(substring, end_idx) | |
| return result | |
| def create_custom_entity_viz(data, full_text, type_col="type"): | |
| """ | |
| Create custom entity visualization using spaCy's displacy | |
| """ | |
| nlp = spacy.blank("xx") | |
| doc = nlp(full_text) | |
| spans = [] | |
| colors = {} | |
| for node in data["nodes"]: | |
| entity_spans = find_token_indices(doc, node["id"], full_text) | |
| for entity in entity_spans: | |
| start = entity["start"] | |
| end = entity["end"] | |
| if start < len(doc) and end <= len(doc): | |
| # Check for overlapping spans | |
| overlapping = any(s.start < end and start < s.end for s in spans) | |
| if not overlapping: | |
| node_type = node.get(type_col, "Entity") | |
| span = Span(doc, start, end, label=node_type) | |
| spans.append(span) | |
| if node_type not in colors: | |
| colors[node_type] = get_random_light_color() | |
| doc.set_ents(spans, default="unmodified") | |
| doc.spans["sc"] = spans | |
| options = { | |
| "colors": colors, | |
| "ents": list(colors.keys()), | |
| "style": "ent", | |
| "manual": True | |
| } | |
| html = displacy.render(doc, style="span", options=options) | |
| # Add custom styling to the entity visualization | |
| styled_html = f""" | |
| <div style="padding: 20px; border-radius: 12px; background-color: gray; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);"> | |
| {html} | |
| </div> | |
| """ | |
| return styled_html | |
| def create_graph(json_data, model_name=MODEL_LIST[0]): | |
| """ | |
| Create interactive knowledge graph using pyvis | |
| """ | |
| if model_name == MODEL_LIST[0]: | |
| G = nx.Graph() | |
| # Add nodes with tooltips and error handling for missing keys | |
| for node in json_data['nodes']: | |
| # Get node type with fallback | |
| type = node.get("type", "Entity") | |
| # Get detailed type with fallback | |
| detailed_type = node.get("detailed_type", type) | |
| # Use node ID and type info for the tooltip | |
| G.add_node(node['id'], title=f"{type}: {detailed_type}") | |
| # Add edges with labels | |
| for edge in json_data['edges']: | |
| # Check if the required keys exist | |
| if 'from' in edge and 'to' in edge: | |
| label = edge.get('label', 'related') | |
| G.add_edge(edge['from'], edge['to'], title=label, label=label) | |
| else: | |
| G = nx.read_graphml(GRAPHML_FILE) | |
| # Create network visualization | |
| network = Network( | |
| width="100%", | |
| # height="700px", | |
| height="100vh", | |
| notebook=False, | |
| bgcolor="#f8fafc", | |
| font_color="#1e293b" | |
| ) | |
| # Configure network display | |
| network.from_nx(G) | |
| if model_name == MODEL_LIST[0]: | |
| network.barnes_hut( | |
| gravity=-3000, | |
| central_gravity=0.3, | |
| spring_length=50, | |
| spring_strength=0.001, | |
| damping=0.09, | |
| overlap=0, | |
| ) | |
| # Customize node appearance | |
| for node in network.nodes: | |
| if "description" in node: | |
| node["title"] = node["description"] | |
| node['color'] = {'background': '#e0e7ff', 'border': '#6366f1', 'highlight': {'background': '#c7d2fe', 'border': '#4f46e5'}} | |
| node['font'] = {'size': 14, 'color': '#1e293b'} | |
| node['shape'] = 'dot' | |
| node['size'] = 20 | |
| # Customize edge appearance | |
| for edge in network.edges: | |
| if "description" in edge: | |
| edge["title"] = edge["description"] | |
| edge['width'] = 4 | |
| # edge['arrows'] = {'to': {'enabled': False, 'type': 'arrow'}} | |
| edge['color'] = {'color': '#6366f1', 'highlight': '#4f46e5'} | |
| edge['font'] = {'size': 12, 'color': '#4b5563', 'face': 'Arial'} | |
| # Generate HTML with iframe to isolate styles | |
| html = network.generate_html() | |
| html = html.replace("'", '"') | |
| return f"""<iframe style="width: 100%; height: 700px; margin: 0 auto; border-radius: 12px; box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -4px rgba(0, 0, 0, 0.1);" | |
| name="result" allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;" | |
| sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups | |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
| allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>""" | |
| def process_and_visualize(text, model_name, progress=gr.Progress()): | |
| """ | |
| Process text and visualize knowledge graph and entities | |
| """ | |
| if not text or not model_name: | |
| raise gr.Error("β οΈ Both text and model must be provided!") | |
| # Check if we're processing the first example for caching | |
| is_first_example = text == EXAMPLES[0][0] | |
| # Try to load from cache if it's the first example | |
| if is_first_example and model_name == MODEL_LIST[0] and os.path.exists(EXAMPLE_CACHE_FILE): | |
| try: | |
| progress(0.3, desc="Loading from cache...") | |
| with open(EXAMPLE_CACHE_FILE, 'rb') as f: | |
| cached_data = pickle.load(f) | |
| progress(1.0, desc="Loaded from cache!") | |
| return cached_data["graph_html"], cached_data["entities_viz"], cached_data["json_data"], cached_data["stats"] | |
| except Exception as e: | |
| logging.error(f"Cache loading error: {str(e)}") | |
| # Catch too long or too short text | |
| if len(text) < MIN_CHARS: | |
| raise gr.Error(f"β οΈ Text is too short! Please provide at least {MIN_CHARS} characters.") | |
| if len(text) > MAX_CHARS: | |
| raise gr.Error(f"β οΈ Text is too long! Please provide no more than {MAX_CHARS} characters.") | |
| if model_name == MODEL_LIST[1]: | |
| # Compute the unique hash for the document | |
| doc_id = hashlib.md5(text.strip().encode()).hexdigest() | |
| if doc_id not in doc_ids: | |
| doc_ids.append(doc_id) | |
| # Clear the working directory if it exists | |
| if os.path.exists(WORKING_DIR): | |
| shutil.rmtree(WORKING_DIR) | |
| os.makedirs(WORKING_DIR, exist_ok=True) | |
| # Initialize the LLMGraph model | |
| model = LLMGraph() | |
| asyncio.run(model.initialize_rag()) | |
| # Continue with normal processing if cache fails | |
| progress(0, desc="Starting extraction...") | |
| json_data = extract_kg(text, model_name, model) | |
| progress(0.5, desc="Creating entity visualization...") | |
| if model_name == MODEL_LIST[0]: | |
| entities_viz = create_custom_entity_viz(json_data, text, type_col="type") | |
| else: | |
| entities_viz = create_custom_entity_viz(json_data, text, type_col="entity_type") | |
| progress(0.8, desc="Building knowledge graph...") | |
| graph_html = create_graph(json_data, model_name) | |
| node_count = len(json_data["nodes"]) | |
| edge_count = len(json_data["edges"]) | |
| stats = f"π Extracted {node_count} entities and {edge_count} relationships" | |
| # Save to cache if it's the first example | |
| if is_first_example and model_name == MODEL_LIST[0]: | |
| try: | |
| cached_data = { | |
| "graph_html": graph_html, | |
| "entities_viz": entities_viz, | |
| "json_data": json_data, | |
| "stats": stats | |
| } | |
| with open(EXAMPLE_CACHE_FILE, 'wb') as f: | |
| pickle.dump(cached_data, f) | |
| except Exception as e: | |
| logging.error(f"Cache saving error: {str(e)}") | |
| progress(1.0, desc="Complete!") | |
| return graph_html, entities_viz, json_data, stats | |
| # Example texts | |
| EXAMPLES = [ | |
| [handle_text(text1_en)], | |
| [handle_text(text_fr)], | |
| [handle_text(text2_en)], | |
| [handle_text(text_es)], | |
| [handle_text(text3_en)], | |
| [handle_text(text_it)], | |
| ] | |
| def generate_first_example(): | |
| """ | |
| Generate cache for the first example if it doesn't exist when the app starts. | |
| """ | |
| if not os.path.exists(EXAMPLE_CACHE_FILE): | |
| logging.info("Generating cache for first example...") | |
| try: | |
| text = EXAMPLES[0][0] | |
| model_name = MODEL_LIST[0] if MODEL_LIST else None | |
| # Initialize the LLMGraph model | |
| model = LLMGraph() | |
| asyncio.run(model.initialize_rag()) | |
| # Extract data | |
| json_data = extract_kg(text, model_name, model) | |
| entities_viz = create_custom_entity_viz(json_data, text) | |
| graph_html = create_graph(json_data) | |
| node_count = len(json_data["nodes"]) | |
| edge_count = len(json_data["edges"]) | |
| stats = f"π Extracted {node_count} entities and {edge_count} relationships" | |
| # Save to cache | |
| cached_data = { | |
| "graph_html": graph_html, | |
| "entities_viz": entities_viz, | |
| "json_data": json_data, | |
| "stats": stats | |
| } | |
| with open(EXAMPLE_CACHE_FILE, 'wb') as f: | |
| pickle.dump(cached_data, f) | |
| logging.info("First example cache generated successfully") | |
| return cached_data | |
| except Exception as e: | |
| logging.error(f"Error generating first example cache: {str(e)}") | |
| else: | |
| logging.info("First example cache already exists") | |
| # Load existing cache | |
| try: | |
| with open(EXAMPLE_CACHE_FILE, 'rb') as f: | |
| return pickle.load(f) | |
| except Exception as e: | |
| logging.error(f"Error loading existing cache: {str(e)}") | |
| return None | |
| def create_ui(): | |
| """ | |
| Create the Gradio UI | |
| """ | |
| # Clear the working directory if it exists | |
| if os.path.exists(WORKING_DIR): | |
| shutil.rmtree(WORKING_DIR) | |
| os.makedirs(WORKING_DIR, exist_ok=True) | |
| # Try to generate/load the first example cache | |
| first_example = generate_first_example() | |
| with gr.Blocks(css=CUSTOM_CSS, title=TITLE) as demo: | |
| # Header | |
| gr.Markdown(f"# {TITLE}") | |
| gr.Markdown(f"{SUBTITLE}") | |
| # Main content area | |
| with gr.Row(): | |
| # Left panel - Input controls | |
| with gr.Column(scale=1): | |
| input_model = gr.Radio( | |
| MODEL_LIST, | |
| label="π€ Select Model", | |
| info="Choose a model to process your text", | |
| value=MODEL_LIST[1] if MODEL_LIST else None, | |
| ) | |
| input_text = gr.TextArea( | |
| label="π Input Text", | |
| info="Enter text in any language to extract a knowledge graph", | |
| placeholder="Enter text here...", | |
| lines=8, | |
| value=EXAMPLES[0][0] # Pre-fill with first example | |
| ) | |
| with gr.Row(): | |
| submit_button = gr.Button("π Extract & Visualize", variant="primary", scale=2) | |
| clear_button = gr.Button("π Clear", variant="secondary", scale=1) | |
| # Statistics will appear here | |
| stats_output = gr.Markdown("", label="π Analysis Results") | |
| # Right panel - Examples moved to right side | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π Example Texts") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=input_text, | |
| label="" | |
| ) | |
| # JSON output moved to right side as well | |
| with gr.Accordion("π JSON Data", open=False): | |
| output_json = gr.JSON(label="") | |
| # Full width visualization area at the bottom | |
| with gr.Row(): | |
| # Full width visualization area | |
| with gr.Tabs(): | |
| with gr.TabItem("π§© Knowledge Graph"): | |
| output_graph = gr.HTML(label="") | |
| with gr.TabItem("π·οΈ Entity Recognition"): | |
| output_entity_viz = gr.HTML(label="") | |
| # Functionality | |
| submit_button.click( | |
| fn=process_and_visualize, | |
| inputs=[input_text, input_model], | |
| outputs=[output_graph, output_entity_viz, output_json, stats_output] | |
| ) | |
| clear_button.click( | |
| fn=lambda: [None, None, None, ""], | |
| inputs=[], | |
| outputs=[output_graph, output_entity_viz, output_json, stats_output] | |
| ) | |
| # Set initial values from cache if available | |
| if first_example: | |
| # Use this to set initial values when the app loads | |
| demo.load( | |
| lambda: [ | |
| first_example["graph_html"], | |
| first_example["entities_viz"], | |
| first_example["json_data"], | |
| first_example["stats"] | |
| ], | |
| inputs=None, | |
| outputs=[output_graph, output_entity_viz, output_json, stats_output] | |
| ) | |
| # Footer | |
| gr.Markdown("---") | |
| gr.Markdown("π **Instructions:** Enter text in any language, select a model and click `Extract & Visualize` to generate a knowledge graph.") | |
| gr.Markdown("π οΈ Powered by [GPT-4.1-mini](https://platform.openai.com/docs/models/gpt-4.1-mini) and [Phi-3-mini-128k-instruct-graph](https://huggingface.co/EmergentMethods/Phi-3-mini-128k-instruct-graph)") | |
| return demo | |
| def main(): | |
| """ | |
| Main function to run the Gradio app | |
| """ | |
| demo = create_ui() | |
| demo.launch(share=False) | |
| if __name__ == "__main__": | |
| main() | |