Spaces:
Sleeping
Sleeping
File size: 7,031 Bytes
14f13a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
"""
Main RAG pipeline orchestration.
Coordinates retrieval and generation for question answering.
"""
import logging
import requests
from typing import Dict, Any, Optional, List
from src.retriever import DocumentRetriever
from src.prompts import create_rag_prompt, create_no_context_prompt, format_response_with_sources
from src.config import settings
# HuggingFace router — OpenAI-compatible chat completions endpoint
_HF_API_URL = "https://router.huggingface.co/v1/chat/completions"
logger = logging.getLogger(__name__)
class RAGPipeline:
"""
Orchestrates the RAG pipeline: retrieve → generate → format.
Features:
- Smart retrieval with filtering
- LLM generation via HuggingFace Inference API
- Source attribution
- Error handling with graceful degradation
"""
def __init__(
self,
retriever: DocumentRetriever,
llm_model: Optional[str] = None,
min_similarity_score: float = 0.5
):
"""
Initialize RAG pipeline.
Args:
retriever: Document retriever instance
llm_model: Optional LLM model name override
min_similarity_score: Minimum score for relevant results
"""
self.retriever = retriever
self.llm_model = llm_model or settings.llm_model
self.min_similarity_score = min_similarity_score
self._api_url = _HF_API_URL
self._headers = {
"Authorization": f"Bearer {settings.hf_token}",
"Content-Type": "application/json",
}
logger.info(f"LLM endpoint: {self._api_url} model={self.llm_model}")
def query(
self,
question: str,
top_k: int = 5,
filter_metadata: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Process a user query through the RAG pipeline.
Args:
question: User's question
top_k: Number of chunks to retrieve
filter_metadata: Optional metadata filters
Returns:
Dictionary with answer, sources, and metadata
"""
try:
logger.info(f"Processing query: {question[:100]}...")
# Step 1: Retrieve relevant context
retrieved_chunks = self.retriever.retrieve(
query=question,
top_k=top_k,
filter_metadata=filter_metadata
)
# Log raw scores for diagnostics
scores = [round(c["score"], 4) for c in retrieved_chunks]
logger.info(f"Raw chunk scores: {scores}")
# Filter by minimum similarity score
relevant_chunks = [
chunk for chunk in retrieved_chunks
if chunk["score"] >= self.min_similarity_score
]
logger.info(f"Found {len(relevant_chunks)} relevant chunks (threshold: {self.min_similarity_score})")
# Step 2: Generate answer
if not relevant_chunks:
answer = f"I couldn't find relevant information in the {settings.docs_name} documentation to answer this question. Could you rephrase or ask about a different topic?"
return {
"answer": answer,
"sources": [],
"source_count": 0,
"confidence": "low",
"chunks_retrieved": 0
}
# Create prompt
prompt = create_rag_prompt(question, relevant_chunks)
# Generate answer
answer = self._generate_answer(prompt)
# Step 3: Format response
response = format_response_with_sources(answer, relevant_chunks)
# Add metadata
response["confidence"] = self._estimate_confidence(relevant_chunks)
response["chunks_retrieved"] = len(relevant_chunks)
logger.info("Query processed successfully")
return response
except Exception as e:
logger.error(f"Error processing query: {e}", exc_info=True)
return {
"answer": f"An error occurred while processing your question: {str(e)}",
"sources": [],
"source_count": 0,
"confidence": "error",
"chunks_retrieved": 0
}
def _generate_answer(self, prompt: str) -> str:
"""
Generate answer using LLM.
Args:
prompt: Formatted prompt with context
Returns:
Generated answer text
"""
try:
# Use OpenAI-compatible chat completions endpoint
payload = {
"model": f"{self.llm_model}:fastest",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": settings.llm_max_tokens,
"temperature": settings.llm_temperature,
"top_p": 0.9,
}
response = requests.post(
self._api_url,
headers=self._headers,
json=payload,
timeout=60
)
response.raise_for_status()
result = response.json()
answer = result["choices"][0]["message"]["content"].strip()
logger.debug(f"Generated answer ({len(answer)} chars)")
return answer
except Exception as e:
logger.error(f"LLM generation failed: {e}")
raise
def _estimate_confidence(self, chunks: List[Dict[str, Any]]) -> str:
"""
Estimate confidence based on retrieval scores.
Args:
chunks: Retrieved chunks with scores
Returns:
Confidence level: "high", "medium", or "low"
"""
if not chunks:
return "low"
avg_score = sum(chunk["score"] for chunk in chunks) / len(chunks)
if avg_score >= 0.75:
return "high"
elif avg_score >= 0.6:
return "medium"
else:
return "low"
def get_stats(self) -> Dict[str, Any]:
"""Get pipeline statistics."""
return {
"llm_model": self.llm_model,
"min_similarity_score": self.min_similarity_score,
**self.retriever.get_collection_stats()
}
def create_rag_pipeline(
retriever: Optional[DocumentRetriever] = None
) -> RAGPipeline:
"""
Factory function to create RAG pipeline.
Args:
retriever: Optional retriever override
Returns:
RAGPipeline instance
"""
from src.retriever import create_retriever
if retriever is None:
retriever = create_retriever()
return RAGPipeline(
retriever=retriever,
min_similarity_score=settings.min_similarity_score
)
|