| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import List, Tuple, Dict, Any | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from sentence_transformers import SentenceTransformer | |
| from sentence_transformers.util import cos_sim | |
| _MODEL_ID = "mabosaimi/bge-m3-text2tables" | |
| model: SentenceTransformer = SentenceTransformer(_MODEL_ID) | |
| _corpus_text_file = hf_hub_download(repo_id=_MODEL_ID, filename="corpus_texts.json") | |
| with open(_corpus_text_file, "r", encoding="utf-8") as _f: | |
| corpus_texts: List[str] = json.load(_f) | |
| _corpus_emb_file = hf_hub_download(repo_id=_MODEL_ID, filename="corpus_embeddings.pt") | |
| corpus_embeddings: torch.Tensor = torch.load(_corpus_emb_file, map_location="cpu") | |
| _schemas_PATH = Path(__file__).parent / "schemas.json" | |
| if _schemas_PATH.exists(): | |
| with open(_schemas_PATH, "r", encoding="utf-8") as _cf: | |
| schemas: List[Dict[str, Any]] = json.load(_cf) | |
| else: | |
| schemas = [] | |
| def get_model_id() -> str: | |
| """Return the identifier of the embedding model in use. | |
| This intentionally hides low-level model details from API consumers while | |
| allowing health/diagnostics endpoints to expose basic service info. | |
| """ | |
| return _MODEL_ID | |
| def get_corpus_size() -> int: | |
| """Return the number of entries in the fixed metadata corpus.""" | |
| return len(corpus_texts) | |
| def preprocess_text(query: str) -> str: | |
| """Preprocess a natural language string by stripping whitespace. | |
| Inputs: | |
| - query: Natural language string to be preprocessed. | |
| Returns: | |
| - The preprocessed string. | |
| """ | |
| return query.strip() | |
| def encode_text(query: str) -> torch.Tensor: | |
| """Encode a natural language query into an embedding tensor. | |
| Inputs: | |
| - query: Natural language string to be embedded. | |
| Returns: | |
| - A 1 x D torch.Tensor representing the normalized embedding of the query. | |
| """ | |
| query = preprocess_text(query) | |
| return model.encode(query, convert_to_tensor=True, normalize_embeddings=True) | |
| def semantic_search(query: str, top_k: int = 5) -> List[Tuple[float, str, int]]: | |
| """Compute semantic similarity between a query and the stored corpus. | |
| Inputs: | |
| - query: Natural language search string. | |
| - top_k: Maximum number of results to return (capped at corpus size). | |
| Returns: | |
| - A list of tuples (score, text, index) sorted by descending similarity, | |
| where: | |
| - score is a float cosine similarity. | |
| - text is the matched corpus entry. | |
| - index is the integer position in the corpus (stable identifier). | |
| """ | |
| query_embedding = encode_text(query) | |
| scores = cos_sim(query_embedding, corpus_embeddings)[0] | |
| k = min(max(top_k, 1), len(corpus_texts)) | |
| values, indices = torch.topk(scores, k=k) | |
| return [ | |
| (float(values[i]), corpus_texts[int(indices[i])], int(indices[i])) | |
| for i in range(len(values)) | |
| ] | |
| def get_schemas(include_columns: bool = False) -> List[Dict[str, Any]]: | |
| """Return the local schemas. | |
| Inputs: | |
| - include_columns: When True, include full column metadata; otherwise | |
| return a minimal view with table name and description only. | |
| Returns: | |
| - List of table dicts. If include_columns is False, each dict contains | |
| {"table", "description"}. If True, it includes the original structure. | |
| """ | |
| if not schemas: | |
| return [] | |
| if include_columns: | |
| return schemas | |
| return [ | |
| {"table": t["table"], "description": t.get("description", "")} for t in schemas | |
| ] | |