indicRAG / backend /src /load_huggingface_dataset.py
hardkpentium101's picture
merge local branch
2e82da7
"""
Load HuggingFace datasets and ingest into Qdrant
No local file dependencies - uses only HF datasets
"""
import os
import sys
import hashlib
import json
from pathlib import Path
from typing import List, Dict
# Add src directory to path
sys.path.insert(0, str(Path(__file__).parent))
from qdrant_setup import QdrantSetup
from embedding_generator import EmbeddingGenerator
from datasets import load_dataset
# Hash file to track ingested documents
HASH_FILE = "hf_datasets_hashes.json"
def get_dataset_hashes() -> Dict[str, str]:
"""Load existing dataset hashes"""
if os.path.exists(HASH_FILE):
with open(HASH_FILE, "r") as f:
return json.load(f)
return {}
def save_dataset_hashes(hashes: Dict[str, str]):
"""Save dataset hashes"""
with open(HASH_FILE, "w") as f:
json.dump(hashes, f, indent=2)
def compute_dataset_hash(dataset_name: str, config: str, split: str, data) -> str:
"""Compute a hash for the dataset to detect changes"""
# Use dataset info for hashing
info = f"{dataset_name}:{config}:{split}:{len(data)}"
return hashlib.md5(info.encode()).hexdigest()
def parse_dataset_spec(spec: str) -> tuple:
"""Parse dataset specification: name:config:split"""
parts = spec.strip().split(":")
if len(parts) == 3:
return parts[0], parts[1], parts[2]
elif len(parts) == 2:
return parts[0], parts[1], "train"
else:
return parts[0], None, "train"
def load_and_ingest_dataset(qdrant_client, collection_name: str, embedding_func,
dataset_name: str, config: str, split: str):
"""Load a dataset from HuggingFace and ingest into Qdrant"""
print(f"Loading dataset: {dataset_name} (config={config}, split={split})")
try:
# Load dataset
if config:
dataset = load_dataset(dataset_name, config, split=split)
else:
dataset = load_dataset(dataset_name, split=split)
print(f" Loaded {len(dataset)} documents")
# Prepare documents for ingestion
texts_to_ingest = []
metadatas_to_ingest = []
for item in dataset:
# Extract text - handle different dataset formats
text = None
if "text" in item:
text = item["text"]
elif "content" in item:
text = item["content"]
elif "passage" in item:
text = item["passage"]
elif "document" in item:
text = item["document"]
if text and isinstance(text, str) and text.strip():
texts_to_ingest.append(text)
# Extract metadata
metadata = {
"title": item.get("title", "") or "",
"author": item.get("author", "") or "",
"genre": item.get("genre", "") or "",
"source": f"{dataset_name}:{config}:{split}",
}
# Add language info if available
if "language" in item:
metadata["language"] = item["language"]
metadatas_to_ingest.append(metadata)
if not texts_to_ingest:
print(f" No valid texts found in dataset")
return 0
print(f" Found {len(texts_to_ingest)} valid texts to ingest")
# Ingest in batches
batch_size = 100
ingested_count = 0
from qdrant_client.http import models
for i in range(0, len(texts_to_ingest), batch_size):
batch_texts = texts_to_ingest[i:i + batch_size]
batch_metadatas = metadatas_to_ingest[i:i + batch_size]
# Generate embeddings
embeddings = []
for text in batch_texts:
embedding = embedding_func(text)
embeddings.append(embedding)
# Create points
points = []
for j, (text, metadata, embedding) in enumerate(zip(batch_texts, batch_metadatas, embeddings)):
point = models.PointStruct(
id=ingested_count + j,
vector=embedding,
payload={
"full_text": text,
**metadata
}
)
points.append(point)
# Upload to Qdrant
qdrant_client.upsert(
collection_name=collection_name,
points=points
)
ingested_count += len(batch_texts)
print(f" Ingested {ingested_count}/{len(texts_to_ingest)} documents")
print(f" ✓ Successfully ingested {ingested_count} documents")
return ingested_count
except Exception as e:
print(f" Error loading dataset: {e}")
return 0
def main():
"""Main function to load and ingest all configured datasets"""
print("=" * 60)
print("HuggingFace Dataset Loader for Simple RAG")
print("=" * 60)
# Get datasets from environment
hf_datasets = os.getenv("HF_DATASETS", "")
if not hf_datasets:
print("No HF_DATASETS environment variable set.")
print("Set HF_DATASETS to load datasets (e.g., miracl/miracl-corpus:hi:train)")
return
# Initialize Qdrant
print("\nInitializing Qdrant...")
qdrant_setup = QdrantSetup()
qdrant_client = qdrant_setup.get_client()
collection_name = qdrant_setup.get_collection_name()
# Create collection if not exists
qdrant_setup.create_collection()
# Initialize embedding generator
print("Initializing embedding generator...")
embedding_func = EmbeddingGenerator().get_embedding
# Load existing hashes
dataset_hashes = get_dataset_hashes()
# Parse and process datasets
dataset_specs = [spec.strip() for spec in hf_datasets.split(",")]
total_ingested = 0
for spec in dataset_specs:
if not spec:
continue
dataset_name, config, split = parse_dataset_spec(spec)
dataset_key = f"{dataset_name}:{config}:{split}"
# Check if already ingested
if dataset_key in dataset_hashes:
print(f"\nSkipping {dataset_key} (already ingested)")
continue
print(f"\nProcessing: {dataset_key}")
# Load and ingest
count = load_and_ingest_dataset(
qdrant_client, collection_name, embedding_func,
dataset_name, config, split
)
if count > 0:
# Save hash
dataset_hashes[dataset_key] = compute_dataset_hash(
dataset_name, config, split,
{"count": count}
)
save_dataset_hashes(dataset_hashes)
total_ingested += count
print("\n" + "=" * 60)
print(f"Total documents ingested: {total_ingested}")
print("=" * 60)
if __name__ == "__main__":
main()