|
|
""" |
|
|
Semantic Search Service for Dataset Discovery |
|
|
|
|
|
Uses Gemini embeddings to find relevant datasets from a query, |
|
|
enabling scalable discovery across 250+ datasets without context overflow. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
from google import genai |
|
|
from google.genai import types |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SemanticSearch: |
|
|
""" |
|
|
Embedding-based semantic search for dataset discovery. |
|
|
|
|
|
Embeds dataset metadata (name, description, tags, columns) and |
|
|
finds the most relevant datasets for a user query using cosine similarity. |
|
|
""" |
|
|
|
|
|
_instance = None |
|
|
EMBEDDINGS_FILE = Path(__file__).parent.parent / "data" / "embeddings.json" |
|
|
EMBEDDING_MODEL = "models/text-embedding-004" |
|
|
|
|
|
def __new__(cls): |
|
|
if cls._instance is None: |
|
|
cls._instance = super(SemanticSearch, cls).__new__(cls) |
|
|
cls._instance.initialized = False |
|
|
return cls._instance |
|
|
|
|
|
def __init__(self): |
|
|
if self.initialized: |
|
|
return |
|
|
|
|
|
self.embeddings: Dict[str, List[float]] = {} |
|
|
self.metadata_cache: Dict[str, str] = {} |
|
|
|
|
|
|
|
|
api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") |
|
|
if api_key: |
|
|
self.client = genai.Client() |
|
|
else: |
|
|
self.client = None |
|
|
logger.warning("No API key found. Semantic search will use fallback keyword matching.") |
|
|
|
|
|
self._load_embeddings() |
|
|
self.initialized = True |
|
|
|
|
|
def _load_embeddings(self) -> None: |
|
|
"""Load cached embeddings from disk.""" |
|
|
if self.EMBEDDINGS_FILE.exists(): |
|
|
try: |
|
|
with open(self.EMBEDDINGS_FILE, 'r') as f: |
|
|
data = json.load(f) |
|
|
self.embeddings = data.get("embeddings", {}) |
|
|
self.metadata_cache = data.get("metadata", {}) |
|
|
logger.info(f"Loaded {len(self.embeddings)} cached embeddings.") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load embeddings: {e}") |
|
|
self.embeddings = {} |
|
|
self.metadata_cache = {} |
|
|
|
|
|
def _save_embeddings(self) -> None: |
|
|
"""Save embeddings cache to disk.""" |
|
|
try: |
|
|
self.EMBEDDINGS_FILE.parent.mkdir(parents=True, exist_ok=True) |
|
|
with open(self.EMBEDDINGS_FILE, 'w') as f: |
|
|
json.dump({ |
|
|
"embeddings": self.embeddings, |
|
|
"metadata": self.metadata_cache |
|
|
}, f) |
|
|
logger.info(f"Saved {len(self.embeddings)} embeddings to cache.") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save embeddings: {e}") |
|
|
|
|
|
def _build_embedding_text(self, table_name: str, metadata: dict) -> str: |
|
|
"""Build text representation of a table for embedding.""" |
|
|
parts = [f"Table: {table_name}"] |
|
|
|
|
|
|
|
|
desc = metadata.get("semantic_description") or metadata.get("description", "") |
|
|
if desc: |
|
|
parts.append(f"Description: {desc}") |
|
|
|
|
|
|
|
|
tags = metadata.get("tags", []) |
|
|
if tags: |
|
|
parts.append(f"Tags: {', '.join(tags)}") |
|
|
|
|
|
|
|
|
category = metadata.get("category", "") |
|
|
if category: |
|
|
parts.append(f"Category: {category}") |
|
|
|
|
|
|
|
|
columns = metadata.get("columns", []) |
|
|
|
|
|
meaningful_cols = [c for c in columns[:15] if c not in ['geom', 'geometry', 'id', 'fid']] |
|
|
if meaningful_cols: |
|
|
parts.append(f"Columns: {', '.join(meaningful_cols)}") |
|
|
|
|
|
|
|
|
data_type = metadata.get("data_type", "static") |
|
|
parts.append(f"Data type: {data_type}") |
|
|
|
|
|
return ". ".join(parts) |
|
|
|
|
|
def _embed_text(self, text: str) -> Optional[List[float]]: |
|
|
"""Get embedding for a text string.""" |
|
|
if not self.client: |
|
|
return None |
|
|
|
|
|
try: |
|
|
result = self.client.models.embed_content( |
|
|
model=self.EMBEDDING_MODEL, |
|
|
contents=text |
|
|
) |
|
|
return result.embeddings[0].values |
|
|
except Exception as e: |
|
|
logger.error(f"Embedding failed: {e}") |
|
|
return None |
|
|
|
|
|
def _cosine_similarity(self, a: List[float], b: List[float]) -> float: |
|
|
"""Compute cosine similarity between two vectors.""" |
|
|
a_np = np.array(a) |
|
|
b_np = np.array(b) |
|
|
|
|
|
dot_product = np.dot(a_np, b_np) |
|
|
norm_a = np.linalg.norm(a_np) |
|
|
norm_b = np.linalg.norm(b_np) |
|
|
|
|
|
if norm_a == 0 or norm_b == 0: |
|
|
return 0.0 |
|
|
|
|
|
return float(dot_product / (norm_a * norm_b)) |
|
|
|
|
|
def embed_table(self, table_name: str, metadata: dict) -> bool: |
|
|
""" |
|
|
Embed a table's metadata for semantic search. |
|
|
|
|
|
Returns True if embedding was successful or already cached. |
|
|
""" |
|
|
text = self._build_embedding_text(table_name, metadata) |
|
|
|
|
|
|
|
|
if table_name in self.metadata_cache and self.metadata_cache[table_name] == text: |
|
|
return True |
|
|
|
|
|
embedding = self._embed_text(text) |
|
|
if embedding: |
|
|
self.embeddings[table_name] = embedding |
|
|
self.metadata_cache[table_name] = text |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def embed_all_tables(self, catalog: Dict[str, dict]) -> int: |
|
|
""" |
|
|
Embed all tables in the catalog. |
|
|
|
|
|
Returns number of newly embedded tables. |
|
|
""" |
|
|
new_count = 0 |
|
|
|
|
|
for table_name, metadata in catalog.items(): |
|
|
text = self._build_embedding_text(table_name, metadata) |
|
|
|
|
|
|
|
|
if table_name in self.metadata_cache and self.metadata_cache[table_name] == text: |
|
|
continue |
|
|
|
|
|
if self.embed_table(table_name, metadata): |
|
|
new_count += 1 |
|
|
|
|
|
if new_count > 0: |
|
|
self._save_embeddings() |
|
|
logger.info(f"Embedded {new_count} new tables.") |
|
|
|
|
|
return new_count |
|
|
|
|
|
def search(self, query: str, top_k: int = 15) -> List[Tuple[str, float]]: |
|
|
""" |
|
|
Find the most relevant tables for a query. |
|
|
|
|
|
Returns list of (table_name, similarity_score) tuples, sorted by relevance. |
|
|
""" |
|
|
if not self.embeddings: |
|
|
logger.warning("No embeddings available. Returning empty results.") |
|
|
return [] |
|
|
|
|
|
|
|
|
query_embedding = self._embed_text(query) |
|
|
|
|
|
if not query_embedding: |
|
|
|
|
|
return self._keyword_fallback(query, top_k) |
|
|
|
|
|
|
|
|
scores = [] |
|
|
for table_name, table_embedding in self.embeddings.items(): |
|
|
score = self._cosine_similarity(query_embedding, table_embedding) |
|
|
scores.append((table_name, score)) |
|
|
|
|
|
|
|
|
scores.sort(key=lambda x: -x[1]) |
|
|
|
|
|
return scores[:top_k] |
|
|
|
|
|
def search_table_names(self, query: str, top_k: int = 15) -> List[str]: |
|
|
"""Convenience method that returns just table names.""" |
|
|
results = self.search(query, top_k) |
|
|
return [name for name, _ in results] |
|
|
|
|
|
def _keyword_fallback(self, query: str, top_k: int) -> List[Tuple[str, float]]: |
|
|
""" |
|
|
Simple keyword matching fallback when embeddings unavailable. |
|
|
""" |
|
|
query_terms = query.lower().split() |
|
|
scores = [] |
|
|
|
|
|
for table_name, text in self.metadata_cache.items(): |
|
|
text_lower = text.lower() |
|
|
score = sum(1 for term in query_terms if term in text_lower) |
|
|
if score > 0: |
|
|
scores.append((table_name, score / len(query_terms))) |
|
|
|
|
|
scores.sort(key=lambda x: -x[1]) |
|
|
return scores[:top_k] |
|
|
|
|
|
def get_stats(self) -> dict: |
|
|
"""Return statistics about the semantic search index.""" |
|
|
return { |
|
|
"total_tables": len(self.embeddings), |
|
|
"cache_file": str(self.EMBEDDINGS_FILE), |
|
|
"cache_exists": self.EMBEDDINGS_FILE.exists(), |
|
|
"client_available": self.client is not None |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
_semantic_search: Optional[SemanticSearch] = None |
|
|
|
|
|
|
|
|
def get_semantic_search() -> SemanticSearch: |
|
|
"""Get the singleton semantic search instance.""" |
|
|
global _semantic_search |
|
|
if _semantic_search is None: |
|
|
_semantic_search = SemanticSearch() |
|
|
return _semantic_search |
|
|
|