File size: 7,202 Bytes
4be6b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import pathlib
import time
import re
from pinecone import Pinecone

from langchain_mistralai import ChatMistralAI
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain.schema import Document
from langchain_community.document_loaders import (
    CSVLoader, PyPDFLoader, UnstructuredWordDocumentLoader,
    UnstructuredPowerPointLoader, UnstructuredMarkdownLoader,
    UnstructuredHTMLLoader, NotebookLoader
)
from langchain_text_splitters import RecursiveCharacterTextSplitter

from llama_index.core.memory import Memory

import pickle

import json
from typing import List, Any
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage

from typing import List, Any
from pydantic import BaseModel, ValidationError





def retrieve_RAG(
    prompt_message, pc, index, kg_index, top_k=5, info=True,
    use_query_reformulation=False, llm=None, graphRAG=False,
):
    """
    Retrieve relevant document chunks and community summaries from Pinecone for a given prompt.
    - Optionally splits and reformulates the prompt for improved search.
    - Searches both standard document chunks and, if enabled, community summaries from the knowledge graph.
    - Returns all retrieved results for further use.
    """

    import os
    import re

    if info:
        print("[Debug] Starting retrieval with prompt:", prompt_message)
        print("[Debug] Top K:", top_k)
        print("[Debug] Query Reformulation Enabled:", use_query_reformulation)

    # --- Step 0: Decide context usage (standard, graph, both) ---
    def _graph_available():
        try:
            stats = index.describe_index_stats()
            namespaces = stats.get("namespaces", {}) or {}
            return "community-summaries" in namespaces
        except Exception as e:
            print(f"[Error] Failed to inspect index namespaces: {e}")
            return False

    graph_ok = bool(kg_index) or _graph_available()

    # --- Step 1: Use LLM to split the prompt into sub-queries ---
    sub_queries = [prompt_message]  # fallback: single query
    if llm is not None:
        try:
            split_prompt = (
                "Given the following user query, identify and list all distinct sub-queries or tasks it contains. "
                "Return ONLY a numbered list of sub-queries, each as a concise phrase.\n\n"
                f"User Query: {prompt_message}"
            )
            split_response = llm.invoke(split_prompt)
            sub_queries = re.findall(r"\d+\.\s*(.+)", split_response.content)
            if not sub_queries:
                sub_queries = [prompt_message]
            if info:
                print(f"[Debug] Identified sub-queries: {sub_queries}")
        except Exception as e:
            print(f"[Error] Sub-query splitting failed: {e}")

    all_retrieved_chunks = []
    all_graph_context_blocks = []

    # --- Step 2: For each sub-query, retrieve context as decided ---
    for idx, sub_query in enumerate(sub_queries):
        task_prompt = sub_query.strip()

        # Optional Query Reformulation
        if use_query_reformulation and llm is not None:
            try:
                reformulation_prompt = (
                    "Reformulate the following query to focus only on the key concepts and remove any unnecessary details. "
                    "It should be suitable for vector search in RAG retrieval:\n\n"
                    f"Original Query: {task_prompt}"
                )
                reformulated_response = llm.invoke(reformulation_prompt)
                task_prompt = reformulated_response.content.strip()
                if info:
                    print(f"[Debug] Reformulated Query for sub-query {idx+1}: {task_prompt}")
            except Exception as e:
                print(f"[Error] Query reformulation failed for sub-query {idx+1}: {e}")

        # Embed the sub-query
        query_embedding = pc.inference.embed(
            model="llama-text-embed-v2",
            inputs=[task_prompt],
            parameters={"input_type": "query"}
        )
        if info:
            print(f"[Debug] Query embedding generated for sub-query {idx+1}.")
        qvec = query_embedding[0].values

        # --- Retrieve chunks if context_choice is standard or both ---
        try:
            retrieved_chunks_raw = index.query(
                namespace="example-namespace",
                vector=qvec,
                top_k=top_k,
                include_values=False,
                include_metadata=True
            )
            retrieved_chunks = []
            for match in retrieved_chunks_raw.matches:
                text = match.metadata.get("text", "")
                source = match.metadata.get("source", "Unknown source")
                retrieved_chunks.append({
                    "text": text,
                    "source": source,
                    "sub_query": sub_query
                })
            all_retrieved_chunks.extend(retrieved_chunks)
            if info:
                print(f"[Debug] Match processed for sub-query {idx+1}: text='{text[:50]}...', source='{source}'")
        except Exception as e:
            print(f"[Error] Standard retrieval failed for sub-query {idx+1}: {e}")

        # --- Retrieve community summaries if context_choice is graph or both ---
        if graphRAG:
            COMMUNITY_NAMESPACE = "community-summaries"
            TOP_K_SUMMARIES = 5
            try:
                comm_matches = index.query(
                    namespace=COMMUNITY_NAMESPACE,
                    vector=qvec,
                    top_k=TOP_K_SUMMARIES,
                    include_values=False,
                    include_metadata=True
                )
                blocks = []
                for m in comm_matches.matches:
                    meta = m.metadata or {}
                    txt = meta.get("text", "")
                    cid = meta.get("community_id", "NA")
                    level = meta.get("level", -1)
                    size = meta.get("size", 0)
                    block = f"[Community {cid} \n level={level} \n size={size}]\n{txt}"
                    blocks.append(block)
                graph_context_str = ("\n\n---\n\n").join(blocks)
                all_graph_context_blocks.append((sub_query, graph_context_str))
                if info:
                    print(f"[Community] Retrieved {len(blocks)} community summaries for sub-query {idx+1}.")
            except Exception as e:
                print(f"[Error] Community summaries retrieval failed for sub-query {idx+1}: {e}")

    # --- Step 3: Aggregate results ---
    combined_graph_context = "\n\n====\n\n".join(
        f"Sub-query: {sub_query}\n{context}"
        for (sub_query, context) in all_graph_context_blocks if context
    )

    if info:
        sources = {os.path.basename(chunk['source']) for chunk in all_retrieved_chunks}
        print(f"[Debug] Final retrieval: {len(all_retrieved_chunks)} chunks from {len(sources)} sources, "
              f"graph context length {len(combined_graph_context)}.")

    # --- Return as before ---
    return all_retrieved_chunks, combined_graph_context