File size: 3,533 Bytes
e0a827b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from __future__ import annotations
import json
from pathlib import Path
from typing import List, Tuple, Dict, Any
import torch
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
_MODEL_ID = "mabosaimi/bge-m3-text2tables"
model: SentenceTransformer = SentenceTransformer(_MODEL_ID)
_corpus_text_file = hf_hub_download(repo_id=_MODEL_ID, filename="corpus_texts.json")
with open(_corpus_text_file, "r", encoding="utf-8") as _f:
corpus_texts: List[str] = json.load(_f)
_corpus_emb_file = hf_hub_download(repo_id=_MODEL_ID, filename="corpus_embeddings.pt")
corpus_embeddings: torch.Tensor = torch.load(_corpus_emb_file, map_location="cpu")
_schemas_PATH = Path(__file__).parent / "schemas.json"
if _schemas_PATH.exists():
with open(_schemas_PATH, "r", encoding="utf-8") as _cf:
schemas: List[Dict[str, Any]] = json.load(_cf)
else:
schemas = []
def get_model_id() -> str:
"""Return the identifier of the embedding model in use.
This intentionally hides low-level model details from API consumers while
allowing health/diagnostics endpoints to expose basic service info.
"""
return _MODEL_ID
def get_corpus_size() -> int:
"""Return the number of entries in the fixed metadata corpus."""
return len(corpus_texts)
def preprocess_text(query: str) -> str:
"""Preprocess a natural language string by stripping whitespace.
Inputs:
- query: Natural language string to be preprocessed.
Returns:
- The preprocessed string.
"""
return query.strip()
def encode_text(query: str) -> torch.Tensor:
"""Encode a natural language query into an embedding tensor.
Inputs:
- query: Natural language string to be embedded.
Returns:
- A 1 x D torch.Tensor representing the normalized embedding of the query.
"""
query = preprocess_text(query)
return model.encode(query, convert_to_tensor=True, normalize_embeddings=True)
def semantic_search(query: str, top_k: int = 5) -> List[Tuple[float, str, int]]:
"""Compute semantic similarity between a query and the stored corpus.
Inputs:
- query: Natural language search string.
- top_k: Maximum number of results to return (capped at corpus size).
Returns:
- A list of tuples (score, text, index) sorted by descending similarity,
where:
- score is a float cosine similarity.
- text is the matched corpus entry.
- index is the integer position in the corpus (stable identifier).
"""
query_embedding = encode_text(query)
scores = cos_sim(query_embedding, corpus_embeddings)[0]
k = min(max(top_k, 1), len(corpus_texts))
values, indices = torch.topk(scores, k=k)
return [
(float(values[i]), corpus_texts[int(indices[i])], int(indices[i]))
for i in range(len(values))
]
def get_schemas(include_columns: bool = False) -> List[Dict[str, Any]]:
"""Return the local schemas.
Inputs:
- include_columns: When True, include full column metadata; otherwise
return a minimal view with table name and description only.
Returns:
- List of table dicts. If include_columns is False, each dict contains
{"table", "description"}. If True, it includes the original structure.
"""
if not schemas:
return []
if include_columns:
return schemas
return [
{"table": t["table"], "description": t.get("description", "")} for t in schemas
]
|