Spaces:
Running
Running
| # DEPENDENCIES | |
| import json | |
| from typing import Dict | |
| from config.models import LLMProvider | |
| from config.settings import get_settings | |
| from config.logging_config import get_logger | |
| from generation.llm_client import get_llm_client | |
| # Setup Settings and Logging | |
| settings = get_settings() | |
| logger = get_logger(__name__) | |
| class QueryClassifier: | |
| """ | |
| LLM-based query classifier that intelligently routes queries to: | |
| 1. General/Conversational (no document context needed) | |
| 2. RAG/Document-based (needs retrieval from documents) | |
| Uses the LLM itself for classification instead of hardcoded patterns. | |
| """ | |
| def __init__(self, provider: LLMProvider = None, model_name: str = None): | |
| self.logger = logger | |
| self.provider = provider or LLMProvider.OLLAMA | |
| self.model_name = model_name or settings.OLLAMA_MODEL | |
| # Initialize LLM client for classification | |
| self.llm_client = get_llm_client(provider = self.provider, | |
| model_name = self.model_name, | |
| ) | |
| # Classification prompt | |
| self.system_prompt = """ | |
| You are a query classification system for a RAG (Retrieval-Augmented Generation) system. | |
| Your job is to determine if a user query should be answered using the user's uploaded documents. | |
| **IMPORTANT CONTEXT**: The user has uploaded documents to the system. All queries related to the content of those uploaded documents should use RAG. | |
| Classify queries into TWO categories: | |
| **RAG (Document-based)** - Use when ANY of these are true: | |
| 1. Query asks about ANY content that could be in the uploaded documents | |
| 2. Query asks factual questions that could be answered from document content | |
| 3. Query asks for lists, summaries, or analysis of information | |
| 4. Query mentions specific details, data, statistics, names, dates, or facts | |
| 5. Query asks "what", "how", "list", "explain", "summarize", "compare", "analyze" about any topic | |
| 6. Query could reasonably be answered by searching through documents | |
| 7. **CRITICAL**: When documents are uploaded, DEFAULT TO RAG for most factual/content queries | |
| **GENERAL (Conversational)** - Use ONLY when MOST of these are true: | |
| 1. Query is purely conversational (greetings, thanks, casual chat) | |
| 2. Query asks about the RAG system itself or its functionality | |
| 3. Query asks for general knowledge that is NOT specific to uploaded documents | |
| 4. Query is a meta-question about how to use the system | |
| 5. Query contains NO request for factual information from documents | |
| **EXAMPLES FOR ANY DOCUMENT TYPE**: | |
| - For business documents: "What sales channels does the company use?" → RAG | |
| - For research papers: "What were the study's findings?" → RAG | |
| - For legal documents: "What are the key clauses?" → RAG | |
| - For technical manuals: "How do I configure the system?" → RAG | |
| - For personal documents: "What dates are mentioned?" → RAG | |
| - "Hi, how are you?" → GENERAL | |
| - "How do I upload a document?" → GENERAL | |
| - "What is the capital of France?" → GENERAL (unless geography documents were uploaded) | |
| **KEY RULES**: | |
| 1. When documents exist, assume queries are about them unless clearly not | |
| 2. When in doubt, classify as RAG (safer to search than hallucinate) | |
| 3. If query could be answered from document content, use RAG | |
| 4. Only use GENERAL for purely conversational or system-related queries | |
| Respond with ONLY a JSON object (no markdown, no extra text): | |
| { | |
| "type": "rag" or "general", | |
| "confidence": 0.0 to 1.0, | |
| "reason": "brief explanation" | |
| } | |
| """ | |
| async def classify(self, query: str, has_documents: bool = True) -> Dict: | |
| """ | |
| Classify a query using LLM | |
| Arguments: | |
| ---------- | |
| query { str } : User query | |
| has_documents { bool } : Whether documents are available in the system | |
| Returns: | |
| -------- | |
| { dict } : Classification result | |
| """ | |
| try: | |
| # If no documents are available, everything should be general | |
| if not has_documents: | |
| return {"type" : "general", | |
| "confidence" : 1.0, | |
| "reason" : "No documents available in system", | |
| "suggested_action" : "respond_with_general_llm", | |
| "is_llm_classified" : False, | |
| } | |
| # Build classification prompt | |
| user_prompt = f""" | |
| Query: "{query}" | |
| System status: {"Documents are available" if has_documents else "No documents uploaded"} | |
| Classify this query. Remember: if uncertain, prefer RAG. | |
| """ | |
| messages = [{"role" : "system", | |
| "content" : self.system_prompt, | |
| }, | |
| {"role" : "user", | |
| "content" : user_prompt, | |
| } | |
| ] | |
| # Get LLM classification (use low temperature for consistency) | |
| llm_response = await self.llm_client.generate(messages = messages, | |
| temperature = 0.1, # Low temperature for consistent classification | |
| max_tokens = 150, | |
| ) | |
| response_text = llm_response.get("content", "").strip() | |
| # Parse JSON response | |
| classification = self._parse_llm_response(response_text) | |
| # Add suggested action based on classification | |
| if (classification["type"] == "rag"): | |
| classification["suggested_action"] = "respond_with_rag" | |
| elif (classification["type"] == "general"): | |
| classification["suggested_action"] = "respond_with_general_llm" | |
| else: | |
| # Default to RAG if uncertain | |
| classification["suggested_action"] = "respond_with_rag" | |
| classification["is_llm_classified"] = True | |
| logger.info(f"LLM classified query as: {classification['type']} (confidence: {classification['confidence']:.2f})") | |
| logger.debug(f"Classification reason: {classification['reason']}") | |
| return classification | |
| except Exception as e: | |
| logger.error(f"LLM classification failed: {e}, defaulting to RAG") | |
| # On error, default to RAG (safer to try document search) | |
| return {"type" : "rag", | |
| "confidence" : 0.5, | |
| "reason" : f"Classification failed: {str(e)}, defaulting to RAG", | |
| "suggested_action" : "respond_with_rag", | |
| "is_llm_classified" : False, | |
| "error" : str(e) | |
| } | |
| def _parse_llm_response(self, response_text: str) -> Dict: | |
| """ | |
| Parse LLM JSON response | |
| Arguments: | |
| ---------- | |
| response_text { str } : LLM response text | |
| Returns: | |
| -------- | |
| { dict } : Parsed classification | |
| """ | |
| try: | |
| # Remove markdown code blocks if present | |
| if ("```json" in response_text): | |
| response_text = response_text.split("```json")[1].split("```")[0].strip() | |
| elif ("```" in response_text): | |
| response_text = response_text.split("```")[1].split("```")[0].strip() | |
| # Parse JSON | |
| result = json.loads(response_text) | |
| # Validate required fields | |
| if ("type" not in result) or (result["type"] not in ["rag", "general"]): | |
| raise ValueError(f"Invalid type in response: {result.get('type')}") | |
| # Set defaults for missing fields | |
| result.setdefault("confidence", 0.8) | |
| result.setdefault("reason", "LLM classification") | |
| # Clamp confidence to valid range | |
| result["confidence"] = max(0.0, min(1.0, float(result["confidence"]))) | |
| return result | |
| except (json.JSONDecodeError, ValueError, KeyError) as e: | |
| logger.warning(f"Failed to parse LLM response: {e}") | |
| logger.debug(f"Raw response: {response_text}") | |
| # Try to extract type from text if JSON parsing fails | |
| response_lower = response_text.lower() | |
| if (("general" in response_lower) and ("rag" not in response_lower)): | |
| return {"type" : "general", | |
| "confidence" : 0.6, | |
| "reason" : "Parsed from non-JSON response", | |
| } | |
| else: | |
| # Default to RAG if parsing fails | |
| return {"type" : "rag", | |
| "confidence" : 0.6, | |
| "reason" : "Failed to parse response, defaulting to RAG", | |
| } | |
| # Global classifier instance | |
| _query_classifier = None | |
| def get_query_classifier(provider: LLMProvider = None, model_name: str = None) -> QueryClassifier: | |
| """ | |
| Get global query classifier instance | |
| Arguments: | |
| ---------- | |
| provider { LLMProvider } : LLM provider | |
| model_name { str } : Model name | |
| Returns: | |
| -------- | |
| { QueryClassifier } : QueryClassifier instance | |
| """ | |
| global _query_classifier | |
| if _query_classifier is None: | |
| _query_classifier = QueryClassifier(provider = provider, | |
| model_name = model_name, | |
| ) | |
| return _query_classifier |