Spaces:
Running
Running
File size: 11,362 Bytes
0a4529c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
# DEPENDENCIES
import json
from typing import Dict
from config.models import LLMProvider
from config.settings import get_settings
from config.logging_config import get_logger
from generation.llm_client import get_llm_client
# Setup Settings and Logging
settings = get_settings()
logger = get_logger(__name__)
class QueryClassifier:
"""
LLM-based query classifier that intelligently routes queries to:
1. General/Conversational (no document context needed)
2. RAG/Document-based (needs retrieval from documents)
Uses the LLM itself for classification instead of hardcoded patterns.
"""
def __init__(self, provider: LLMProvider = None, model_name: str = None):
self.logger = logger
self.provider = provider or LLMProvider.OLLAMA
self.model_name = model_name or settings.OLLAMA_MODEL
# Initialize LLM client for classification
self.llm_client = get_llm_client(provider = self.provider,
model_name = self.model_name,
)
# Classification prompt
self.system_prompt = """
You are a query classification system for a RAG (Retrieval-Augmented Generation) system.
Your job is to determine if a user query should be answered using the user's uploaded documents.
**IMPORTANT CONTEXT**: The user has uploaded documents to the system. All queries related to the content of those uploaded documents should use RAG.
Classify queries into TWO categories:
**RAG (Document-based)** - Use when ANY of these are true:
1. Query asks about ANY content that could be in the uploaded documents
2. Query asks factual questions that could be answered from document content
3. Query asks for lists, summaries, or analysis of information
4. Query mentions specific details, data, statistics, names, dates, or facts
5. Query asks "what", "how", "list", "explain", "summarize", "compare", "analyze" about any topic
6. Query could reasonably be answered by searching through documents
7. **CRITICAL**: When documents are uploaded, DEFAULT TO RAG for most factual/content queries
**GENERAL (Conversational)** - Use ONLY when MOST of these are true:
1. Query is purely conversational (greetings, thanks, casual chat)
2. Query asks about the RAG system itself or its functionality
3. Query asks for general knowledge that is NOT specific to uploaded documents
4. Query is a meta-question about how to use the system
5. Query contains NO request for factual information from documents
**EXAMPLES FOR ANY DOCUMENT TYPE**:
- For business documents: "What sales channels does the company use?" → RAG
- For research papers: "What were the study's findings?" → RAG
- For legal documents: "What are the key clauses?" → RAG
- For technical manuals: "How do I configure the system?" → RAG
- For personal documents: "What dates are mentioned?" → RAG
- "Hi, how are you?" → GENERAL
- "How do I upload a document?" → GENERAL
- "What is the capital of France?" → GENERAL (unless geography documents were uploaded)
**KEY RULES**:
1. When documents exist, assume queries are about them unless clearly not
2. When in doubt, classify as RAG (safer to search than hallucinate)
3. If query could be answered from document content, use RAG
4. Only use GENERAL for purely conversational or system-related queries
Respond with ONLY a JSON object (no markdown, no extra text):
{
"type": "rag" or "general",
"confidence": 0.0 to 1.0,
"reason": "brief explanation"
}
"""
async def classify(self, query: str, has_documents: bool = True) -> Dict:
"""
Classify a query using LLM
Arguments:
----------
query { str } : User query
has_documents { bool } : Whether documents are available in the system
Returns:
--------
{ dict } : Classification result
"""
try:
# If no documents are available, everything should be general
if not has_documents:
return {"type" : "general",
"confidence" : 1.0,
"reason" : "No documents available in system",
"suggested_action" : "respond_with_general_llm",
"is_llm_classified" : False,
}
# Build classification prompt
user_prompt = f"""
Query: "{query}"
System status: {"Documents are available" if has_documents else "No documents uploaded"}
Classify this query. Remember: if uncertain, prefer RAG.
"""
messages = [{"role" : "system",
"content" : self.system_prompt,
},
{"role" : "user",
"content" : user_prompt,
}
]
# Get LLM classification (use low temperature for consistency)
llm_response = await self.llm_client.generate(messages = messages,
temperature = 0.1, # Low temperature for consistent classification
max_tokens = 150,
)
response_text = llm_response.get("content", "").strip()
# Parse JSON response
classification = self._parse_llm_response(response_text)
# Add suggested action based on classification
if (classification["type"] == "rag"):
classification["suggested_action"] = "respond_with_rag"
elif (classification["type"] == "general"):
classification["suggested_action"] = "respond_with_general_llm"
else:
# Default to RAG if uncertain
classification["suggested_action"] = "respond_with_rag"
classification["is_llm_classified"] = True
logger.info(f"LLM classified query as: {classification['type']} (confidence: {classification['confidence']:.2f})")
logger.debug(f"Classification reason: {classification['reason']}")
return classification
except Exception as e:
logger.error(f"LLM classification failed: {e}, defaulting to RAG")
# On error, default to RAG (safer to try document search)
return {"type" : "rag",
"confidence" : 0.5,
"reason" : f"Classification failed: {str(e)}, defaulting to RAG",
"suggested_action" : "respond_with_rag",
"is_llm_classified" : False,
"error" : str(e)
}
def _parse_llm_response(self, response_text: str) -> Dict:
"""
Parse LLM JSON response
Arguments:
----------
response_text { str } : LLM response text
Returns:
--------
{ dict } : Parsed classification
"""
try:
# Remove markdown code blocks if present
if ("```json" in response_text):
response_text = response_text.split("```json")[1].split("```")[0].strip()
elif ("```" in response_text):
response_text = response_text.split("```")[1].split("```")[0].strip()
# Parse JSON
result = json.loads(response_text)
# Validate required fields
if ("type" not in result) or (result["type"] not in ["rag", "general"]):
raise ValueError(f"Invalid type in response: {result.get('type')}")
# Set defaults for missing fields
result.setdefault("confidence", 0.8)
result.setdefault("reason", "LLM classification")
# Clamp confidence to valid range
result["confidence"] = max(0.0, min(1.0, float(result["confidence"])))
return result
except (json.JSONDecodeError, ValueError, KeyError) as e:
logger.warning(f"Failed to parse LLM response: {e}")
logger.debug(f"Raw response: {response_text}")
# Try to extract type from text if JSON parsing fails
response_lower = response_text.lower()
if (("general" in response_lower) and ("rag" not in response_lower)):
return {"type" : "general",
"confidence" : 0.6,
"reason" : "Parsed from non-JSON response",
}
else:
# Default to RAG if parsing fails
return {"type" : "rag",
"confidence" : 0.6,
"reason" : "Failed to parse response, defaulting to RAG",
}
# Global classifier instance
_query_classifier = None
def get_query_classifier(provider: LLMProvider = None, model_name: str = None) -> QueryClassifier:
"""
Get global query classifier instance
Arguments:
----------
provider { LLMProvider } : LLM provider
model_name { str } : Model name
Returns:
--------
{ QueryClassifier } : QueryClassifier instance
"""
global _query_classifier
if _query_classifier is None:
_query_classifier = QueryClassifier(provider = provider,
model_name = model_name,
)
return _query_classifier |