Spaces:
Sleeping
Sleeping
File size: 4,103 Bytes
6874d8b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | """
Main ingestion script for loading Nuinamath dataset into Qdrant.
"""
import logging
import os
from datasets import load_dataset
from tqdm import tqdm
import time
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configuration settings
QDRANT_URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath")
DATASET_NAME = "AI-MO/NuminaMath-CoT"
DATASET_SPLIT = "train"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
VECTOR_SIZE = 384
DISTANCE_METRIC = "Cosine"
BATCH_SIZE = 100
MAX_SAMPLES = None
# Validation
if not QDRANT_URL or not QDRANT_API_KEY:
raise ValueError("Please set QDRANT_URL and QDRANT_API_KEY in your .env file")
from utils import EmbeddingGenerator, batch_process_dataset
from qdrant_manager import QdrantManager
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def main():
"""Main ingestion pipeline."""
try:
# Initialize components
logger.info("Initializing components...")
embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL)
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
# Load dataset
logger.info(f"Loading dataset: {DATASET_NAME}")
if MAX_SAMPLES:
dataset = load_dataset(DATASET_NAME, split=f"{DATASET_SPLIT}[:{MAX_SAMPLES}]")
logger.info(f"Loaded {len(dataset)} samples (limited)")
else:
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
logger.info(f"Loaded full dataset: {len(dataset)} samples")
# Create Qdrant collection
logger.info(f"Creating collection: {QDRANT_COLLECTION}")
success = qdrant_manager.create_collection(
collection_name=QDRANT_COLLECTION,
vector_size=VECTOR_SIZE,
distance=DISTANCE_METRIC
)
if not success:
logger.error("Failed to create collection")
return
# Process dataset in batches
logger.info("Processing dataset in batches...")
batches = batch_process_dataset(dataset, BATCH_SIZE)
total_processed = 0
total_batches = len(batches)
for batch_idx, batch_data in enumerate(tqdm(batches, desc="Processing batches")):
try:
# Extract texts for embedding
texts = [item['text'] for item in batch_data]
# Generate embeddings
logger.info(f"Generating embeddings for batch {batch_idx + 1}/{total_batches}")
embeddings = embedding_generator.embed_text(texts)
# Upsert to Qdrant
logger.info(f"Uploading batch {batch_idx + 1} to Qdrant...")
qdrant_manager.upsert_points(
collection_name=QDRANT_COLLECTION,
points_data=batch_data,
embeddings=embeddings
)
total_processed += len(batch_data)
logger.info(f"Progress: {total_processed}/{len(dataset)} items processed")
# Small delay to avoid overwhelming the API
time.sleep(0.5)
except Exception as e:
logger.error(f"Error processing batch {batch_idx + 1}: {e}")
continue
# Final summary
logger.info("Ingestion completed!")
logger.info(f"Total items processed: {total_processed}")
# Get collection info
collection_info = qdrant_manager.get_collection_info(QDRANT_COLLECTION)
if collection_info:
logger.info(f"Collection status: {collection_info.status}")
logger.info(f"Vectors count: {collection_info.vectors_count}")
except Exception as e:
logger.error(f"Fatal error in ingestion pipeline: {e}")
raise
if __name__ == "__main__":
main()
|