Spaces:
Sleeping
Sleeping
| from llama_index.core import Document | |
| from llama_index.core import KnowledgeGraphIndex, ServiceContext, StorageContext | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.core.graph_stores import SimpleGraphStore | |
| from llama_index.core import SimpleDirectoryReader, load_index_from_storage | |
| from typing import List | |
| from dotenv import load_dotenv | |
| import os | |
| import json | |
| import networkx as nx | |
| from pyvis.network import Network | |
| from datetime import datetime | |
| from retrieve import get_latest_dir | |
| import html | |
| load_dotenv() | |
| llm = OpenAI( | |
| temperature=0.0, model="gpt-3.5-turbo", api_key=os.getenv("OPENAI_API_KEY") | |
| ) | |
| graph_store = SimpleGraphStore() | |
| storage_context = StorageContext.from_defaults(graph_store=graph_store) | |
| service_context = ServiceContext.from_defaults( | |
| llm=llm, chunk_size=2048, chunk_overlap=24 | |
| ) | |
| def create_document(input_dir: str) -> List[Document]: | |
| """ | |
| Create a document from the given directory. | |
| Args: | |
| input_dir (str): The input directory to read the documents from. | |
| Returns: | |
| List[Document]: The list of documents from the directory. | |
| """ | |
| reader = SimpleDirectoryReader( | |
| input_dir, exclude_hidden=True, required_exts=[".json"] | |
| ) | |
| products_document = [] | |
| for docs in reader.iter_data(): | |
| products_document.extend(docs) | |
| return products_document | |
| def kg_triplet_extract_fn(text) -> List[str]: | |
| """ | |
| Extract the triplets from the text. | |
| Args: | |
| text (str): The text to extract the triplets from. | |
| Returns: | |
| List[str]: The list of triplets extracted from the text. | |
| """ | |
| json_text = text.split("\n\n")[-1] | |
| product_spec = json.loads(json_text) | |
| triplets = [] | |
| product_name = product_spec["name"] | |
| del product_spec["name"] | |
| for key, value in product_spec.items(): | |
| triplets.append((product_name, key, value)) | |
| return triplets | |
| def generate_graph_visualization(kg_index): | |
| """ | |
| Generate a graph visualization from the KG index. | |
| Args: | |
| kg_index (KnowledgeGraphIndex): The Knowledge Graph index to generate the visualization from. | |
| Returns: | |
| str: The path to the generated graph visualization. | |
| """ | |
| output_directory = os.getenv("GRAPH_VIS_DIR", "graph_vis") | |
| # Generate a timestamp for the filename | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
| graph_output_file = f"{timestamp}.html" | |
| graph_output_path = os.path.join(output_directory, graph_output_file) | |
| g = kg_index.get_networkx_graph(limit=20000) | |
| net = Network( | |
| notebook=False, | |
| cdn_resources="remote", | |
| height="800px", | |
| width="100%", | |
| select_menu=True, | |
| filter_menu=False, | |
| ) | |
| net.from_nx(g) | |
| net.force_atlas_2based(central_gravity=0.015, gravity=-31) | |
| net.save_graph(graph_output_path) | |
| print(f"Graph visualization saved to: {graph_output_path}") | |
| return graph_output_path | |
| def plot_subgraph(triplets): | |
| """ | |
| Plot a subgraph from the triplets. | |
| Args: | |
| triplets (str): The triplets to plot the subgraph from. | |
| Returns: | |
| str: The escaped HTML content to display the subgraph | |
| """ | |
| G = nx.DiGraph() | |
| for edge_str in eval(triplets): | |
| source, action, target = eval(edge_str) | |
| G.add_edge(source, target, label=action) | |
| net = Network(notebook=True, cdn_resources="remote", height="400px", width="100%") | |
| net.from_nx(G) | |
| net.force_atlas_2based(central_gravity=0.015, gravity=-31) | |
| html_content = net.generate_html() | |
| escaped_html = html.escape(html_content) | |
| return escaped_html | |
| def create_kg(max_features: int = 60): | |
| """ | |
| Create a Knowledge Graph from the given directory. | |
| Args: | |
| max_features (int): The maximum number of features to use for the KG. | |
| Returns: | |
| KnowledgeGraphIndex: The Knowledge Graph index. | |
| """ | |
| input_dir = os.getenv("PROD_SPEC_DIR", "prod_spec") | |
| product_documents = create_document(input_dir) | |
| kg_index = KnowledgeGraphIndex.from_documents( | |
| documents=product_documents, | |
| max_triplets_per_chunk=max_features, | |
| storage_context=storage_context, | |
| service_context=service_context, | |
| show_progress=True, | |
| include_embeddings=True, | |
| kg_triplet_extract_fn=kg_triplet_extract_fn, | |
| ) | |
| graphvis_path = generate_graph_visualization(kg_index) | |
| return kg_index, graphvis_path | |
| def persist_kg(kg_index: KnowledgeGraphIndex) -> str: | |
| """ | |
| Persist the KG index to storage. | |
| Args: | |
| kg_index (KnowledgeGraphIndex): The Knowledge Graph index to persist. | |
| Returns: | |
| str: The path to the persisted KG index. | |
| """ | |
| output_dir = os.getenv("GRAPH_DIR", "graphs") | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
| kg_path = f"{output_dir}/{timestamp}" | |
| kg_index.storage_context.persist(kg_path) | |
| return kg_path | |
| def load_kg(kg_dir: str) -> KnowledgeGraphIndex: | |
| """ | |
| Load the KG index from the given directory. | |
| Args: | |
| kg_dir (str): The parent directory to load the KG index from. | |
| Returns: | |
| KnowledgeGraphIndex: The loaded Knowledge Graph index. | |
| """ | |
| kg_path = get_latest_dir(kg_dir) | |
| kg_index = load_index_from_storage( | |
| StorageContext.from_defaults(persist_dir=kg_path) | |
| ) | |
| return kg_index | |
| def query(kg_dir: str, query: str): | |
| """ | |
| Query the KG index for a given query. | |
| Args: | |
| kg_dir (str): The directory to load the KG index from. | |
| query (str): The query to ask the KG index. | |
| Returns: | |
| Response: The response from the KG index. | |
| """ | |
| kg_index = load_kg(kg_dir) | |
| query_engine = kg_index.as_query_engine( | |
| include_text=True, | |
| response_mode="refine", | |
| graph_store_query_depth=6, | |
| similarity_top_k=5, | |
| ) | |
| response = query_engine.query(query) | |
| return response | |
| def query_graph_qa(graph_rag_index, query, search_level): | |
| """ | |
| Query the Graph-RAG model for a given query. | |
| Args: | |
| graph_rag_index (KnowledgeGraphIndex): The Graph-RAG model index. | |
| query (str): The query to ask the Graph-RAG model. | |
| search_level (int): The max search level to use for the Graph-RAG model. | |
| Returns: | |
| tuple: The response, reference, and reference text from the Graph-RAG model. | |
| """ | |
| myretriever = graph_rag_index.as_retriever( | |
| include_text=True, | |
| similarity_top_k=search_level, | |
| ) | |
| query_engine = graph_rag_index.as_query_engine( | |
| sub_retrievers=[ | |
| myretriever, | |
| ], | |
| graph_store_query_depth=6, | |
| include_text=True, | |
| similarity_top_k=search_level, | |
| ) | |
| response = query_engine.query(query) | |
| nodes = myretriever.retrieve(query) | |
| reference = [] | |
| for _, value in response.metadata.items(): | |
| if isinstance(value, dict) and "kg_rel_texts" in value: | |
| reference = value["kg_rel_texts"] | |
| break | |
| reference_text = [] | |
| for node in nodes: | |
| reference_text.append(node.text) | |
| return response, reference, reference_text | |
| if __name__ == "__main__": | |
| kg_index, graphvis_path = create_kg() | |
| persist_kg(kg_index) | |
| kg_index = load_kg(os.getenv("GRAPH_DIR", "graphs")) | |
| generate_graph_visualization(kg_index) | |
| response = query( | |
| os.getenv("GRAPH_DIR", "graphs"), | |
| "Tell me the Built-in memory in Apple iPhone 15 Pro Max 256Gb Blue Titanium?", | |
| ) | |
| print(response) | |
| key = list(response.metadata)[-1] | |
| print(response.metadata[key]) | |