text2sql_backend / src /embedding.py
LightRT's picture
Update src/embedding.py
cfea50e verified
raw
history blame
2.3 kB
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct
from fastembed import TextEmbedding, SparseTextEmbedding
import uuid
from dotenv import load_dotenv
import os
from src.scheme import create_scheme
from huggingface_hub import FloatingPointError
COLLECTION_NAME = "Text2SQL"
load_dotenv()
qdrant_api = os.getenv("QDRANT_API_KEY")
qdrant_url = os.getenv("QDRANT_URL")
login(token=os.getenv("HF_TOKEN"))
def create_embeddings(connection_url : str , user_id : str) :
client = QdrantClient(api_key=qdrant_api , url=qdrant_url)
dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25")
if not client.collection_exists(COLLECTION_NAME) :
client.create_collection(collection_name=COLLECTION_NAME,
vectors_config={"dense": VectorParams(size=384, distance=Distance.COSINE)},
sparse_vectors_config={"sparse": SparseVectorParams()})
try:
client.create_payload_index(
collection_name=COLLECTION_NAME,
field_name="user_id",
field_schema="keyword",
)
except Exception:
pass
docs = create_scheme(connection_url)
text = [doc.page_content for doc in docs]
dense_vectors = list(dense_model.embed(text))
sparse_vectors = list(sparse_model.embed(text))
points = []
for i , doc in enumerate(docs) :
dense_vector = dense_vectors[i].tolist()
sparse_embeddings = sparse_vectors[i]
sparse_vector = {
'indices' : sparse_embeddings.indices.tolist(),
'values' : sparse_embeddings.values.tolist()
}
table_id = str(uuid.uuid4())
point = PointStruct(
id = table_id ,
vector = {
"dense" : dense_vector ,
"sparse" : sparse_vector
},
payload = {
'user_id' : user_id,
'text' : doc.page_content,
'table_name' : doc.metadata.get("table_name")
}
)
points.append(point)
client.upsert(collection_name=COLLECTION_NAME, points=points)