Spaces:
Sleeping
Sleeping
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, SparseVectorParams, PointStruct | |
| from fastembed import TextEmbedding, SparseTextEmbedding | |
| import uuid | |
| from dotenv import load_dotenv | |
| import os | |
| from src.scheme import create_scheme | |
| from huggingface_hub import FloatingPointError | |
| COLLECTION_NAME = "Text2SQL" | |
| load_dotenv() | |
| qdrant_api = os.getenv("QDRANT_API_KEY") | |
| qdrant_url = os.getenv("QDRANT_URL") | |
| login(token=os.getenv("HF_TOKEN")) | |
| def create_embeddings(connection_url : str , user_id : str) : | |
| client = QdrantClient(api_key=qdrant_api , url=qdrant_url) | |
| dense_model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25") | |
| if not client.collection_exists(COLLECTION_NAME) : | |
| client.create_collection(collection_name=COLLECTION_NAME, | |
| vectors_config={"dense": VectorParams(size=384, distance=Distance.COSINE)}, | |
| sparse_vectors_config={"sparse": SparseVectorParams()}) | |
| try: | |
| client.create_payload_index( | |
| collection_name=COLLECTION_NAME, | |
| field_name="user_id", | |
| field_schema="keyword", | |
| ) | |
| except Exception: | |
| pass | |
| docs = create_scheme(connection_url) | |
| text = [doc.page_content for doc in docs] | |
| dense_vectors = list(dense_model.embed(text)) | |
| sparse_vectors = list(sparse_model.embed(text)) | |
| points = [] | |
| for i , doc in enumerate(docs) : | |
| dense_vector = dense_vectors[i].tolist() | |
| sparse_embeddings = sparse_vectors[i] | |
| sparse_vector = { | |
| 'indices' : sparse_embeddings.indices.tolist(), | |
| 'values' : sparse_embeddings.values.tolist() | |
| } | |
| table_id = str(uuid.uuid4()) | |
| point = PointStruct( | |
| id = table_id , | |
| vector = { | |
| "dense" : dense_vector , | |
| "sparse" : sparse_vector | |
| }, | |
| payload = { | |
| 'user_id' : user_id, | |
| 'text' : doc.page_content, | |
| 'table_name' : doc.metadata.get("table_name") | |
| } | |
| ) | |
| points.append(point) | |
| client.upsert(collection_name=COLLECTION_NAME, points=points) |