Spaces:
Sleeping
Sleeping
| """ | |
| Load HuggingFace datasets and ingest into Qdrant | |
| No local file dependencies - uses only HF datasets | |
| """ | |
| import os | |
| import sys | |
| import hashlib | |
| import json | |
| from pathlib import Path | |
| from typing import List, Dict | |
| # Add src directory to path | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from qdrant_setup import QdrantSetup | |
| from embedding_generator import EmbeddingGenerator | |
| from datasets import load_dataset | |
| # Hash file to track ingested documents | |
| HASH_FILE = "hf_datasets_hashes.json" | |
| def get_dataset_hashes() -> Dict[str, str]: | |
| """Load existing dataset hashes""" | |
| if os.path.exists(HASH_FILE): | |
| with open(HASH_FILE, "r") as f: | |
| return json.load(f) | |
| return {} | |
| def save_dataset_hashes(hashes: Dict[str, str]): | |
| """Save dataset hashes""" | |
| with open(HASH_FILE, "w") as f: | |
| json.dump(hashes, f, indent=2) | |
| def compute_dataset_hash(dataset_name: str, config: str, split: str, data) -> str: | |
| """Compute a hash for the dataset to detect changes""" | |
| # Use dataset info for hashing | |
| info = f"{dataset_name}:{config}:{split}:{len(data)}" | |
| return hashlib.md5(info.encode()).hexdigest() | |
| def parse_dataset_spec(spec: str) -> tuple: | |
| """Parse dataset specification: name:config:split""" | |
| parts = spec.strip().split(":") | |
| if len(parts) == 3: | |
| return parts[0], parts[1], parts[2] | |
| elif len(parts) == 2: | |
| return parts[0], parts[1], "train" | |
| else: | |
| return parts[0], None, "train" | |
| def load_and_ingest_dataset(qdrant_client, collection_name: str, embedding_func, | |
| dataset_name: str, config: str, split: str): | |
| """Load a dataset from HuggingFace and ingest into Qdrant""" | |
| print(f"Loading dataset: {dataset_name} (config={config}, split={split})") | |
| try: | |
| # Load dataset | |
| if config: | |
| dataset = load_dataset(dataset_name, config, split=split) | |
| else: | |
| dataset = load_dataset(dataset_name, split=split) | |
| print(f" Loaded {len(dataset)} documents") | |
| # Prepare documents for ingestion | |
| texts_to_ingest = [] | |
| metadatas_to_ingest = [] | |
| for item in dataset: | |
| # Extract text - handle different dataset formats | |
| text = None | |
| if "text" in item: | |
| text = item["text"] | |
| elif "content" in item: | |
| text = item["content"] | |
| elif "passage" in item: | |
| text = item["passage"] | |
| elif "document" in item: | |
| text = item["document"] | |
| if text and isinstance(text, str) and text.strip(): | |
| texts_to_ingest.append(text) | |
| # Extract metadata | |
| metadata = { | |
| "title": item.get("title", "") or "", | |
| "author": item.get("author", "") or "", | |
| "genre": item.get("genre", "") or "", | |
| "source": f"{dataset_name}:{config}:{split}", | |
| } | |
| # Add language info if available | |
| if "language" in item: | |
| metadata["language"] = item["language"] | |
| metadatas_to_ingest.append(metadata) | |
| if not texts_to_ingest: | |
| print(f" No valid texts found in dataset") | |
| return 0 | |
| print(f" Found {len(texts_to_ingest)} valid texts to ingest") | |
| # Ingest in batches | |
| batch_size = 100 | |
| ingested_count = 0 | |
| from qdrant_client.http import models | |
| for i in range(0, len(texts_to_ingest), batch_size): | |
| batch_texts = texts_to_ingest[i:i + batch_size] | |
| batch_metadatas = metadatas_to_ingest[i:i + batch_size] | |
| # Generate embeddings | |
| embeddings = [] | |
| for text in batch_texts: | |
| embedding = embedding_func(text) | |
| embeddings.append(embedding) | |
| # Create points | |
| points = [] | |
| for j, (text, metadata, embedding) in enumerate(zip(batch_texts, batch_metadatas, embeddings)): | |
| point = models.PointStruct( | |
| id=ingested_count + j, | |
| vector=embedding, | |
| payload={ | |
| "full_text": text, | |
| **metadata | |
| } | |
| ) | |
| points.append(point) | |
| # Upload to Qdrant | |
| qdrant_client.upsert( | |
| collection_name=collection_name, | |
| points=points | |
| ) | |
| ingested_count += len(batch_texts) | |
| print(f" Ingested {ingested_count}/{len(texts_to_ingest)} documents") | |
| print(f" ✓ Successfully ingested {ingested_count} documents") | |
| return ingested_count | |
| except Exception as e: | |
| print(f" Error loading dataset: {e}") | |
| return 0 | |
| def main(): | |
| """Main function to load and ingest all configured datasets""" | |
| print("=" * 60) | |
| print("HuggingFace Dataset Loader for Simple RAG") | |
| print("=" * 60) | |
| # Get datasets from environment | |
| hf_datasets = os.getenv("HF_DATASETS", "") | |
| if not hf_datasets: | |
| print("No HF_DATASETS environment variable set.") | |
| print("Set HF_DATASETS to load datasets (e.g., miracl/miracl-corpus:hi:train)") | |
| return | |
| # Initialize Qdrant | |
| print("\nInitializing Qdrant...") | |
| qdrant_setup = QdrantSetup() | |
| qdrant_client = qdrant_setup.get_client() | |
| collection_name = qdrant_setup.get_collection_name() | |
| # Create collection if not exists | |
| qdrant_setup.create_collection() | |
| # Initialize embedding generator | |
| print("Initializing embedding generator...") | |
| embedding_func = EmbeddingGenerator().get_embedding | |
| # Load existing hashes | |
| dataset_hashes = get_dataset_hashes() | |
| # Parse and process datasets | |
| dataset_specs = [spec.strip() for spec in hf_datasets.split(",")] | |
| total_ingested = 0 | |
| for spec in dataset_specs: | |
| if not spec: | |
| continue | |
| dataset_name, config, split = parse_dataset_spec(spec) | |
| dataset_key = f"{dataset_name}:{config}:{split}" | |
| # Check if already ingested | |
| if dataset_key in dataset_hashes: | |
| print(f"\nSkipping {dataset_key} (already ingested)") | |
| continue | |
| print(f"\nProcessing: {dataset_key}") | |
| # Load and ingest | |
| count = load_and_ingest_dataset( | |
| qdrant_client, collection_name, embedding_func, | |
| dataset_name, config, split | |
| ) | |
| if count > 0: | |
| # Save hash | |
| dataset_hashes[dataset_key] = compute_dataset_hash( | |
| dataset_name, config, split, | |
| {"count": count} | |
| ) | |
| save_dataset_hashes(dataset_hashes) | |
| total_ingested += count | |
| print("\n" + "=" * 60) | |
| print(f"Total documents ingested: {total_ingested}") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |