test_giz / initialize_qdrant.py
Romulan12's picture
Upload 5 files
70db68d verified
import os
import time
import logging
import uuid
import numpy as np
from qdrant_client import QdrantClient, models
from datasets import load_dataset
from typing import List, Dict, Any
import ast
os.environ['HF_HOME'] = '/tmp/hf_cache'
# --- Setup Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration ---
# Read configuration from environment variables set in the Dockerfile
QDRANT_HOST = os.getenv("QDRANT_HOST", "127.0.0.1")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333))
# Embeddings related params and config
HF_DATASET_NAME = os.getenv("EMBEDDING_DATASET")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "default_collection")
EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIM_SIZE", 1024))
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 200))
# TOKENS AND SECRETS
embedding_token = os.getenv("INS_READ_TOKEN")
qdrant_token = os.getenv("QDRANT__SERVICE__API_KEY")
# --- Core Functions ---
def get_qdrant_client_raw() -> QdrantClient:
"""Initialize the raw Qdrant client."""
# Connecting to the Qdrant server started locally via start.sh
logger.info("Getting client connection")
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, api_key = qdrant_token, https =False)
return client
def create_collection_if_not_exists(client: QdrantClient):
"""Checks if the collection exists and creates it if it doesn't."""
# Check if collection is present
collections = client.get_collections().collections
if COLLECTION_NAME in [c.name for c in collections]:
logger.info(f"Collection '{COLLECTION_NAME}' already exists. Skipping data loading.")
return True # Collection found
logger.info(f"Collection '{COLLECTION_NAME}' not found. Creating collection...")
try:
# Create the collection with the correct vector size
client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=models.VectorParams(
size=EMBEDDING_DIMENSION,
distance=models.Distance.COSINE
)
)
logger.info(f"Collection '{COLLECTION_NAME}' successfully created.")
return False # Collection was created (needs data loading)
except Exception as e:
logger.error(f"Failed to create collection: {e}")
return False
def safe_parse_data(data: Any, expected_type: type, field_name: str, index: int) -> Any:
"""
Safely parses input data, converting strings to the expected type (list/dict)
using ast.literal_eval if necessary, or raises an error.
"""
if isinstance(data, expected_type):
return data
if isinstance(data, str):
try:
parsed_data = ast.literal_eval(data)
if isinstance(parsed_data, expected_type):
return parsed_data
else:
raise ValueError(f"Parsed type is {type(parsed_data)}, not expected {expected_type}")
except (ValueError, SyntaxError, TypeError) as e:
raise ValueError(
f"Data parsing error for {field_name} at index {index}: "
f"Input string '{data[:50]}...' is not a valid {expected_type.__name__} string. Error: {e}"
)
raise TypeError(
f"Data type error for {field_name} at index {index}: "
f"Expected {expected_type.__name__} or string representation, but got {type(data).__name__}"
)
def load_and_index_data(client: QdrantClient):
"""Pulls pre-embedded data from HF and indexes it into Qdrant using client.upsert."""
logger.info(f"Starting data loading from Hugging Face dataset: {HF_DATASET_NAME}")
try:
# Load a small slice of the dataset for demonstration/speed
dataset = load_dataset(HF_DATASET_NAME, token = embedding_token, cache_dir="/tmp/datasets_cache")
dataset = dataset['train']
points_to_upsert: List[models.PointStruct] = []
vector_column_name = os.getenv("VECTOR_COLUMN", "vector")
logger.info(f"Loaded {len(dataset)} documents. Preparing points for indexing...")
logger.info(f"Dataset id of type {type(dataset[0]['id'])}")
logger.info(f"Dataset vector of type {type(dataset[0][vector_column_name])}")
logger.info(f"Dataset payload of type {type(dataset[0]['payload'])}")
for i in range(0, len(dataset)):
# 1. Getting Column name for vectors
# Use safe_parse_data for vectors (expected List[float])
vector = safe_parse_data(
dataset[i].get(vector_column_name),
list,
vector_column_name,
i
)
# Use safe_parse_data for payloads (expected Dict[str, Any])
payload = safe_parse_data(
dataset[i].get('payload'),
dict,
'payload',
i
)
# Basic validation
if len(vector) != EMBEDDING_DIMENSION:
raise ValueError(f"Vector at index {i} has incorrect size: {len(vector)} (Expected: {EMBEDDING_DIMENSION})")
## 2. Create the Qdrant PointStruct
points_to_upsert.append(
models.PointStruct(
id=dataset[i]['id'],
vector=vector,
payload=payload
# vector=ast.literal_eval(dataset[i][vector_column_name]),
# payload=ast.literal_eval(dataset[i]['payload'])
)
)
# 3. Upsert points in batches (optional, but good practice for large sets)
for i in range(0, len(points_to_upsert), BATCH_SIZE):
client.upsert(
collection_name=COLLECTION_NAME,
points=points_to_upsert[i:i+BATCH_SIZE]
)
logger.info("Data Upsert and indexing complete.")
except Exception as e:
logger.error(f"Data indexing failed. Check dataset structure or network access. Error: {e}")
def main_initialization():
"""Main execution block."""
max_retries = 10
# Wait for the local Qdrant server (started by start.sh) to become ready
for attempt in range(max_retries):
try:
client = get_qdrant_client_raw()
# Simple check to see if the server is responsive
client.get_collections()
logger.info("Qdrant server is reachable.")
break
except Exception as e:
logger.warning(f"Attempt {attempt + 1}/{max_retries}: Qdrant not ready yet. Retrying in 5 seconds. ({e})")
time.sleep(5)
else:
logger.error("Qdrant server failed to start within the maximum retry limit. Exiting initialization.")
return
# Check if the collection exists (meaning data was previously indexed)
collection_exists = create_collection_if_not_exists(client)
# If the collection was just created (or if it never existed), load the data
if not collection_exists:
# We no longer need to initialize an embedding model
load_and_index_data(client)
else:
logger.info("Collection already populated, skipping data indexing.")
if __name__ == "__main__":
main_initialization()