text2sql_backend / src /retrieval.py
LightRT's picture
Project Completion Commit
52adb86
raw
history blame
2.14 kB
import os
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client import models
from fastembed import TextEmbedding, SparseTextEmbedding
from langchain_community.utilities import SQLDatabase
load_dotenv()
qdrant_api = os.getenv("QDRANT_API_KEY")
qdrant_url = os.getenv("QDRANT_URL")
COLLECTION_NAME = "Text2SQL"
def retrieve(user_id : str , query : str , connection_url: str) :
client = QdrantClient(api_key=qdrant_api , url=qdrant_url)
dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
dense_query_vector = list(dense_model.embed([query]))[0]
sparse_query = list(sparse_model.embed([query]))[0]
sparse_query_vector = models.SparseVector(indices=sparse_query.indices,
values=sparse_query.values)
user_filter = models.Filter(
must=[
models.FieldCondition(
key="user_id",
match=models.MatchValue(value=user_id)
)
]
)
results = client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[
models.Prefetch(
query=dense_query_vector,
limit=10,
using="dense",
filter=user_filter
),
models.Prefetch(
query=sparse_query_vector,
using="sparse",
limit=10,
filter=user_filter
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=10
)
tables = []
for point in results.points :
table = point.payload['table_name']
if table not in tables :
tables.append(table)
db = SQLDatabase.from_uri(connection_url , sample_rows_in_table_info=0)
dialect = db.dialect
final_schemes = f"Dialect : {dialect}\n {db.get_table_info(table_names=tables)}"
return final_schemes