File size: 3,533 Bytes
e0a827b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import annotations

import json
from pathlib import Path
from typing import List, Tuple, Dict, Any

import torch
from huggingface_hub import hf_hub_download
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

_MODEL_ID = "mabosaimi/bge-m3-text2tables"

model: SentenceTransformer = SentenceTransformer(_MODEL_ID)

_corpus_text_file = hf_hub_download(repo_id=_MODEL_ID, filename="corpus_texts.json")
with open(_corpus_text_file, "r", encoding="utf-8") as _f:
    corpus_texts: List[str] = json.load(_f)

_corpus_emb_file = hf_hub_download(repo_id=_MODEL_ID, filename="corpus_embeddings.pt")
corpus_embeddings: torch.Tensor = torch.load(_corpus_emb_file, map_location="cpu")

_schemas_PATH = Path(__file__).parent / "schemas.json"
if _schemas_PATH.exists():
    with open(_schemas_PATH, "r", encoding="utf-8") as _cf:
        schemas: List[Dict[str, Any]] = json.load(_cf)
else:
    schemas = []


def get_model_id() -> str:
    """Return the identifier of the embedding model in use.

    This intentionally hides low-level model details from API consumers while
    allowing health/diagnostics endpoints to expose basic service info.
    """

    return _MODEL_ID


def get_corpus_size() -> int:
    """Return the number of entries in the fixed metadata corpus."""

    return len(corpus_texts)


def preprocess_text(query: str) -> str:
    """Preprocess a natural language string by stripping whitespace.

    Inputs:
    - query: Natural language string to be preprocessed.

    Returns:
    - The preprocessed string.
    """
    return query.strip()


def encode_text(query: str) -> torch.Tensor:
    """Encode a natural language query into an embedding tensor.

    Inputs:
    - query: Natural language string to be embedded.

    Returns:
    - A 1 x D torch.Tensor representing the normalized embedding of the query.
    """
    query = preprocess_text(query)
    return model.encode(query, convert_to_tensor=True, normalize_embeddings=True)


def semantic_search(query: str, top_k: int = 5) -> List[Tuple[float, str, int]]:
    """Compute semantic similarity between a query and the stored corpus.

    Inputs:
    - query: Natural language search string.
    - top_k: Maximum number of results to return (capped at corpus size).

    Returns:
    - A list of tuples (score, text, index) sorted by descending similarity,
      where:
        - score is a float cosine similarity.
        - text is the matched corpus entry.
        - index is the integer position in the corpus (stable identifier).
    """

    query_embedding = encode_text(query)
    scores = cos_sim(query_embedding, corpus_embeddings)[0]
    k = min(max(top_k, 1), len(corpus_texts))
    values, indices = torch.topk(scores, k=k)
    return [
        (float(values[i]), corpus_texts[int(indices[i])], int(indices[i]))
        for i in range(len(values))
    ]


def get_schemas(include_columns: bool = False) -> List[Dict[str, Any]]:
    """Return the local schemas.

    Inputs:
    - include_columns: When True, include full column metadata; otherwise
      return a minimal view with table name and description only.

    Returns:
    - List of table dicts. If include_columns is False, each dict contains
      {"table", "description"}. If True, it includes the original structure.
    """

    if not schemas:
        return []
    if include_columns:
        return schemas
    return [
        {"table": t["table"], "description": t.get("description", "")} for t in schemas
    ]