File size: 2,213 Bytes
52adb86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct
from fastembed import TextEmbedding, SparseTextEmbedding
import uuid
from dotenv import load_dotenv
import os
from src.scheme import create_scheme

COLLECTION_NAME = "Text2SQL"

load_dotenv()

qdrant_api = os.getenv("QDRANT_API_KEY")
qdrant_url = os.getenv("QDRANT_URL")

def create_embeddings(connection_url : str , user_id : str) :
    client = QdrantClient(api_key=qdrant_api , url=qdrant_url)

    dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
    sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")

    if not client.collection_exists(COLLECTION_NAME) :
        client.create_collection(collection_name=COLLECTION_NAME,
                                 vectors_config={"dense": VectorParams(size=384, distance=Distance.COSINE)},
                                 sparse_vectors_config={"sparse": SparseVectorParams()})
        
    try:
        client.create_payload_index(
            collection_name=COLLECTION_NAME,
            field_name="user_id",
            field_schema="keyword",
        )
    except Exception:
        pass
    
    docs = create_scheme(connection_url)
    text = [doc.page_content for doc in docs]

    dense_vectors = list(dense_model.embed(text))
    sparse_vectors = list(sparse_model.embed(text))

    points = []

    for i , doc in enumerate(docs) :
        dense_vector = dense_vectors[i].tolist()

        sparse_embeddings = sparse_vectors[i]

        sparse_vector = {
            'indices' : sparse_embeddings.indices.tolist(),
            'values' : sparse_embeddings.values.tolist()
        }

        table_id = str(uuid.uuid4())

        point = PointStruct(
            id = table_id ,
            vector = {
                "dense" : dense_vector ,
                "sparse" : sparse_vector
            },
            payload = {
                'user_id' : user_id,
                'text' : doc.page_content,
                'table_name' : doc.metadata.get("table_name")
            }
        )

        points.append(point)

    client.upsert(collection_name=COLLECTION_NAME, points=points)