File size: 10,223 Bytes
6b98b09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
Advanced RAG techniques for improved retrieval and generation
Includes: Query Expansion, Reranking, Contextual Compression, Hybrid Search
"""

from typing import List, Dict, Optional, Tuple
import numpy as np
from dataclasses import dataclass
import re


@dataclass
class RetrievedDocument:
    """Document retrieved from vector database"""
    id: str
    text: str
    confidence: float
    metadata: Dict


class AdvancedRAG:
    """Advanced RAG system with modern techniques"""

    def __init__(self, embedding_service, qdrant_service):
        self.embedding_service = embedding_service
        self.qdrant_service = qdrant_service

    def expand_query(self, query: str) -> List[str]:
        """
        Expand query with related terms and variations
        Simple rule-based expansion for Vietnamese queries
        """
        queries = [query]

        # Add query variations
        # Remove question words for alternative search
        question_words = ['ai', 'gì', 'nào', 'đâu', 'khi nào', 'như thế nào',
                         'tại sao', 'có', 'là', 'được', 'không']

        query_lower = query.lower()
        for qw in question_words:
            if qw in query_lower:
                variant = query_lower.replace(qw, '').strip()
                if variant and variant != query_lower:
                    queries.append(variant)

        # Extract key nouns/phrases (simple approach)
        words = query.split()
        if len(words) > 3:
            # Take important words (skip first question word)
            key_phrases = ' '.join(words[1:]) if words[0].lower() in question_words else ' '.join(words[:3])
            if key_phrases not in queries:
                queries.append(key_phrases)

        return queries[:3]  # Return top 3 variations

    def multi_query_retrieval(
        self,
        query: str,
        top_k: int = 5,
        score_threshold: float = 0.5
    ) -> List[RetrievedDocument]:
        """
        Retrieve documents using multiple query variations
        Combines results from all query variations
        """
        expanded_queries = self.expand_query(query)

        all_results = {}  # Use dict to deduplicate by doc_id

        for q in expanded_queries:
            # Generate embedding for each query variant
            query_embedding = self.embedding_service.encode_text(q)

            # Search in Qdrant
            results = self.qdrant_service.search(
                query_embedding=query_embedding,
                limit=top_k,
                score_threshold=score_threshold
            )

            # Add to results (keep highest score for duplicates)
            for result in results:
                doc_id = result["id"]
                if doc_id not in all_results or result["confidence"] > all_results[doc_id].confidence:
                    all_results[doc_id] = RetrievedDocument(
                        id=doc_id,
                        text=result["metadata"].get("text", ""),
                        confidence=result["confidence"],
                        metadata=result["metadata"]
                    )

        # Sort by confidence and return top_k
        sorted_results = sorted(all_results.values(), key=lambda x: x.confidence, reverse=True)
        return sorted_results[:top_k]

    def rerank_documents(
        self,
        query: str,
        documents: List[RetrievedDocument],
        use_cross_encoder: bool = False
    ) -> List[RetrievedDocument]:
        """
        Rerank documents based on semantic similarity
        Simple reranking using embedding similarity (can be upgraded to cross-encoder)
        """
        if not documents:
            return documents

        # Simple reranking: recalculate similarity with original query
        query_embedding = self.embedding_service.encode_text(query)

        reranked = []
        for doc in documents:
            # Get document embedding
            doc_embedding = self.embedding_service.encode_text(doc.text)

            # Calculate cosine similarity
            similarity = np.dot(query_embedding.flatten(), doc_embedding.flatten())

            # Combine with original confidence (weighted average)
            new_score = 0.6 * similarity + 0.4 * doc.confidence

            reranked.append(RetrievedDocument(
                id=doc.id,
                text=doc.text,
                confidence=float(new_score),
                metadata=doc.metadata
            ))

        # Sort by new score
        reranked.sort(key=lambda x: x.confidence, reverse=True)
        return reranked

    def compress_context(
        self,
        query: str,
        documents: List[RetrievedDocument],
        max_tokens: int = 500
    ) -> List[RetrievedDocument]:
        """
        Compress context to most relevant parts
        Remove redundant information and keep only relevant sentences
        """
        compressed_docs = []

        for doc in documents:
            # Split into sentences
            sentences = self._split_sentences(doc.text)

            # Score each sentence based on relevance to query
            scored_sentences = []
            query_words = set(query.lower().split())

            for sent in sentences:
                sent_words = set(sent.lower().split())
                # Simple relevance: word overlap
                overlap = len(query_words & sent_words)
                if overlap > 0:
                    scored_sentences.append((sent, overlap))

            # Sort by relevance and take top sentences
            scored_sentences.sort(key=lambda x: x[1], reverse=True)

            # Reconstruct compressed text (up to max_tokens)
            compressed_text = ""
            word_count = 0
            for sent, score in scored_sentences:
                sent_words = len(sent.split())
                if word_count + sent_words <= max_tokens:
                    compressed_text += sent + " "
                    word_count += sent_words
                else:
                    break

            # If nothing selected, take original first part
            if not compressed_text.strip():
                compressed_text = doc.text[:max_tokens * 5]  # Rough estimate

            compressed_docs.append(RetrievedDocument(
                id=doc.id,
                text=compressed_text.strip(),
                confidence=doc.confidence,
                metadata=doc.metadata
            ))

        return compressed_docs

    def _split_sentences(self, text: str) -> List[str]:
        """Split text into sentences (Vietnamese-aware)"""
        # Simple sentence splitter
        sentences = re.split(r'[.!?]+', text)
        return [s.strip() for s in sentences if s.strip()]

    def hybrid_rag_pipeline(
        self,
        query: str,
        top_k: int = 5,
        score_threshold: float = 0.5,
        use_reranking: bool = True,
        use_compression: bool = True,
        max_context_tokens: int = 500
    ) -> Tuple[List[RetrievedDocument], Dict]:
        """
        Complete advanced RAG pipeline
        1. Multi-query retrieval
        2. Reranking
        3. Contextual compression
        """
        stats = {
            "original_query": query,
            "expanded_queries": [],
            "initial_results": 0,
            "after_rerank": 0,
            "after_compression": 0
        }

        # Step 1: Multi-query retrieval
        expanded_queries = self.expand_query(query)
        stats["expanded_queries"] = expanded_queries

        documents = self.multi_query_retrieval(
            query=query,
            top_k=top_k * 2,  # Get more candidates for reranking
            score_threshold=score_threshold
        )
        stats["initial_results"] = len(documents)

        # Step 2: Reranking (optional)
        if use_reranking and documents:
            documents = self.rerank_documents(query, documents)
            documents = documents[:top_k]  # Keep top_k after reranking
        stats["after_rerank"] = len(documents)

        # Step 3: Contextual compression (optional)
        if use_compression and documents:
            documents = self.compress_context(
                query=query,
                documents=documents,
                max_tokens=max_context_tokens
            )
        stats["after_compression"] = len(documents)

        return documents, stats

    def format_context_for_llm(
        self,
        documents: List[RetrievedDocument],
        include_metadata: bool = True
    ) -> str:
        """
        Format retrieved documents into context string for LLM
        Uses better structure for improved LLM understanding
        """
        if not documents:
            return ""

        context_parts = ["RELEVANT CONTEXT:\n"]

        for i, doc in enumerate(documents, 1):
            context_parts.append(f"\n--- Document {i} (Relevance: {doc.confidence:.2%}) ---")
            context_parts.append(doc.text)

            if include_metadata and doc.metadata:
                # Add useful metadata
                meta_str = []
                for key, value in doc.metadata.items():
                    if key not in ['text', 'texts'] and value:
                        meta_str.append(f"{key}: {value}")
                if meta_str:
                    context_parts.append(f"[Metadata: {', '.join(meta_str)}]")

        context_parts.append("\n--- End of Context ---\n")
        return "\n".join(context_parts)

    def build_rag_prompt(
        self,
        query: str,
        context: str,
        system_message: str = "You are a helpful AI assistant."
    ) -> str:
        """
        Build optimized RAG prompt for LLM
        Uses best practices for prompt engineering
        """
        prompt_template = f"""{system_message}

{context}

INSTRUCTIONS:
1. Answer the user's question using ONLY the information provided in the context above
2. If the context doesn't contain relevant information, say "Tôi không tìm thấy thông tin liên quan trong dữ liệu."
3. Cite relevant parts of the context when answering
4. Be concise and accurate
5. Answer in Vietnamese if the question is in Vietnamese

USER QUESTION: {query}

YOUR ANSWER:"""

        return prompt_template