""" Semantic Search Service for Dataset Discovery Uses Gemini embeddings to find relevant datasets from a query, enabling scalable discovery across 250+ datasets without context overflow. """ import json import logging import numpy as np from pathlib import Path from typing import Dict, List, Optional, Tuple from google import genai from google.genai import types import os from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) class SemanticSearch: """ Embedding-based semantic search for dataset discovery. Embeds dataset metadata (name, description, tags, columns) and finds the most relevant datasets for a user query using cosine similarity. """ _instance = None EMBEDDINGS_FILE = Path(__file__).parent.parent / "data" / "embeddings.json" EMBEDDING_MODEL = "models/text-embedding-004" def __new__(cls): if cls._instance is None: cls._instance = super(SemanticSearch, cls).__new__(cls) cls._instance.initialized = False return cls._instance def __init__(self): if self.initialized: return self.embeddings: Dict[str, List[float]] = {} self.metadata_cache: Dict[str, str] = {} # table_name -> embedded text # Initialize Gemini client api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") if api_key: self.client = genai.Client() else: self.client = None logger.warning("No API key found. Semantic search will use fallback keyword matching.") self._load_embeddings() self.initialized = True def _load_embeddings(self) -> None: """Load cached embeddings from disk.""" if self.EMBEDDINGS_FILE.exists(): try: with open(self.EMBEDDINGS_FILE, 'r') as f: data = json.load(f) self.embeddings = data.get("embeddings", {}) self.metadata_cache = data.get("metadata", {}) logger.info(f"Loaded {len(self.embeddings)} cached embeddings.") except Exception as e: logger.error(f"Failed to load embeddings: {e}") self.embeddings = {} self.metadata_cache = {} def _save_embeddings(self) -> None: """Save embeddings cache to disk.""" try: self.EMBEDDINGS_FILE.parent.mkdir(parents=True, exist_ok=True) with open(self.EMBEDDINGS_FILE, 'w') as f: json.dump({ "embeddings": self.embeddings, "metadata": self.metadata_cache }, f) logger.info(f"Saved {len(self.embeddings)} embeddings to cache.") except Exception as e: logger.error(f"Failed to save embeddings: {e}") def _build_embedding_text(self, table_name: str, metadata: dict) -> str: """Build text representation of a table for embedding.""" parts = [f"Table: {table_name}"] # Description (prefer semantic if available) desc = metadata.get("semantic_description") or metadata.get("description", "") if desc: parts.append(f"Description: {desc}") # Tags tags = metadata.get("tags", []) if tags: parts.append(f"Tags: {', '.join(tags)}") # Category category = metadata.get("category", "") if category: parts.append(f"Category: {category}") # Key columns (limit to first 15 for embedding efficiency) columns = metadata.get("columns", []) # Filter out generic columns meaningful_cols = [c for c in columns[:15] if c not in ['geom', 'geometry', 'id', 'fid']] if meaningful_cols: parts.append(f"Columns: {', '.join(meaningful_cols)}") # Data type data_type = metadata.get("data_type", "static") parts.append(f"Data type: {data_type}") return ". ".join(parts) def _embed_text(self, text: str) -> Optional[List[float]]: """Get embedding for a text string.""" if not self.client: return None try: result = self.client.models.embed_content( model=self.EMBEDDING_MODEL, contents=text # Note: 'contents' not 'content' ) return result.embeddings[0].values except Exception as e: logger.error(f"Embedding failed: {e}") return None def _cosine_similarity(self, a: List[float], b: List[float]) -> float: """Compute cosine similarity between two vectors.""" a_np = np.array(a) b_np = np.array(b) dot_product = np.dot(a_np, b_np) norm_a = np.linalg.norm(a_np) norm_b = np.linalg.norm(b_np) if norm_a == 0 or norm_b == 0: return 0.0 return float(dot_product / (norm_a * norm_b)) def embed_table(self, table_name: str, metadata: dict) -> bool: """ Embed a table's metadata for semantic search. Returns True if embedding was successful or already cached. """ text = self._build_embedding_text(table_name, metadata) # Check if already embedded with same text if table_name in self.metadata_cache and self.metadata_cache[table_name] == text: return True embedding = self._embed_text(text) if embedding: self.embeddings[table_name] = embedding self.metadata_cache[table_name] = text return True return False def embed_all_tables(self, catalog: Dict[str, dict]) -> int: """ Embed all tables in the catalog. Returns number of newly embedded tables. """ new_count = 0 for table_name, metadata in catalog.items(): text = self._build_embedding_text(table_name, metadata) # Skip if already embedded with same text if table_name in self.metadata_cache and self.metadata_cache[table_name] == text: continue if self.embed_table(table_name, metadata): new_count += 1 if new_count > 0: self._save_embeddings() logger.info(f"Embedded {new_count} new tables.") return new_count def search(self, query: str, top_k: int = 15) -> List[Tuple[str, float]]: """ Find the most relevant tables for a query. Returns list of (table_name, similarity_score) tuples, sorted by relevance. """ if not self.embeddings: logger.warning("No embeddings available. Returning empty results.") return [] # Embed the query query_embedding = self._embed_text(query) if not query_embedding: # Fallback to keyword matching return self._keyword_fallback(query, top_k) # Compute similarities scores = [] for table_name, table_embedding in self.embeddings.items(): score = self._cosine_similarity(query_embedding, table_embedding) scores.append((table_name, score)) # Sort by similarity (descending) scores.sort(key=lambda x: -x[1]) return scores[:top_k] def search_table_names(self, query: str, top_k: int = 15) -> List[str]: """Convenience method that returns just table names.""" results = self.search(query, top_k) return [name for name, _ in results] def _keyword_fallback(self, query: str, top_k: int) -> List[Tuple[str, float]]: """ Simple keyword matching fallback when embeddings unavailable. """ query_terms = query.lower().split() scores = [] for table_name, text in self.metadata_cache.items(): text_lower = text.lower() score = sum(1 for term in query_terms if term in text_lower) if score > 0: scores.append((table_name, score / len(query_terms))) scores.sort(key=lambda x: -x[1]) return scores[:top_k] def get_stats(self) -> dict: """Return statistics about the semantic search index.""" return { "total_tables": len(self.embeddings), "cache_file": str(self.EMBEDDINGS_FILE), "cache_exists": self.EMBEDDINGS_FILE.exists(), "client_available": self.client is not None } # Singleton accessor _semantic_search: Optional[SemanticSearch] = None def get_semantic_search() -> SemanticSearch: """Get the singleton semantic search instance.""" global _semantic_search if _semantic_search is None: _semantic_search = SemanticSearch() return _semantic_search