File size: 7,031 Bytes
14f13a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
Main RAG pipeline orchestration.

Coordinates retrieval and generation for question answering.
"""
import logging
import requests
from typing import Dict, Any, Optional, List

from src.retriever import DocumentRetriever
from src.prompts import create_rag_prompt, create_no_context_prompt, format_response_with_sources
from src.config import settings

# HuggingFace router — OpenAI-compatible chat completions endpoint
_HF_API_URL = "https://router.huggingface.co/v1/chat/completions"

logger = logging.getLogger(__name__)


class RAGPipeline:
    """
    Orchestrates the RAG pipeline: retrieve → generate → format.
    
    Features:
    - Smart retrieval with filtering
    - LLM generation via HuggingFace Inference API
    - Source attribution
    - Error handling with graceful degradation
    """
    
    def __init__(
        self,
        retriever: DocumentRetriever,
        llm_model: Optional[str] = None,
        min_similarity_score: float = 0.5
    ):
        """
        Initialize RAG pipeline.
        
        Args:
            retriever: Document retriever instance
            llm_model: Optional LLM model name override
            min_similarity_score: Minimum score for relevant results
        """
        self.retriever = retriever
        self.llm_model = llm_model or settings.llm_model
        self.min_similarity_score = min_similarity_score
        
        self._api_url = _HF_API_URL
        self._headers = {
            "Authorization": f"Bearer {settings.hf_token}",
            "Content-Type": "application/json",
        }
        logger.info(f"LLM endpoint: {self._api_url} model={self.llm_model}")
    
    def query(
        self,
        question: str,
        top_k: int = 5,
        filter_metadata: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """
        Process a user query through the RAG pipeline.
        
        Args:
            question: User's question
            top_k: Number of chunks to retrieve
            filter_metadata: Optional metadata filters
            
        Returns:
            Dictionary with answer, sources, and metadata
        """
        try:
            logger.info(f"Processing query: {question[:100]}...")
            
            # Step 1: Retrieve relevant context
            retrieved_chunks = self.retriever.retrieve(
                query=question,
                top_k=top_k,
                filter_metadata=filter_metadata
            )
            
            # Log raw scores for diagnostics
            scores = [round(c["score"], 4) for c in retrieved_chunks]
            logger.info(f"Raw chunk scores: {scores}")

            # Filter by minimum similarity score
            relevant_chunks = [
                chunk for chunk in retrieved_chunks
                if chunk["score"] >= self.min_similarity_score
            ]

            logger.info(f"Found {len(relevant_chunks)} relevant chunks (threshold: {self.min_similarity_score})")
            
            # Step 2: Generate answer
            if not relevant_chunks:
                answer = f"I couldn't find relevant information in the {settings.docs_name} documentation to answer this question. Could you rephrase or ask about a different topic?"
                return {
                    "answer": answer,
                    "sources": [],
                    "source_count": 0,
                    "confidence": "low",
                    "chunks_retrieved": 0
                }
            
            # Create prompt
            prompt = create_rag_prompt(question, relevant_chunks)
            
            # Generate answer
            answer = self._generate_answer(prompt)
            
            # Step 3: Format response
            response = format_response_with_sources(answer, relevant_chunks)
            
            # Add metadata
            response["confidence"] = self._estimate_confidence(relevant_chunks)
            response["chunks_retrieved"] = len(relevant_chunks)
            
            logger.info("Query processed successfully")
            return response
            
        except Exception as e:
            logger.error(f"Error processing query: {e}", exc_info=True)
            return {
                "answer": f"An error occurred while processing your question: {str(e)}",
                "sources": [],
                "source_count": 0,
                "confidence": "error",
                "chunks_retrieved": 0
            }
    
    def _generate_answer(self, prompt: str) -> str:
        """
        Generate answer using LLM.
        
        Args:
            prompt: Formatted prompt with context
            
        Returns:
            Generated answer text
        """
        try:
            # Use OpenAI-compatible chat completions endpoint
            payload = {
                "model": f"{self.llm_model}:fastest",
                "messages": [{"role": "user", "content": prompt}],
                "max_tokens": settings.llm_max_tokens,
                "temperature": settings.llm_temperature,
                "top_p": 0.9,
            }
            response = requests.post(
                self._api_url,
                headers=self._headers,
                json=payload,
                timeout=60
            )
            response.raise_for_status()
            result = response.json()
            answer = result["choices"][0]["message"]["content"].strip()
            logger.debug(f"Generated answer ({len(answer)} chars)")

            return answer
            
        except Exception as e:
            logger.error(f"LLM generation failed: {e}")
            raise
    
    def _estimate_confidence(self, chunks: List[Dict[str, Any]]) -> str:
        """
        Estimate confidence based on retrieval scores.
        
        Args:
            chunks: Retrieved chunks with scores
            
        Returns:
            Confidence level: "high", "medium", or "low"
        """
        if not chunks:
            return "low"
        
        avg_score = sum(chunk["score"] for chunk in chunks) / len(chunks)
        
        if avg_score >= 0.75:
            return "high"
        elif avg_score >= 0.6:
            return "medium"
        else:
            return "low"
    
    def get_stats(self) -> Dict[str, Any]:
        """Get pipeline statistics."""
        return {
            "llm_model": self.llm_model,
            "min_similarity_score": self.min_similarity_score,
            **self.retriever.get_collection_stats()
        }


def create_rag_pipeline(
    retriever: Optional[DocumentRetriever] = None
) -> RAGPipeline:
    """
    Factory function to create RAG pipeline.
    
    Args:
        retriever: Optional retriever override
        
    Returns:
        RAGPipeline instance
    """
    from src.retriever import create_retriever
    
    if retriever is None:
        retriever = create_retriever()
    
    return RAGPipeline(
        retriever=retriever,
        min_similarity_score=settings.min_similarity_score
    )