Spaces:
Sleeping
Sleeping
| """ | |
| Vector Store for SQL Examples | |
| Handles storage and retrieval of SQL examples using ChromaDB and FAISS for high-performance similarity search. | |
| """ | |
| import os | |
| import json | |
| import pickle | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from pathlib import Path | |
| import chromadb | |
| from chromadb.config import Settings | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from loguru import logger | |
| class VectorStore: | |
| """High-performance vector store for SQL examples using ChromaDB and FAISS.""" | |
| def __init__(self, | |
| persist_directory: str = "./data/vector_store", | |
| embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", | |
| collection_name: str = "sql_examples"): | |
| """ | |
| Initialize the vector store. | |
| Args: | |
| persist_directory: Directory to persist the vector store | |
| embedding_model: Sentence transformer model for embeddings | |
| collection_name: Name of the ChromaDB collection | |
| """ | |
| self.persist_directory = Path(persist_directory) | |
| self.persist_directory.mkdir(parents=True, exist_ok=True) | |
| self.embedding_model = SentenceTransformer(embedding_model) | |
| self.collection_name = collection_name | |
| # Initialize ChromaDB client | |
| self.client = chromadb.PersistentClient( | |
| path=str(self.persist_directory), | |
| settings=Settings( | |
| anonymized_telemetry=False, | |
| allow_reset=True | |
| ) | |
| ) | |
| # Get or create collection | |
| self.collection = self.client.get_or_create_collection( | |
| name=collection_name, | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| logger.info(f"Vector store initialized at {self.persist_directory}") | |
| def add_examples(self, examples: List[Dict[str, Any]]) -> None: | |
| """ | |
| Add SQL examples to the vector store. | |
| Args: | |
| examples: List of dictionaries with keys: question, sql, table_headers, metadata | |
| """ | |
| if not examples: | |
| return | |
| # Prepare data for ChromaDB | |
| ids = [] | |
| documents = [] | |
| metadatas = [] | |
| for i, example in enumerate(examples): | |
| # Create document text combining question and table headers | |
| question = example["question"] | |
| table_headers = ", ".join(example["table_headers"]) if isinstance(example["table_headers"], list) else example["table_headers"] | |
| document_text = f"Question: {question}\nTable columns: {table_headers}" | |
| ids.append(f"example_{i}") | |
| documents.append(document_text) | |
| # Store metadata for filtering and retrieval | |
| metadata = { | |
| "question": question, | |
| "sql": example["sql"], | |
| "table_headers": table_headers, | |
| "difficulty": example.get("difficulty", "medium"), | |
| "category": example.get("category", "general"), | |
| "example_id": i | |
| } | |
| metadatas.append(metadata) | |
| # Add to collection | |
| self.collection.add( | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| logger.info(f"Added {len(examples)} examples to vector store") | |
| def search_similar(self, | |
| query: str, | |
| table_headers: List[str], | |
| top_k: int = 5, | |
| similarity_threshold: float = 0.7) -> List[Dict[str, Any]]: | |
| """ | |
| Search for similar SQL examples. | |
| Args: | |
| query: Natural language question | |
| table_headers: List of table column names | |
| top_k: Number of top results to return | |
| similarity_threshold: Minimum similarity score | |
| Returns: | |
| List of similar examples with scores | |
| """ | |
| # Create search query | |
| search_text = f"Question: {query}\nTable columns: {', '.join(table_headers)}" | |
| # Search in ChromaDB | |
| results = self.collection.query( | |
| query_texts=[search_text], | |
| n_results=top_k * 2, # Get more results for filtering | |
| include=["metadatas", "distances"] | |
| ) | |
| # Process and filter results | |
| similar_examples = [] | |
| for i, (metadata, distance) in enumerate(zip(results["metadatas"][0], results["distances"][0])): | |
| # Convert distance to similarity score (cosine distance -> similarity) | |
| similarity_score = 1 - distance | |
| if similarity_score >= similarity_threshold: | |
| example = { | |
| "question": metadata["question"], | |
| "sql": metadata["sql"], | |
| "table_headers": metadata["table_headers"], | |
| "similarity_score": similarity_score, | |
| "difficulty": metadata.get("difficulty", "medium"), | |
| "category": metadata.get("category", "general") | |
| } | |
| similar_examples.append(example) | |
| # Sort by similarity score and return top_k | |
| similar_examples.sort(key=lambda x: x["similarity_score"], reverse=True) | |
| return similar_examples[:top_k] | |
| def get_example_by_id(self, example_id: str) -> Optional[Dict[str, Any]]: | |
| """Get a specific example by ID.""" | |
| try: | |
| result = self.collection.get(ids=[example_id]) | |
| if result["metadatas"]: | |
| metadata = result["metadatas"][0] | |
| return { | |
| "question": metadata["question"], | |
| "sql": metadata["sql"], | |
| "table_headers": metadata["table_headers"], | |
| "difficulty": metadata.get("difficulty", "medium"), | |
| "category": metadata.get("category", "general") | |
| } | |
| except Exception as e: | |
| logger.error(f"Error retrieving example {example_id}: {e}") | |
| return None | |
| def get_statistics(self) -> Dict[str, Any]: | |
| """Get statistics about the vector store.""" | |
| try: | |
| count = self.collection.count() | |
| return { | |
| "total_examples": count, | |
| "collection_name": self.collection_name, | |
| "persist_directory": str(self.persist_directory) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting statistics: {e}") | |
| return {"error": str(e)} | |
| def clear_collection(self) -> None: | |
| """Clear all examples from the collection.""" | |
| try: | |
| self.client.delete_collection(self.collection_name) | |
| self.collection = self.client.create_collection( | |
| name=self.collection_name, | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| logger.info("Collection cleared successfully") | |
| except Exception as e: | |
| logger.error(f"Error clearing collection: {e}") | |
| def export_examples(self, filepath: str) -> None: | |
| """Export all examples to a JSON file.""" | |
| try: | |
| results = self.collection.get() | |
| examples = [] | |
| for i, metadata in enumerate(results["metadatas"]): | |
| example = { | |
| "question": metadata["question"], | |
| "sql": metadata["sql"], | |
| "table_headers": metadata["table_headers"], | |
| "difficulty": metadata.get("difficulty", "medium"), | |
| "category": metadata.get("category", "general") | |
| } | |
| examples.append(example) | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(examples, f, indent=2, ensure_ascii=False) | |
| logger.info(f"Exported {len(examples)} examples to {filepath}") | |
| except Exception as e: | |
| logger.error(f"Error exporting examples: {e}") | |