|
|
import os |
|
|
import time |
|
|
import logging |
|
|
import uuid |
|
|
import numpy as np |
|
|
from qdrant_client import QdrantClient, models |
|
|
from datasets import load_dataset |
|
|
from typing import List, Dict, Any |
|
|
import ast |
|
|
|
|
|
os.environ['HF_HOME'] = '/tmp/hf_cache' |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
QDRANT_HOST = os.getenv("QDRANT_HOST", "127.0.0.1") |
|
|
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333)) |
|
|
|
|
|
|
|
|
HF_DATASET_NAME = os.getenv("EMBEDDING_DATASET") |
|
|
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "default_collection") |
|
|
EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIM_SIZE", 1024)) |
|
|
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 200)) |
|
|
|
|
|
embedding_token = os.getenv("INS_READ_TOKEN") |
|
|
qdrant_token = os.getenv("QDRANT__SERVICE__API_KEY") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_qdrant_client_raw() -> QdrantClient: |
|
|
"""Initialize the raw Qdrant client.""" |
|
|
|
|
|
logger.info("Getting client connection") |
|
|
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, api_key = qdrant_token, https =False) |
|
|
return client |
|
|
|
|
|
def create_collection_if_not_exists(client: QdrantClient): |
|
|
"""Checks if the collection exists and creates it if it doesn't.""" |
|
|
|
|
|
|
|
|
collections = client.get_collections().collections |
|
|
if COLLECTION_NAME in [c.name for c in collections]: |
|
|
logger.info(f"Collection '{COLLECTION_NAME}' already exists. Skipping data loading.") |
|
|
return True |
|
|
|
|
|
logger.info(f"Collection '{COLLECTION_NAME}' not found. Creating collection...") |
|
|
|
|
|
try: |
|
|
|
|
|
client.recreate_collection( |
|
|
collection_name=COLLECTION_NAME, |
|
|
vectors_config=models.VectorParams( |
|
|
size=EMBEDDING_DIMENSION, |
|
|
distance=models.Distance.COSINE |
|
|
) |
|
|
) |
|
|
logger.info(f"Collection '{COLLECTION_NAME}' successfully created.") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create collection: {e}") |
|
|
return False |
|
|
|
|
|
def safe_parse_data(data: Any, expected_type: type, field_name: str, index: int) -> Any: |
|
|
""" |
|
|
Safely parses input data, converting strings to the expected type (list/dict) |
|
|
using ast.literal_eval if necessary, or raises an error. |
|
|
""" |
|
|
if isinstance(data, expected_type): |
|
|
return data |
|
|
|
|
|
if isinstance(data, str): |
|
|
try: |
|
|
parsed_data = ast.literal_eval(data) |
|
|
if isinstance(parsed_data, expected_type): |
|
|
return parsed_data |
|
|
else: |
|
|
raise ValueError(f"Parsed type is {type(parsed_data)}, not expected {expected_type}") |
|
|
except (ValueError, SyntaxError, TypeError) as e: |
|
|
raise ValueError( |
|
|
f"Data parsing error for {field_name} at index {index}: " |
|
|
f"Input string '{data[:50]}...' is not a valid {expected_type.__name__} string. Error: {e}" |
|
|
) |
|
|
|
|
|
raise TypeError( |
|
|
f"Data type error for {field_name} at index {index}: " |
|
|
f"Expected {expected_type.__name__} or string representation, but got {type(data).__name__}" |
|
|
) |
|
|
|
|
|
def load_and_index_data(client: QdrantClient): |
|
|
"""Pulls pre-embedded data from HF and indexes it into Qdrant using client.upsert.""" |
|
|
logger.info(f"Starting data loading from Hugging Face dataset: {HF_DATASET_NAME}") |
|
|
|
|
|
try: |
|
|
|
|
|
dataset = load_dataset(HF_DATASET_NAME, token = embedding_token, cache_dir="/tmp/datasets_cache") |
|
|
dataset = dataset['train'] |
|
|
points_to_upsert: List[models.PointStruct] = [] |
|
|
vector_column_name = os.getenv("VECTOR_COLUMN", "vector") |
|
|
logger.info(f"Loaded {len(dataset)} documents. Preparing points for indexing...") |
|
|
logger.info(f"Dataset id of type {type(dataset[0]['id'])}") |
|
|
logger.info(f"Dataset vector of type {type(dataset[0][vector_column_name])}") |
|
|
logger.info(f"Dataset payload of type {type(dataset[0]['payload'])}") |
|
|
|
|
|
for i in range(0, len(dataset)): |
|
|
|
|
|
|
|
|
|
|
|
vector = safe_parse_data( |
|
|
dataset[i].get(vector_column_name), |
|
|
list, |
|
|
vector_column_name, |
|
|
i |
|
|
) |
|
|
|
|
|
|
|
|
payload = safe_parse_data( |
|
|
dataset[i].get('payload'), |
|
|
dict, |
|
|
'payload', |
|
|
i |
|
|
) |
|
|
|
|
|
|
|
|
if len(vector) != EMBEDDING_DIMENSION: |
|
|
raise ValueError(f"Vector at index {i} has incorrect size: {len(vector)} (Expected: {EMBEDDING_DIMENSION})") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
points_to_upsert.append( |
|
|
models.PointStruct( |
|
|
id=dataset[i]['id'], |
|
|
vector=vector, |
|
|
payload=payload |
|
|
|
|
|
|
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
for i in range(0, len(points_to_upsert), BATCH_SIZE): |
|
|
client.upsert( |
|
|
collection_name=COLLECTION_NAME, |
|
|
points=points_to_upsert[i:i+BATCH_SIZE] |
|
|
) |
|
|
logger.info("Data Upsert and indexing complete.") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Data indexing failed. Check dataset structure or network access. Error: {e}") |
|
|
|
|
|
|
|
|
def main_initialization(): |
|
|
"""Main execution block.""" |
|
|
max_retries = 10 |
|
|
|
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
client = get_qdrant_client_raw() |
|
|
|
|
|
client.get_collections() |
|
|
logger.info("Qdrant server is reachable.") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.warning(f"Attempt {attempt + 1}/{max_retries}: Qdrant not ready yet. Retrying in 5 seconds. ({e})") |
|
|
time.sleep(5) |
|
|
else: |
|
|
logger.error("Qdrant server failed to start within the maximum retry limit. Exiting initialization.") |
|
|
return |
|
|
|
|
|
|
|
|
collection_exists = create_collection_if_not_exists(client) |
|
|
|
|
|
|
|
|
if not collection_exists: |
|
|
|
|
|
load_and_index_data(client) |
|
|
else: |
|
|
logger.info("Collection already populated, skipping data indexing.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main_initialization() |