Nadezhda Komarova
first commit
4be6b01
import os
import pathlib
import time
import re
from pinecone import Pinecone
from langchain_mistralai import ChatMistralAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.schema import Document
from langchain_community.document_loaders import (
CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader,
UnstructuredPowerPointLoader, UnstructuredMarkdownLoader,
UnstructuredHTMLLoader, NotebookLoader
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from llama_index.core.memory import Memory
import pickle
import json
from typing import List, Any
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from typing import List, Any
from pydantic import BaseModel, ValidationError
def retrieve_RAG(
prompt_message, pc, index, kg_index, top_k=5, info=True,
use_query_reformulation=False, llm=None, graphRAG=False,
):
"""
Retrieve relevant document chunks and community summaries from Pinecone for a given prompt.
- Optionally splits and reformulates the prompt for improved search.
- Searches both standard document chunks and, if enabled, community summaries from the knowledge graph.
- Returns all retrieved results for further use.
"""
import os
import re
if info:
print("[Debug] Starting retrieval with prompt:", prompt_message)
print("[Debug] Top K:", top_k)
print("[Debug] Query Reformulation Enabled:", use_query_reformulation)
# --- Step 0: Decide context usage (standard, graph, both) ---
def _graph_available():
try:
stats = index.describe_index_stats()
namespaces = stats.get("namespaces", {}) or {}
return "community-summaries" in namespaces
except Exception as e:
print(f"[Error] Failed to inspect index namespaces: {e}")
return False
graph_ok = bool(kg_index) or _graph_available()
# --- Step 1: Use LLM to split the prompt into sub-queries ---
sub_queries = [prompt_message] # fallback: single query
if llm is not None:
try:
split_prompt = (
"Given the following user query, identify and list all distinct sub-queries or tasks it contains. "
"Return ONLY a numbered list of sub-queries, each as a concise phrase.\n\n"
f"User Query: {prompt_message}"
)
split_response = llm.invoke(split_prompt)
sub_queries = re.findall(r"\d+\.\s*(.+)", split_response.content)
if not sub_queries:
sub_queries = [prompt_message]
if info:
print(f"[Debug] Identified sub-queries: {sub_queries}")
except Exception as e:
print(f"[Error] Sub-query splitting failed: {e}")
all_retrieved_chunks = []
all_graph_context_blocks = []
# --- Step 2: For each sub-query, retrieve context as decided ---
for idx, sub_query in enumerate(sub_queries):
task_prompt = sub_query.strip()
# Optional Query Reformulation
if use_query_reformulation and llm is not None:
try:
reformulation_prompt = (
"Reformulate the following query to focus only on the key concepts and remove any unnecessary details. "
"It should be suitable for vector search in RAG retrieval:\n\n"
f"Original Query: {task_prompt}"
)
reformulated_response = llm.invoke(reformulation_prompt)
task_prompt = reformulated_response.content.strip()
if info:
print(f"[Debug] Reformulated Query for sub-query {idx+1}: {task_prompt}")
except Exception as e:
print(f"[Error] Query reformulation failed for sub-query {idx+1}: {e}")
# Embed the sub-query
query_embedding = pc.inference.embed(
model="llama-text-embed-v2",
inputs=[task_prompt],
parameters={"input_type": "query"}
)
if info:
print(f"[Debug] Query embedding generated for sub-query {idx+1}.")
qvec = query_embedding[0].values
# --- Retrieve chunks if context_choice is standard or both ---
try:
retrieved_chunks_raw = index.query(
namespace="example-namespace",
vector=qvec,
top_k=top_k,
include_values=False,
include_metadata=True
)
retrieved_chunks = []
for match in retrieved_chunks_raw.matches:
text = match.metadata.get("text", "")
source = match.metadata.get("source", "Unknown source")
retrieved_chunks.append({
"text": text,
"source": source,
"sub_query": sub_query
})
all_retrieved_chunks.extend(retrieved_chunks)
if info:
print(f"[Debug] Match processed for sub-query {idx+1}: text='{text[:50]}...', source='{source}'")
except Exception as e:
print(f"[Error] Standard retrieval failed for sub-query {idx+1}: {e}")
# --- Retrieve community summaries if context_choice is graph or both ---
if graphRAG:
COMMUNITY_NAMESPACE = "community-summaries"
TOP_K_SUMMARIES = 5
try:
comm_matches = index.query(
namespace=COMMUNITY_NAMESPACE,
vector=qvec,
top_k=TOP_K_SUMMARIES,
include_values=False,
include_metadata=True
)
blocks = []
for m in comm_matches.matches:
meta = m.metadata or {}
txt = meta.get("text", "")
cid = meta.get("community_id", "NA")
level = meta.get("level", -1)
size = meta.get("size", 0)
block = f"[Community {cid} \n level={level} \n size={size}]\n{txt}"
blocks.append(block)
graph_context_str = ("\n\n---\n\n").join(blocks)
all_graph_context_blocks.append((sub_query, graph_context_str))
if info:
print(f"[Community] Retrieved {len(blocks)} community summaries for sub-query {idx+1}.")
except Exception as e:
print(f"[Error] Community summaries retrieval failed for sub-query {idx+1}: {e}")
# --- Step 3: Aggregate results ---
combined_graph_context = "\n\n====\n\n".join(
f"Sub-query: {sub_query}\n{context}"
for (sub_query, context) in all_graph_context_blocks if context
)
if info:
sources = {os.path.basename(chunk['source']) for chunk in all_retrieved_chunks}
print(f"[Debug] Final retrieval: {len(all_retrieved_chunks)} chunks from {len(sources)} sources, "
f"graph context length {len(combined_graph_context)}.")
# --- Return as before ---
return all_retrieved_chunks, combined_graph_context