Spaces:
Sleeping
Sleeping
| """ | |
| SQL Retriever for RAG System | |
| Intelligent retrieval of relevant SQL examples based on question similarity and table schema analysis. | |
| """ | |
| import re | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from collections import defaultdict | |
| import numpy as np | |
| from loguru import logger | |
| from .vector_store import VectorStore | |
| class SQLRetriever: | |
| """Intelligent SQL example retriever with schema-aware filtering.""" | |
| def __init__(self, vector_store: VectorStore): | |
| """ | |
| Initialize the SQL retriever. | |
| Args: | |
| vector_store: Initialized vector store instance | |
| """ | |
| self.vector_store = vector_store | |
| self.schema_cache = {} # Cache for table schema analysis | |
| def retrieve_examples(self, | |
| question: str, | |
| table_headers: List[str], | |
| top_k: int = 5, | |
| use_schema_filtering: bool = True) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve relevant SQL examples using multiple retrieval strategies. | |
| Args: | |
| question: Natural language question | |
| table_headers: List of table column names | |
| top_k: Number of examples to retrieve | |
| use_schema_filtering: Whether to use schema-aware filtering | |
| Returns: | |
| List of retrieved examples with relevance scores | |
| """ | |
| # Strategy 1: Vector similarity search | |
| vector_results = self.vector_store.search_similar( | |
| query=question, | |
| table_headers=table_headers, | |
| top_k=top_k * 2, # Get more for filtering | |
| similarity_threshold=0.6 | |
| ) | |
| if not vector_results: | |
| logger.warning("No vector search results found") | |
| return [] | |
| # Strategy 2: Schema-aware filtering and ranking | |
| if use_schema_filtering: | |
| filtered_results = self._apply_schema_filtering( | |
| vector_results, question, table_headers | |
| ) | |
| else: | |
| filtered_results = vector_results | |
| # Strategy 3: Question type classification and boosting | |
| enhanced_results = self._enhance_with_question_analysis( | |
| filtered_results, question, table_headers | |
| ) | |
| # Strategy 4: Final ranking and selection | |
| final_results = self._final_ranking( | |
| enhanced_results, question, table_headers, top_k | |
| ) | |
| logger.info(f"Retrieved {len(final_results)} relevant examples") | |
| return final_results | |
| def _apply_schema_filtering(self, | |
| results: List[Dict[str, Any]], | |
| question: str, | |
| table_headers: List[str]) -> List[Dict[str, Any]]: | |
| """Apply schema-aware filtering to improve relevance.""" | |
| filtered_results = [] | |
| # Analyze current table schema | |
| current_schema = self._analyze_schema(table_headers) | |
| for result in results: | |
| # Analyze example table schema | |
| example_headers = result["table_headers"] | |
| if isinstance(example_headers, str): | |
| example_headers = [h.strip() for h in example_headers.split(",")] | |
| example_schema = self._analyze_schema(example_headers) | |
| # Calculate schema similarity | |
| schema_similarity = self._calculate_schema_similarity( | |
| current_schema, example_schema | |
| ) | |
| # Boost score based on schema similarity | |
| result["schema_similarity"] = schema_similarity | |
| result["enhanced_score"] = ( | |
| result["similarity_score"] * 0.7 + | |
| schema_similarity * 0.3 | |
| ) | |
| # Filter out examples with very low schema similarity | |
| if schema_similarity > 0.3: | |
| filtered_results.append(result) | |
| return filtered_results | |
| def _analyze_schema(self, table_headers: List[str]) -> Dict[str, Any]: | |
| """Analyze table schema for intelligent matching.""" | |
| if not table_headers: | |
| return {} | |
| schema_info = { | |
| "column_count": len(table_headers), | |
| "column_types": {}, | |
| "has_numeric": False, | |
| "has_text": False, | |
| "has_date": False, | |
| "has_boolean": False, | |
| "primary_key_candidates": [], | |
| "foreign_key_candidates": [] | |
| } | |
| for header in table_headers: | |
| header_lower = header.lower() | |
| # Detect column types based on naming patterns | |
| if any(word in header_lower for word in ['id', 'key', 'pk', 'fk']): | |
| if 'id' in header_lower: | |
| schema_info["primary_key_candidates"].append(header) | |
| if 'fk' in header_lower or 'foreign' in header_lower: | |
| schema_info["foreign_key_candidates"].append(header) | |
| # Detect data types | |
| if any(word in header_lower for word in ['age', 'count', 'number', 'price', 'salary', 'amount']): | |
| schema_info["has_numeric"] = True | |
| schema_info["column_types"][header] = "numeric" | |
| if any(word in header_lower for word in ['name', 'title', 'description', 'text', 'comment']): | |
| schema_info["has_text"] = True | |
| schema_info["column_types"][header] = "text" | |
| if any(word in header_lower for word in ['date', 'time', 'created', 'updated', 'birth']): | |
| schema_info["has_date"] = True | |
| schema_info["column_types"][header] = "date" | |
| if any(word in header_lower for word in ['is_', 'has_', 'active', 'enabled', 'status']): | |
| schema_info["has_boolean"] = True | |
| schema_info["column_types"][header] = "boolean" | |
| return schema_info | |
| def _calculate_schema_similarity(self, | |
| schema1: Dict[str, Any], | |
| schema2: Dict[str, Any]) -> float: | |
| """Calculate similarity between two table schemas.""" | |
| if not schema1 or not schema2: | |
| return 0.0 | |
| # Column count similarity | |
| count_diff = abs(schema1.get("column_count", 0) - schema2.get("column_count", 0)) | |
| count_similarity = max(0, 1 - (count_diff / max(schema1.get("column_count", 1), 1))) | |
| # Data type similarity | |
| type_similarity = 0.0 | |
| if schema1.get("has_numeric") == schema2.get("has_numeric"): | |
| type_similarity += 0.25 | |
| if schema1.get("has_text") == schema2.get("has_text"): | |
| type_similarity += 0.25 | |
| if schema1.get("has_date") == schema2.get("has_date"): | |
| type_similarity += 0.25 | |
| if schema1.get("has_boolean") == schema2.get("has_boolean"): | |
| type_similarity += 0.25 | |
| # Primary key similarity | |
| pk_similarity = 0.0 | |
| if (schema1.get("primary_key_candidates") and | |
| schema2.get("primary_key_candidates")): | |
| pk_similarity = 0.2 | |
| # Weighted combination | |
| final_similarity = ( | |
| count_similarity * 0.4 + | |
| type_similarity * 0.4 + | |
| pk_similarity * 0.2 | |
| ) | |
| return final_similarity | |
| def _enhance_with_question_analysis(self, | |
| results: List[Dict[str, Any]], | |
| question: str, | |
| table_headers: List[str]) -> List[Dict[str, Any]]: | |
| """Enhance results with question type analysis.""" | |
| # Analyze question type | |
| question_type = self._classify_question_type(question) | |
| for result in results: | |
| # Boost examples that match question type | |
| if question_type in result.get("category", "").lower(): | |
| result["enhanced_score"] *= 1.2 | |
| # Boost examples with similar complexity | |
| question_complexity = self._assess_question_complexity(question) | |
| example_complexity = self._assess_question_complexity(result["question"]) | |
| complexity_match = 1 - abs(question_complexity - example_complexity) / max(question_complexity, 1) | |
| result["enhanced_score"] *= (0.9 + complexity_match * 0.1) | |
| return results | |
| def _classify_question_type(self, question: str) -> str: | |
| """Classify the type of SQL question.""" | |
| question_lower = question.lower() | |
| if any(word in question_lower for word in ['count', 'how many', 'number of']): | |
| return "aggregation" | |
| elif any(word in question_lower for word in ['average', 'mean', 'sum', 'total']): | |
| return "aggregation" | |
| elif any(word in question_lower for word in ['group by', 'grouped', 'by department', 'by category']): | |
| return "grouping" | |
| elif any(word in question_lower for word in ['join', 'combine', 'merge', 'connect']): | |
| return "join" | |
| elif any(word in question_lower for word in ['order by', 'sort', 'rank', 'top', 'highest', 'lowest']): | |
| return "sorting" | |
| elif any(word in question_lower for word in ['where', 'filter', 'condition']): | |
| return "filtering" | |
| else: | |
| return "general" | |
| def _assess_question_complexity(self, question: str) -> float: | |
| """Assess the complexity of a question (0-1 scale).""" | |
| complexity_score = 0.0 | |
| # Length complexity | |
| if len(question.split()) > 20: | |
| complexity_score += 0.3 | |
| elif len(question.split()) > 10: | |
| complexity_score += 0.2 | |
| # Keyword complexity | |
| complex_keywords = ['join', 'group by', 'having', 'subquery', 'union', 'intersect'] | |
| for keyword in complex_keywords: | |
| if keyword in question.lower(): | |
| complexity_score += 0.15 | |
| # Question type complexity | |
| if '?' in question: | |
| complexity_score += 0.1 | |
| return min(1.0, complexity_score) | |
| def _final_ranking(self, | |
| results: List[Dict[str, Any]], | |
| question: str, | |
| table_headers: List[str], | |
| top_k: int) -> List[Dict[str, Any]]: | |
| """Final ranking and selection of examples.""" | |
| if not results: | |
| return [] | |
| # Sort by enhanced score | |
| results.sort(key=lambda x: x.get("enhanced_score", 0), reverse=True) | |
| # Ensure diversity in results | |
| diverse_results = [] | |
| seen_categories = set() | |
| for result in results: | |
| if len(diverse_results) >= top_k: | |
| break | |
| category = result.get("category", "general") | |
| if category not in seen_categories or len(diverse_results) < top_k // 2: | |
| diverse_results.append(result) | |
| seen_categories.add(category) | |
| # Fill remaining slots with highest scoring examples | |
| remaining_slots = top_k - len(diverse_results) | |
| if remaining_slots > 0: | |
| for result in results: | |
| if result not in diverse_results and len(diverse_results) < top_k: | |
| diverse_results.append(result) | |
| # Final formatting | |
| for result in diverse_results: | |
| result["final_score"] = result.get("enhanced_score", result.get("similarity_score", 0)) | |
| # Remove internal scoring fields | |
| result.pop("enhanced_score", None) | |
| result.pop("schema_similarity", None) | |
| return diverse_results[:top_k] | |
| def get_retrieval_stats(self) -> Dict[str, Any]: | |
| """Get statistics about the retrieval system.""" | |
| vector_stats = self.vector_store.get_statistics() | |
| return { | |
| "vector_store_stats": vector_stats, | |
| "schema_cache_size": len(self.schema_cache), | |
| "retrieval_strategies": [ | |
| "vector_similarity", | |
| "schema_filtering", | |
| "question_analysis", | |
| "diversity_ranking" | |
| ] | |
| } | |