bhoomika19's picture
phase 1 - data storage in qdrant and retrieval
6874d8b
"""
Main ingestion script for loading Nuinamath dataset into Qdrant.
"""
import logging
import os
from datasets import load_dataset
from tqdm import tqdm
import time
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configuration settings
QDRANT_URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath")
DATASET_NAME = "AI-MO/NuminaMath-CoT"
DATASET_SPLIT = "train"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
VECTOR_SIZE = 384
DISTANCE_METRIC = "Cosine"
BATCH_SIZE = 100
MAX_SAMPLES = None
# Validation
if not QDRANT_URL or not QDRANT_API_KEY:
raise ValueError("Please set QDRANT_URL and QDRANT_API_KEY in your .env file")
from utils import EmbeddingGenerator, batch_process_dataset
from qdrant_manager import QdrantManager
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def main():
"""Main ingestion pipeline."""
try:
# Initialize components
logger.info("Initializing components...")
embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL)
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
# Load dataset
logger.info(f"Loading dataset: {DATASET_NAME}")
if MAX_SAMPLES:
dataset = load_dataset(DATASET_NAME, split=f"{DATASET_SPLIT}[:{MAX_SAMPLES}]")
logger.info(f"Loaded {len(dataset)} samples (limited)")
else:
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
logger.info(f"Loaded full dataset: {len(dataset)} samples")
# Create Qdrant collection
logger.info(f"Creating collection: {QDRANT_COLLECTION}")
success = qdrant_manager.create_collection(
collection_name=QDRANT_COLLECTION,
vector_size=VECTOR_SIZE,
distance=DISTANCE_METRIC
)
if not success:
logger.error("Failed to create collection")
return
# Process dataset in batches
logger.info("Processing dataset in batches...")
batches = batch_process_dataset(dataset, BATCH_SIZE)
total_processed = 0
total_batches = len(batches)
for batch_idx, batch_data in enumerate(tqdm(batches, desc="Processing batches")):
try:
# Extract texts for embedding
texts = [item['text'] for item in batch_data]
# Generate embeddings
logger.info(f"Generating embeddings for batch {batch_idx + 1}/{total_batches}")
embeddings = embedding_generator.embed_text(texts)
# Upsert to Qdrant
logger.info(f"Uploading batch {batch_idx + 1} to Qdrant...")
qdrant_manager.upsert_points(
collection_name=QDRANT_COLLECTION,
points_data=batch_data,
embeddings=embeddings
)
total_processed += len(batch_data)
logger.info(f"Progress: {total_processed}/{len(dataset)} items processed")
# Small delay to avoid overwhelming the API
time.sleep(0.5)
except Exception as e:
logger.error(f"Error processing batch {batch_idx + 1}: {e}")
continue
# Final summary
logger.info("Ingestion completed!")
logger.info(f"Total items processed: {total_processed}")
# Get collection info
collection_info = qdrant_manager.get_collection_info(QDRANT_COLLECTION)
if collection_info:
logger.info(f"Collection status: {collection_info.status}")
logger.info(f"Vectors count: {collection_info.vectors_count}")
except Exception as e:
logger.error(f"Fatal error in ingestion pipeline: {e}")
raise
if __name__ == "__main__":
main()