Spaces:
Sleeping
Sleeping
| """ | |
| Main ingestion script for loading Nuinamath dataset into Qdrant. | |
| """ | |
| import logging | |
| import os | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| import time | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Configuration settings | |
| QDRANT_URL = os.getenv("QDRANT_URL") | |
| QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
| QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath") | |
| DATASET_NAME = "AI-MO/NuminaMath-CoT" | |
| DATASET_SPLIT = "train" | |
| EMBEDDING_MODEL = "all-MiniLM-L6-v2" | |
| VECTOR_SIZE = 384 | |
| DISTANCE_METRIC = "Cosine" | |
| BATCH_SIZE = 100 | |
| MAX_SAMPLES = None | |
| # Validation | |
| if not QDRANT_URL or not QDRANT_API_KEY: | |
| raise ValueError("Please set QDRANT_URL and QDRANT_API_KEY in your .env file") | |
| from utils import EmbeddingGenerator, batch_process_dataset | |
| from qdrant_manager import QdrantManager | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| def main(): | |
| """Main ingestion pipeline.""" | |
| try: | |
| # Initialize components | |
| logger.info("Initializing components...") | |
| embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL) | |
| qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY) | |
| # Load dataset | |
| logger.info(f"Loading dataset: {DATASET_NAME}") | |
| if MAX_SAMPLES: | |
| dataset = load_dataset(DATASET_NAME, split=f"{DATASET_SPLIT}[:{MAX_SAMPLES}]") | |
| logger.info(f"Loaded {len(dataset)} samples (limited)") | |
| else: | |
| dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT) | |
| logger.info(f"Loaded full dataset: {len(dataset)} samples") | |
| # Create Qdrant collection | |
| logger.info(f"Creating collection: {QDRANT_COLLECTION}") | |
| success = qdrant_manager.create_collection( | |
| collection_name=QDRANT_COLLECTION, | |
| vector_size=VECTOR_SIZE, | |
| distance=DISTANCE_METRIC | |
| ) | |
| if not success: | |
| logger.error("Failed to create collection") | |
| return | |
| # Process dataset in batches | |
| logger.info("Processing dataset in batches...") | |
| batches = batch_process_dataset(dataset, BATCH_SIZE) | |
| total_processed = 0 | |
| total_batches = len(batches) | |
| for batch_idx, batch_data in enumerate(tqdm(batches, desc="Processing batches")): | |
| try: | |
| # Extract texts for embedding | |
| texts = [item['text'] for item in batch_data] | |
| # Generate embeddings | |
| logger.info(f"Generating embeddings for batch {batch_idx + 1}/{total_batches}") | |
| embeddings = embedding_generator.embed_text(texts) | |
| # Upsert to Qdrant | |
| logger.info(f"Uploading batch {batch_idx + 1} to Qdrant...") | |
| qdrant_manager.upsert_points( | |
| collection_name=QDRANT_COLLECTION, | |
| points_data=batch_data, | |
| embeddings=embeddings | |
| ) | |
| total_processed += len(batch_data) | |
| logger.info(f"Progress: {total_processed}/{len(dataset)} items processed") | |
| # Small delay to avoid overwhelming the API | |
| time.sleep(0.5) | |
| except Exception as e: | |
| logger.error(f"Error processing batch {batch_idx + 1}: {e}") | |
| continue | |
| # Final summary | |
| logger.info("Ingestion completed!") | |
| logger.info(f"Total items processed: {total_processed}") | |
| # Get collection info | |
| collection_info = qdrant_manager.get_collection_info(QDRANT_COLLECTION) | |
| if collection_info: | |
| logger.info(f"Collection status: {collection_info.status}") | |
| logger.info(f"Vectors count: {collection_info.vectors_count}") | |
| except Exception as e: | |
| logger.error(f"Fatal error in ingestion pipeline: {e}") | |
| raise | |
| if __name__ == "__main__": | |
| main() | |