File size: 7,331 Bytes
70db68d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import time
import logging
import uuid
import numpy as np 
from qdrant_client import QdrantClient, models
from datasets import load_dataset
from typing import List, Dict, Any
import ast

os.environ['HF_HOME'] = '/tmp/hf_cache'

# --- Setup Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Configuration ---
# Read configuration from environment variables set in the Dockerfile
QDRANT_HOST = os.getenv("QDRANT_HOST", "127.0.0.1")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333))

# Embeddings related params and config
HF_DATASET_NAME = os.getenv("EMBEDDING_DATASET")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "default_collection")
EMBEDDING_DIMENSION = int(os.getenv("EMBEDDING_DIM_SIZE", 1024))
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 200))
# TOKENS AND SECRETS
embedding_token = os.getenv("INS_READ_TOKEN")
qdrant_token = os.getenv("QDRANT__SERVICE__API_KEY")


# --- Core Functions ---

def get_qdrant_client_raw() -> QdrantClient:
    """Initialize the raw Qdrant client."""
    # Connecting to the Qdrant server started locally via start.sh
    logger.info("Getting client connection")
    client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT, api_key = qdrant_token, https =False)
    return client

def create_collection_if_not_exists(client: QdrantClient):
    """Checks if the collection exists and creates it if it doesn't."""
    
    # Check if collection is present
    collections = client.get_collections().collections
    if COLLECTION_NAME in [c.name for c in collections]:
        logger.info(f"Collection '{COLLECTION_NAME}' already exists. Skipping data loading.")
        return True # Collection found
    
    logger.info(f"Collection '{COLLECTION_NAME}' not found. Creating collection...")
    
    try:
        # Create the collection with the correct vector size
        client.recreate_collection(
            collection_name=COLLECTION_NAME,
            vectors_config=models.VectorParams(
                size=EMBEDDING_DIMENSION, 
                distance=models.Distance.COSINE
            )
        )
        logger.info(f"Collection '{COLLECTION_NAME}' successfully created.")
        return False # Collection was created (needs data loading)
        
    except Exception as e:
        logger.error(f"Failed to create collection: {e}")
        return False

def safe_parse_data(data: Any, expected_type: type, field_name: str, index: int) -> Any:
    """
    Safely parses input data, converting strings to the expected type (list/dict) 
    using ast.literal_eval if necessary, or raises an error.
    """
    if isinstance(data, expected_type):
        return data

    if isinstance(data, str):
        try:
            parsed_data = ast.literal_eval(data)
            if isinstance(parsed_data, expected_type):
                return parsed_data
            else:
                raise ValueError(f"Parsed type is {type(parsed_data)}, not expected {expected_type}")
        except (ValueError, SyntaxError, TypeError) as e:
            raise ValueError(
                f"Data parsing error for {field_name} at index {index}: "
                f"Input string '{data[:50]}...' is not a valid {expected_type.__name__} string. Error: {e}"
            )

    raise TypeError(
        f"Data type error for {field_name} at index {index}: "
        f"Expected {expected_type.__name__} or string representation, but got {type(data).__name__}"
    )

def load_and_index_data(client: QdrantClient):
    """Pulls pre-embedded data from HF and indexes it into Qdrant using client.upsert."""
    logger.info(f"Starting data loading from Hugging Face dataset: {HF_DATASET_NAME}")
    
    try:
        # Load a small slice of the dataset for demonstration/speed
        dataset = load_dataset(HF_DATASET_NAME, token = embedding_token, cache_dir="/tmp/datasets_cache") 
        dataset = dataset['train']
        points_to_upsert: List[models.PointStruct] = []
        vector_column_name = os.getenv("VECTOR_COLUMN", "vector")
        logger.info(f"Loaded {len(dataset)} documents. Preparing points for indexing...")
        logger.info(f"Dataset id of type {type(dataset[0]['id'])}")
        logger.info(f"Dataset vector of type {type(dataset[0][vector_column_name])}")
        logger.info(f"Dataset payload of type {type(dataset[0]['payload'])}")
        
        for i in range(0, len(dataset)):
            # 1. Getting Column name for vectors

            # Use safe_parse_data for vectors (expected List[float])
            vector = safe_parse_data(
                dataset[i].get(vector_column_name), 
                list, 
                vector_column_name, 
                i
            )

            # Use safe_parse_data for payloads (expected Dict[str, Any])
            payload = safe_parse_data(
                dataset[i].get('payload'), 
                dict, 
                'payload', 
                i
            )

            # Basic validation
            if len(vector) != EMBEDDING_DIMENSION:
                raise ValueError(f"Vector at index {i} has incorrect size: {len(vector)} (Expected: {EMBEDDING_DIMENSION})")

            
            
            ## 2. Create the Qdrant PointStruct
            points_to_upsert.append(
                models.PointStruct(
                    id=dataset[i]['id'],
                    vector=vector,
                    payload=payload
                    # vector=ast.literal_eval(dataset[i][vector_column_name]),
                    # payload=ast.literal_eval(dataset[i]['payload'])
                )
            )
        # 3. Upsert points in batches (optional, but good practice for large sets)
        
        for i in range(0, len(points_to_upsert), BATCH_SIZE):
            client.upsert(
                collection_name=COLLECTION_NAME,
                points=points_to_upsert[i:i+BATCH_SIZE]
            )      
        logger.info("Data Upsert and indexing complete.")
        
    except Exception as e:
        logger.error(f"Data indexing failed. Check dataset structure or network access. Error: {e}")


def main_initialization():
    """Main execution block."""
    max_retries = 10
    
    # Wait for the local Qdrant server (started by start.sh) to become ready
    for attempt in range(max_retries):
        try:
            client = get_qdrant_client_raw()
            # Simple check to see if the server is responsive
            client.get_collections() 
            logger.info("Qdrant server is reachable.")
            break
        except Exception as e:
            logger.warning(f"Attempt {attempt + 1}/{max_retries}: Qdrant not ready yet. Retrying in 5 seconds. ({e})")
            time.sleep(5)
    else:
        logger.error("Qdrant server failed to start within the maximum retry limit. Exiting initialization.")
        return
        
    # Check if the collection exists (meaning data was previously indexed)
    collection_exists = create_collection_if_not_exists(client)
    
    # If the collection was just created (or if it never existed), load the data
    if not collection_exists:
        # We no longer need to initialize an embedding model
        load_and_index_data(client)
    else:
        logger.info("Collection already populated, skipping data indexing.")


if __name__ == "__main__":
    main_initialization()