import os from dotenv import load_dotenv from qdrant_client import QdrantClient from qdrant_client import models from fastembed import TextEmbedding, SparseTextEmbedding from langchain_community.utilities import SQLDatabase load_dotenv() qdrant_api = os.getenv("QDRANT_API_KEY") qdrant_url = os.getenv("QDRANT_URL") COLLECTION_NAME = "Text2SQL" def retrieve(user_id : str , query : str , connection_url: str) : client = QdrantClient(api_key=qdrant_api , url=qdrant_url) dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2") sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25") dense_query_vector = list(dense_model.embed([query]))[0] sparse_query = list(sparse_model.embed([query]))[0] sparse_query_vector = models.SparseVector(indices=sparse_query.indices, values=sparse_query.values) user_filter = models.Filter( must=[ models.FieldCondition( key="user_id", match=models.MatchValue(value=user_id) ) ] ) results = client.query_points( collection_name=COLLECTION_NAME, prefetch=[ models.Prefetch( query=dense_query_vector, limit=10, using="dense", filter=user_filter ), models.Prefetch( query=sparse_query_vector, using="sparse", limit=10, filter=user_filter ) ], query=models.FusionQuery(fusion=models.Fusion.RRF), limit=10 ) tables = [] for point in results.points : table = point.payload['table_name'] if table not in tables : tables.append(table) db = SQLDatabase.from_uri(connection_url , sample_rows_in_table_info=0) dialect = db.dialect final_schemes = f"Dialect : {dialect}\n {db.get_table_info(table_names=tables)}" return final_schemes