| from sentence_transformers import SentenceTransformer |
| from torch.nn.functional import cosine_similarity |
| import torch |
|
|
| class SQLMetadataRetriever: |
| def __init__(self): |
| self.model = SentenceTransformer("all-MiniLM-L6-v2") |
| self.docs = [] |
| self.embeddings = None |
|
|
| def add_documents(self, docs): |
| """Store and embed schema documents""" |
| self.docs = docs |
| self.embeddings = self.model.encode(docs, convert_to_tensor=True) |
|
|
| def retrieve(self, query, top_k=1): |
| query_embedding = self.model.encode(query, convert_to_tensor=True) |
| |
| if self.embeddings is None or self.embeddings.shape[0] == 0: |
| raise ValueError("No embeddings found. Did you call add_documents()?") |
|
|
| available_docs = self.embeddings.shape[0] |
| top_k = min(top_k, available_docs) |
|
|
| |
| query_expanded = query_embedding.unsqueeze(0).expand(self.embeddings.size(0), -1) |
| scores = cosine_similarity(query_expanded, self.embeddings, dim=1) |
| |
| |
| top_indices = torch.topk(scores, top_k).indices.tolist() |
| return [self.docs[i] for i in top_indices] |
|
|
|
|
| |
| if __name__ == "__main__": |
| retriever = SQLMetadataRetriever() |
|
|
| metadata_docs = [ |
| |
| "Table team: columns are id (Unique team identifier), full_name (Full team name, e.g., 'Los Angeles Lakers'), abbreviation (3-letter team code, e.g., 'LAL'), city, state, year_founded.", |
| |
| |
| "Table game: columns are game_date (Date of the game), team_id_home, team_id_away (Unique IDs of home and away teams), team_name_home, team_name_away (Full names of the teams), pts_home, pts_away (Points scored), wl_home (W/L result), reb_home, reb_away (Total rebounds), ast_home, ast_away (Total assists), fgm_home, fg_pct_home (Field goals), fg3m_home (Three-pointers), ftm_home (Free throws), tov_home (Turnovers), and other game-related statistics." |
| ] |
|
|
|
|
| retriever.add_documents(metadata_docs) |
|
|
| question = "What is the most assists by the Celtics in a home game?" |
| relevant = retriever.retrieve(question, top_k=1) |
| print("Top match:", relevant[0]) |
|
|