Spaces:
Sleeping
Sleeping
File size: 2,295 Bytes
52adb86 cfea50e 52adb86 cfea50e 52adb86 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct
from fastembed import TextEmbedding, SparseTextEmbedding
import uuid
from dotenv import load_dotenv
import os
from src.scheme import create_scheme
from huggingface_hub import FloatingPointError
COLLECTION_NAME = "Text2SQL"
load_dotenv()
qdrant_api = os.getenv("QDRANT_API_KEY")
qdrant_url = os.getenv("QDRANT_URL")
login(token=os.getenv("HF_TOKEN"))
def create_embeddings(connection_url : str , user_id : str) :
client = QdrantClient(api_key=qdrant_api , url=qdrant_url)
dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
if not client.collection_exists(COLLECTION_NAME) :
client.create_collection(collection_name=COLLECTION_NAME,
vectors_config={"dense": VectorParams(size=384, distance=Distance.COSINE)},
sparse_vectors_config={"sparse": SparseVectorParams()})
try:
client.create_payload_index(
collection_name=COLLECTION_NAME,
field_name="user_id",
field_schema="keyword",
)
except Exception:
pass
docs = create_scheme(connection_url)
text = [doc.page_content for doc in docs]
dense_vectors = list(dense_model.embed(text))
sparse_vectors = list(sparse_model.embed(text))
points = []
for i , doc in enumerate(docs) :
dense_vector = dense_vectors[i].tolist()
sparse_embeddings = sparse_vectors[i]
sparse_vector = {
'indices' : sparse_embeddings.indices.tolist(),
'values' : sparse_embeddings.values.tolist()
}
table_id = str(uuid.uuid4())
point = PointStruct(
id = table_id ,
vector = {
"dense" : dense_vector ,
"sparse" : sparse_vector
},
payload = {
'user_id' : user_id,
'text' : doc.page_content,
'table_name' : doc.metadata.get("table_name")
}
)
points.append(point)
client.upsert(collection_name=COLLECTION_NAME, points=points) |