File size: 7,331 Bytes
70db68d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | import os
import time
import logging
import uuid
import numpy as np
from qdrant_client import QdrantClient, models
from datasets import load_dataset
from typing import List, Dict, Any
import ast
os.environ['HF_HOME'] = '/tmp/hf_cache'
# --- Setup Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration ---
# Read configuration from environment variables set in the Dockerfile
QDRANT_HOST = os.getenv("QDRANT_HOST", "127.0.0.1")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333))
# Embeddings related params and config
HF_DATASET_NAME = os.getenv("EMBEDDING_DATASET")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "default_collection")
EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIM_SIZE", 1024))
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 200))
# TOKENS AND SECRETS
embedding_token = os.getenv("INS_READ_TOKEN")
qdrant_token = os.getenv("QDRANT__SERVICE__API_KEY")
# --- Core Functions ---
def get_qdrant_client_raw() -> QdrantClient:
"""Initialize the raw Qdrant client."""
# Connecting to the Qdrant server started locally via start.sh
logger.info("Getting client connection")
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, api_key = qdrant_token, https =False)
return client
def create_collection_if_not_exists(client: QdrantClient):
"""Checks if the collection exists and creates it if it doesn't."""
# Check if collection is present
collections = client.get_collections().collections
if COLLECTION_NAME in [c.name for c in collections]:
logger.info(f"Collection '{COLLECTION_NAME}' already exists. Skipping data loading.")
return True # Collection found
logger.info(f"Collection '{COLLECTION_NAME}' not found. Creating collection...")
try:
# Create the collection with the correct vector size
client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=models.VectorParams(
size=EMBEDDING_DIMENSION,
distance=models.Distance.COSINE
)
)
logger.info(f"Collection '{COLLECTION_NAME}' successfully created.")
return False # Collection was created (needs data loading)
except Exception as e:
logger.error(f"Failed to create collection: {e}")
return False
def safe_parse_data(data: Any, expected_type: type, field_name: str, index: int) -> Any:
"""
Safely parses input data, converting strings to the expected type (list/dict)
using ast.literal_eval if necessary, or raises an error.
"""
if isinstance(data, expected_type):
return data
if isinstance(data, str):
try:
parsed_data = ast.literal_eval(data)
if isinstance(parsed_data, expected_type):
return parsed_data
else:
raise ValueError(f"Parsed type is {type(parsed_data)}, not expected {expected_type}")
except (ValueError, SyntaxError, TypeError) as e:
raise ValueError(
f"Data parsing error for {field_name} at index {index}: "
f"Input string '{data[:50]}...' is not a valid {expected_type.__name__} string. Error: {e}"
)
raise TypeError(
f"Data type error for {field_name} at index {index}: "
f"Expected {expected_type.__name__} or string representation, but got {type(data).__name__}"
)
def load_and_index_data(client: QdrantClient):
"""Pulls pre-embedded data from HF and indexes it into Qdrant using client.upsert."""
logger.info(f"Starting data loading from Hugging Face dataset: {HF_DATASET_NAME}")
try:
# Load a small slice of the dataset for demonstration/speed
dataset = load_dataset(HF_DATASET_NAME, token = embedding_token, cache_dir="/tmp/datasets_cache")
dataset = dataset['train']
points_to_upsert: List[models.PointStruct] = []
vector_column_name = os.getenv("VECTOR_COLUMN", "vector")
logger.info(f"Loaded {len(dataset)} documents. Preparing points for indexing...")
logger.info(f"Dataset id of type {type(dataset[0]['id'])}")
logger.info(f"Dataset vector of type {type(dataset[0][vector_column_name])}")
logger.info(f"Dataset payload of type {type(dataset[0]['payload'])}")
for i in range(0, len(dataset)):
# 1. Getting Column name for vectors
# Use safe_parse_data for vectors (expected List[float])
vector = safe_parse_data(
dataset[i].get(vector_column_name),
list,
vector_column_name,
i
)
# Use safe_parse_data for payloads (expected Dict[str, Any])
payload = safe_parse_data(
dataset[i].get('payload'),
dict,
'payload',
i
)
# Basic validation
if len(vector) != EMBEDDING_DIMENSION:
raise ValueError(f"Vector at index {i} has incorrect size: {len(vector)} (Expected: {EMBEDDING_DIMENSION})")
## 2. Create the Qdrant PointStruct
points_to_upsert.append(
models.PointStruct(
id=dataset[i]['id'],
vector=vector,
payload=payload
# vector=ast.literal_eval(dataset[i][vector_column_name]),
# payload=ast.literal_eval(dataset[i]['payload'])
)
)
# 3. Upsert points in batches (optional, but good practice for large sets)
for i in range(0, len(points_to_upsert), BATCH_SIZE):
client.upsert(
collection_name=COLLECTION_NAME,
points=points_to_upsert[i:i+BATCH_SIZE]
)
logger.info("Data Upsert and indexing complete.")
except Exception as e:
logger.error(f"Data indexing failed. Check dataset structure or network access. Error: {e}")
def main_initialization():
"""Main execution block."""
max_retries = 10
# Wait for the local Qdrant server (started by start.sh) to become ready
for attempt in range(max_retries):
try:
client = get_qdrant_client_raw()
# Simple check to see if the server is responsive
client.get_collections()
logger.info("Qdrant server is reachable.")
break
except Exception as e:
logger.warning(f"Attempt {attempt + 1}/{max_retries}: Qdrant not ready yet. Retrying in 5 seconds. ({e})")
time.sleep(5)
else:
logger.error("Qdrant server failed to start within the maximum retry limit. Exiting initialization.")
return
# Check if the collection exists (meaning data was previously indexed)
collection_exists = create_collection_if_not_exists(client)
# If the collection was just created (or if it never existed), load the data
if not collection_exists:
# We no longer need to initialize an embedding model
load_and_index_data(client)
else:
logger.info("Collection already populated, skipping data indexing.")
if __name__ == "__main__":
main_initialization() |