Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from qdrant_client import QdrantClient | |
| from qdrant_client import models | |
| from fastembed import TextEmbedding, SparseTextEmbedding | |
| from langchain_community.utilities import SQLDatabase | |
| load_dotenv() | |
| qdrant_api = os.getenv("QDRANT_API_KEY") | |
| qdrant_url = os.getenv("QDRANT_URL") | |
| COLLECTION_NAME = "Text2SQL" | |
| def retrieve(user_id : str , query : str , connection_url: str) : | |
| client = QdrantClient(api_key=qdrant_api , url=qdrant_url) | |
| dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25") | |
| dense_query_vector = list(dense_model.embed([query]))[0] | |
| sparse_query = list(sparse_model.embed([query]))[0] | |
| sparse_query_vector = models.SparseVector(indices=sparse_query.indices, | |
| values=sparse_query.values) | |
| user_filter = models.Filter( | |
| must=[ | |
| models.FieldCondition( | |
| key="user_id", | |
| match=models.MatchValue(value=user_id) | |
| ) | |
| ] | |
| ) | |
| results = client.query_points( | |
| collection_name=COLLECTION_NAME, | |
| prefetch=[ | |
| models.Prefetch( | |
| query=dense_query_vector, | |
| limit=10, | |
| using="dense", | |
| filter=user_filter | |
| ), | |
| models.Prefetch( | |
| query=sparse_query_vector, | |
| using="sparse", | |
| limit=10, | |
| filter=user_filter | |
| ) | |
| ], | |
| query=models.FusionQuery(fusion=models.Fusion.RRF), | |
| limit=10 | |
| ) | |
| tables = [] | |
| for point in results.points : | |
| table = point.payload['table_name'] | |
| if table not in tables : | |
| tables.append(table) | |
| db = SQLDatabase.from_uri(connection_url , sample_rows_in_table_info=0) | |
| dialect = db.dialect | |
| final_schemes = f"Dialect : {dialect}\n {db.get_table_info(table_names=tables)}" | |
| return final_schemes |