| import os |
| import pathlib |
| import time |
| import re |
| from pinecone import Pinecone |
|
|
| from langchain_mistralai import ChatMistralAI |
| from langchain_openai import ChatOpenAI |
| from langchain_core.messages import HumanMessage, SystemMessage |
| from langchain.schema import Document |
| from langchain_community.document_loaders import ( |
| CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader, |
| UnstructuredPowerPointLoader, UnstructuredMarkdownLoader, |
| UnstructuredHTMLLoader, NotebookLoader |
| ) |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
| from llama_index.core.memory import Memory |
|
|
| import pickle |
|
|
| import json |
| from typing import List, Any |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage |
|
|
| from typing import List, Any |
| from pydantic import BaseModel, ValidationError |
|
|
|
|
|
|
|
|
|
|
| def retrieve_RAG( |
| prompt_message, pc, index, kg_index, top_k=5, info=True, |
| use_query_reformulation=False, llm=None, graphRAG=False, |
| ): |
| """ |
| Retrieve relevant document chunks and community summaries from Pinecone for a given prompt. |
| - Optionally splits and reformulates the prompt for improved search. |
| - Searches both standard document chunks and, if enabled, community summaries from the knowledge graph. |
| - Returns all retrieved results for further use. |
| """ |
|
|
| import os |
| import re |
|
|
| if info: |
| print("[Debug] Starting retrieval with prompt:", prompt_message) |
| print("[Debug] Top K:", top_k) |
| print("[Debug] Query Reformulation Enabled:", use_query_reformulation) |
|
|
| |
| def _graph_available(): |
| try: |
| stats = index.describe_index_stats() |
| namespaces = stats.get("namespaces", {}) or {} |
| return "community-summaries" in namespaces |
| except Exception as e: |
| print(f"[Error] Failed to inspect index namespaces: {e}") |
| return False |
|
|
| graph_ok = bool(kg_index) or _graph_available() |
|
|
| |
| sub_queries = [prompt_message] |
| if llm is not None: |
| try: |
| split_prompt = ( |
| "Given the following user query, identify and list all distinct sub-queries or tasks it contains. " |
| "Return ONLY a numbered list of sub-queries, each as a concise phrase.\n\n" |
| f"User Query: {prompt_message}" |
| ) |
| split_response = llm.invoke(split_prompt) |
| sub_queries = re.findall(r"\d+\.\s*(.+)", split_response.content) |
| if not sub_queries: |
| sub_queries = [prompt_message] |
| if info: |
| print(f"[Debug] Identified sub-queries: {sub_queries}") |
| except Exception as e: |
| print(f"[Error] Sub-query splitting failed: {e}") |
|
|
| all_retrieved_chunks = [] |
| all_graph_context_blocks = [] |
|
|
| |
| for idx, sub_query in enumerate(sub_queries): |
| task_prompt = sub_query.strip() |
|
|
| |
| if use_query_reformulation and llm is not None: |
| try: |
| reformulation_prompt = ( |
| "Reformulate the following query to focus only on the key concepts and remove any unnecessary details. " |
| "It should be suitable for vector search in RAG retrieval:\n\n" |
| f"Original Query: {task_prompt}" |
| ) |
| reformulated_response = llm.invoke(reformulation_prompt) |
| task_prompt = reformulated_response.content.strip() |
| if info: |
| print(f"[Debug] Reformulated Query for sub-query {idx+1}: {task_prompt}") |
| except Exception as e: |
| print(f"[Error] Query reformulation failed for sub-query {idx+1}: {e}") |
|
|
| |
| query_embedding = pc.inference.embed( |
| model="llama-text-embed-v2", |
| inputs=[task_prompt], |
| parameters={"input_type": "query"} |
| ) |
| if info: |
| print(f"[Debug] Query embedding generated for sub-query {idx+1}.") |
| qvec = query_embedding[0].values |
|
|
| |
| try: |
| retrieved_chunks_raw = index.query( |
| namespace="example-namespace", |
| vector=qvec, |
| top_k=top_k, |
| include_values=False, |
| include_metadata=True |
| ) |
| retrieved_chunks = [] |
| for match in retrieved_chunks_raw.matches: |
| text = match.metadata.get("text", "") |
| source = match.metadata.get("source", "Unknown source") |
| retrieved_chunks.append({ |
| "text": text, |
| "source": source, |
| "sub_query": sub_query |
| }) |
| all_retrieved_chunks.extend(retrieved_chunks) |
| if info: |
| print(f"[Debug] Match processed for sub-query {idx+1}: text='{text[:50]}...', source='{source}'") |
| except Exception as e: |
| print(f"[Error] Standard retrieval failed for sub-query {idx+1}: {e}") |
|
|
| |
| if graphRAG: |
| COMMUNITY_NAMESPACE = "community-summaries" |
| TOP_K_SUMMARIES = 5 |
| try: |
| comm_matches = index.query( |
| namespace=COMMUNITY_NAMESPACE, |
| vector=qvec, |
| top_k=TOP_K_SUMMARIES, |
| include_values=False, |
| include_metadata=True |
| ) |
| blocks = [] |
| for m in comm_matches.matches: |
| meta = m.metadata or {} |
| txt = meta.get("text", "") |
| cid = meta.get("community_id", "NA") |
| level = meta.get("level", -1) |
| size = meta.get("size", 0) |
| block = f"[Community {cid} \n level={level} \n size={size}]\n{txt}" |
| blocks.append(block) |
| graph_context_str = ("\n\n---\n\n").join(blocks) |
| all_graph_context_blocks.append((sub_query, graph_context_str)) |
| if info: |
| print(f"[Community] Retrieved {len(blocks)} community summaries for sub-query {idx+1}.") |
| except Exception as e: |
| print(f"[Error] Community summaries retrieval failed for sub-query {idx+1}: {e}") |
|
|
| |
| combined_graph_context = "\n\n====\n\n".join( |
| f"Sub-query: {sub_query}\n{context}" |
| for (sub_query, context) in all_graph_context_blocks if context |
| ) |
|
|
| if info: |
| sources = {os.path.basename(chunk['source']) for chunk in all_retrieved_chunks} |
| print(f"[Debug] Final retrieval: {len(all_retrieved_chunks)} chunks from {len(sources)} sources, " |
| f"graph context length {len(combined_graph_context)}.") |
|
|
| |
| return all_retrieved_chunks, combined_graph_context |
|
|
|
|
|
|