import os import pathlib import time import re from pinecone import Pinecone from langchain_mistralai import ChatMistralAI from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, SystemMessage from langchain.schema import Document from langchain_community.document_loaders import ( CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader, UnstructuredPowerPointLoader, UnstructuredMarkdownLoader, UnstructuredHTMLLoader, NotebookLoader ) from langchain_text_splitters import RecursiveCharacterTextSplitter from llama_index.core.memory import Memory import pickle import json from typing import List, Any from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage from typing import List, Any from pydantic import BaseModel, ValidationError def retrieve_RAG( prompt_message, pc, index, kg_index, top_k=5, info=True, use_query_reformulation=False, llm=None, graphRAG=False, ): """ Retrieve relevant document chunks and community summaries from Pinecone for a given prompt. - Optionally splits and reformulates the prompt for improved search. - Searches both standard document chunks and, if enabled, community summaries from the knowledge graph. - Returns all retrieved results for further use. """ import os import re if info: print("[Debug] Starting retrieval with prompt:", prompt_message) print("[Debug] Top K:", top_k) print("[Debug] Query Reformulation Enabled:", use_query_reformulation) # --- Step 0: Decide context usage (standard, graph, both) --- def _graph_available(): try: stats = index.describe_index_stats() namespaces = stats.get("namespaces", {}) or {} return "community-summaries" in namespaces except Exception as e: print(f"[Error] Failed to inspect index namespaces: {e}") return False graph_ok = bool(kg_index) or _graph_available() # --- Step 1: Use LLM to split the prompt into sub-queries --- sub_queries = [prompt_message] # fallback: single query if llm is not None: try: split_prompt = ( "Given the following user query, identify and list all distinct sub-queries or tasks it contains. " "Return ONLY a numbered list of sub-queries, each as a concise phrase.\n\n" f"User Query: {prompt_message}" ) split_response = llm.invoke(split_prompt) sub_queries = re.findall(r"\d+\.\s*(.+)", split_response.content) if not sub_queries: sub_queries = [prompt_message] if info: print(f"[Debug] Identified sub-queries: {sub_queries}") except Exception as e: print(f"[Error] Sub-query splitting failed: {e}") all_retrieved_chunks = [] all_graph_context_blocks = [] # --- Step 2: For each sub-query, retrieve context as decided --- for idx, sub_query in enumerate(sub_queries): task_prompt = sub_query.strip() # Optional Query Reformulation if use_query_reformulation and llm is not None: try: reformulation_prompt = ( "Reformulate the following query to focus only on the key concepts and remove any unnecessary details. " "It should be suitable for vector search in RAG retrieval:\n\n" f"Original Query: {task_prompt}" ) reformulated_response = llm.invoke(reformulation_prompt) task_prompt = reformulated_response.content.strip() if info: print(f"[Debug] Reformulated Query for sub-query {idx+1}: {task_prompt}") except Exception as e: print(f"[Error] Query reformulation failed for sub-query {idx+1}: {e}") # Embed the sub-query query_embedding = pc.inference.embed( model="llama-text-embed-v2", inputs=[task_prompt], parameters={"input_type": "query"} ) if info: print(f"[Debug] Query embedding generated for sub-query {idx+1}.") qvec = query_embedding[0].values # --- Retrieve chunks if context_choice is standard or both --- try: retrieved_chunks_raw = index.query( namespace="example-namespace", vector=qvec, top_k=top_k, include_values=False, include_metadata=True ) retrieved_chunks = [] for match in retrieved_chunks_raw.matches: text = match.metadata.get("text", "") source = match.metadata.get("source", "Unknown source") retrieved_chunks.append({ "text": text, "source": source, "sub_query": sub_query }) all_retrieved_chunks.extend(retrieved_chunks) if info: print(f"[Debug] Match processed for sub-query {idx+1}: text='{text[:50]}...', source='{source}'") except Exception as e: print(f"[Error] Standard retrieval failed for sub-query {idx+1}: {e}") # --- Retrieve community summaries if context_choice is graph or both --- if graphRAG: COMMUNITY_NAMESPACE = "community-summaries" TOP_K_SUMMARIES = 5 try: comm_matches = index.query( namespace=COMMUNITY_NAMESPACE, vector=qvec, top_k=TOP_K_SUMMARIES, include_values=False, include_metadata=True ) blocks = [] for m in comm_matches.matches: meta = m.metadata or {} txt = meta.get("text", "") cid = meta.get("community_id", "NA") level = meta.get("level", -1) size = meta.get("size", 0) block = f"[Community {cid} \n level={level} \n size={size}]\n{txt}" blocks.append(block) graph_context_str = ("\n\n---\n\n").join(blocks) all_graph_context_blocks.append((sub_query, graph_context_str)) if info: print(f"[Community] Retrieved {len(blocks)} community summaries for sub-query {idx+1}.") except Exception as e: print(f"[Error] Community summaries retrieval failed for sub-query {idx+1}: {e}") # --- Step 3: Aggregate results --- combined_graph_context = "\n\n====\n\n".join( f"Sub-query: {sub_query}\n{context}" for (sub_query, context) in all_graph_context_blocks if context ) if info: sources = {os.path.basename(chunk['source']) for chunk in all_retrieved_chunks} print(f"[Debug] Final retrieval: {len(all_retrieved_chunks)} chunks from {len(sources)} sources, " f"graph context length {len(combined_graph_context)}.") # --- Return as before --- return all_retrieved_chunks, combined_graph_context