File size: 7,202 Bytes
4be6b01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import os
import pathlib
import time
import re
from pinecone import Pinecone
from langchain_mistralai import ChatMistralAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.schema import Document
from langchain_community.document_loaders import (
CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader,
UnstructuredPowerPointLoader, UnstructuredMarkdownLoader,
UnstructuredHTMLLoader, NotebookLoader
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from llama_index.core.memory import Memory
import pickle
import json
from typing import List, Any
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from typing import List, Any
from pydantic import BaseModel, ValidationError
def retrieve_RAG(
prompt_message, pc, index, kg_index, top_k=5, info=True,
use_query_reformulation=False, llm=None, graphRAG=False,
):
"""
Retrieve relevant document chunks and community summaries from Pinecone for a given prompt.
- Optionally splits and reformulates the prompt for improved search.
- Searches both standard document chunks and, if enabled, community summaries from the knowledge graph.
- Returns all retrieved results for further use.
"""
import os
import re
if info:
print("[Debug] Starting retrieval with prompt:", prompt_message)
print("[Debug] Top K:", top_k)
print("[Debug] Query Reformulation Enabled:", use_query_reformulation)
# --- Step 0: Decide context usage (standard, graph, both) ---
def _graph_available():
try:
stats = index.describe_index_stats()
namespaces = stats.get("namespaces", {}) or {}
return "community-summaries" in namespaces
except Exception as e:
print(f"[Error] Failed to inspect index namespaces: {e}")
return False
graph_ok = bool(kg_index) or _graph_available()
# --- Step 1: Use LLM to split the prompt into sub-queries ---
sub_queries = [prompt_message] # fallback: single query
if llm is not None:
try:
split_prompt = (
"Given the following user query, identify and list all distinct sub-queries or tasks it contains. "
"Return ONLY a numbered list of sub-queries, each as a concise phrase.\n\n"
f"User Query: {prompt_message}"
)
split_response = llm.invoke(split_prompt)
sub_queries = re.findall(r"\d+\.\s*(.+)", split_response.content)
if not sub_queries:
sub_queries = [prompt_message]
if info:
print(f"[Debug] Identified sub-queries: {sub_queries}")
except Exception as e:
print(f"[Error] Sub-query splitting failed: {e}")
all_retrieved_chunks = []
all_graph_context_blocks = []
# --- Step 2: For each sub-query, retrieve context as decided ---
for idx, sub_query in enumerate(sub_queries):
task_prompt = sub_query.strip()
# Optional Query Reformulation
if use_query_reformulation and llm is not None:
try:
reformulation_prompt = (
"Reformulate the following query to focus only on the key concepts and remove any unnecessary details. "
"It should be suitable for vector search in RAG retrieval:\n\n"
f"Original Query: {task_prompt}"
)
reformulated_response = llm.invoke(reformulation_prompt)
task_prompt = reformulated_response.content.strip()
if info:
print(f"[Debug] Reformulated Query for sub-query {idx+1}: {task_prompt}")
except Exception as e:
print(f"[Error] Query reformulation failed for sub-query {idx+1}: {e}")
# Embed the sub-query
query_embedding = pc.inference.embed(
model="llama-text-embed-v2",
inputs=[task_prompt],
parameters={"input_type": "query"}
)
if info:
print(f"[Debug] Query embedding generated for sub-query {idx+1}.")
qvec = query_embedding[0].values
# --- Retrieve chunks if context_choice is standard or both ---
try:
retrieved_chunks_raw = index.query(
namespace="example-namespace",
vector=qvec,
top_k=top_k,
include_values=False,
include_metadata=True
)
retrieved_chunks = []
for match in retrieved_chunks_raw.matches:
text = match.metadata.get("text", "")
source = match.metadata.get("source", "Unknown source")
retrieved_chunks.append({
"text": text,
"source": source,
"sub_query": sub_query
})
all_retrieved_chunks.extend(retrieved_chunks)
if info:
print(f"[Debug] Match processed for sub-query {idx+1}: text='{text[:50]}...', source='{source}'")
except Exception as e:
print(f"[Error] Standard retrieval failed for sub-query {idx+1}: {e}")
# --- Retrieve community summaries if context_choice is graph or both ---
if graphRAG:
COMMUNITY_NAMESPACE = "community-summaries"
TOP_K_SUMMARIES = 5
try:
comm_matches = index.query(
namespace=COMMUNITY_NAMESPACE,
vector=qvec,
top_k=TOP_K_SUMMARIES,
include_values=False,
include_metadata=True
)
blocks = []
for m in comm_matches.matches:
meta = m.metadata or {}
txt = meta.get("text", "")
cid = meta.get("community_id", "NA")
level = meta.get("level", -1)
size = meta.get("size", 0)
block = f"[Community {cid} \n level={level} \n size={size}]\n{txt}"
blocks.append(block)
graph_context_str = ("\n\n---\n\n").join(blocks)
all_graph_context_blocks.append((sub_query, graph_context_str))
if info:
print(f"[Community] Retrieved {len(blocks)} community summaries for sub-query {idx+1}.")
except Exception as e:
print(f"[Error] Community summaries retrieval failed for sub-query {idx+1}: {e}")
# --- Step 3: Aggregate results ---
combined_graph_context = "\n\n====\n\n".join(
f"Sub-query: {sub_query}\n{context}"
for (sub_query, context) in all_graph_context_blocks if context
)
if info:
sources = {os.path.basename(chunk['source']) for chunk in all_retrieved_chunks}
print(f"[Debug] Final retrieval: {len(all_retrieved_chunks)} chunks from {len(sources)} sources, "
f"graph context length {len(combined_graph_context)}.")
# --- Return as before ---
return all_retrieved_chunks, combined_graph_context
|