File size: 2,210 Bytes
52adb86
 
 
 
 
 
0f75eeb
52adb86
 
 
 
 
0f75eeb
52adb86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client import models
from fastembed import TextEmbedding, SparseTextEmbedding
from langchain_community.utilities import SQLDatabase
from huggingface_hub import login

load_dotenv()

qdrant_api = os.getenv("QDRANT_API_KEY")
qdrant_url = os.getenv("QDRANT_URL")
login(token=os.getenv("HF_TOKEN"))
COLLECTION_NAME = "Text2SQL"

def retrieve(user_id : str , query : str , connection_url: str) :

    client = QdrantClient(api_key=qdrant_api , url=qdrant_url)

    dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
    sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")

    dense_query_vector = list(dense_model.embed([query]))[0]

    sparse_query = list(sparse_model.embed([query]))[0]

    sparse_query_vector = models.SparseVector(indices=sparse_query.indices,
                                              values=sparse_query.values)
    
    user_filter = models.Filter(
            must=[
                models.FieldCondition(
                    key="user_id",
                    match=models.MatchValue(value=user_id)
                )
            ]
        )
    
    results = client.query_points(
            collection_name=COLLECTION_NAME,
            prefetch=[
                models.Prefetch(
                    query=dense_query_vector,
                    limit=10,
                    using="dense",
                    filter=user_filter
                ),
                models.Prefetch(
                    query=sparse_query_vector,
                    using="sparse",
                    limit=10,
                    filter=user_filter
                )
            ],
            query=models.FusionQuery(fusion=models.Fusion.RRF),
            limit=10
        )
    
    tables = []
    for point in results.points :
        table = point.payload['table_name']
        if table not in tables :
            tables.append(table)

    db = SQLDatabase.from_uri(connection_url , sample_rows_in_table_info=0)

    dialect = db.dialect
    
    final_schemes = f"Dialect : {dialect}\n {db.get_table_info(table_names=tables)}"
    
    return final_schemes