RAG-Pipeline-Optimizer / utils /dataset_loader.py
puji4ml's picture
Upload 30 files
2b22a59 verified
"""
utils/dataset_loader.py - Load Natural Questions + Wikipedia Dataset
======================================================================
Clean implementation for loading:
- Natural Questions: Q&A dataset with answers extracted from Wikipedia
- Wikipedia: Standard Wikipedia dataset chunked into ~100 word passages
"""
from typing import List, Dict, Optional
from dataclasses import dataclass
from pathlib import Path
import json
from datasets import load_dataset
from tqdm import tqdm
@dataclass
class Question:
"""Question with answers and context"""
id: str
question: str
answers: List[str]
context: str = ""
has_answer: bool = True
# =============================================================================
# NATURAL QUESTIONS DATASET
# =============================================================================
class NaturalQuestionsDataset:
"""Load Natural Questions dataset"""
def __init__(self, cache_dir: str = "./data/datasets"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.dataset = None
self.questions = []
def load(self, max_samples: int = 100, show_progress: bool = True):
"""Load Natural Questions validation set"""
if show_progress:
print(f"\nπŸ“₯ Loading Natural Questions...")
print("=" * 80)
# Load from HuggingFace
self.dataset = load_dataset(
"google-research-datasets/natural_questions",
split="validation",
cache_dir=str(self.cache_dir),
)
if show_progress:
print(f"βœ… Loaded {len(self.dataset)} questions from dataset")
# Extract questions
self._extract_questions(max_samples, show_progress)
def _extract_questions(self, max_samples: int, show_progress: bool):
"""Extract questions from dataset"""
if show_progress:
print(f"\nπŸ”„ Extracting questions...")
# Debug first item
if show_progress:
print(f"\nπŸ” Inspecting first item structure:")
self._debug_first_item()
num_to_process = min(max_samples, len(self.dataset))
questions = []
for idx in tqdm(range(num_to_process), desc="Processing", disable=not show_progress):
item = self.dataset[idx]
# Get question
question_text = self._extract_question(item)
if not question_text:
continue
# Get answers
answers = self._extract_answers(item)
if not answers: # Skip questions without answers
continue
# Get context
context = self._extract_context(item)
questions.append(Question(
id=str(item.get('id', f'nq_{idx}')),
question=question_text,
answers=answers,
context=context,
has_answer=True
))
self.questions = questions
if show_progress:
print(f"\nβœ… Successfully extracted {len(questions)} questions with answers")
if len(questions) == 0:
print(f"❌ ERROR: Extracted 0 questions from {num_to_process} items")
print(f" This means the answer extraction is failing.")
print(f" Check the debug output above to see the data structure.")
def _debug_first_item(self):
"""Debug first item to understand structure"""
if len(self.dataset) == 0:
print(" ❌ Dataset is empty!")
return
item = self.dataset[0]
# Question
print(f"\n πŸ“ Question:")
if 'question' in item:
q = item['question']
print(f" Type: {type(q)}")
if isinstance(q, dict):
print(f" Keys: {list(q.keys())}")
print(f" Text: {q.get('text', 'N/A')[:100]}")
else:
print(f" Value: {str(q)[:100]}")
# Annotations
print(f"\n πŸ“‹ Annotations:")
if 'annotations' in item:
anns = item['annotations']
print(f" Type: {type(anns)}")
if isinstance(anns, list):
print(f" Count: {len(anns)}")
if len(anns) > 0:
ann = anns[0]
print(f" First annotation keys: {list(ann.keys())}")
if 'short_answers' in ann:
short_ans = ann['short_answers']
print(f" Short answers count: {len(short_ans) if isinstance(short_ans, list) else 'Not a list'}")
if isinstance(short_ans, list) and len(short_ans) > 0:
print(f" First short answer: {short_ans[0]}")
elif isinstance(anns, dict):
print(f" It's a dict with keys: {list(anns.keys())}")
# If it's a dict, check if it has the fields we need
if 'short_answers' in anns:
short_ans = anns['short_answers']
print(f" Short answers type: {type(short_ans)}")
print(f" Short answers: {short_ans}")
else:
print(f" Unexpected type: {type(anns)}")
print(f" Value: {anns}")
else:
print(f" ❌ No 'annotations' field found")
# Document tokens
print(f"\n πŸ“„ Document:")
if 'document' in item:
doc = item['document']
print(f" Type: {type(doc)}")
if isinstance(doc, dict):
print(f" Keys: {list(doc.keys())}")
if 'tokens' in doc:
tokens = doc['tokens']
print(f" Tokens type: {type(tokens)}")
if isinstance(tokens, dict):
print(f" Tokens keys: {list(tokens.keys())}")
if 'token' in tokens:
token_list = tokens['token']
print(f" Token list length: {len(token_list) if isinstance(token_list, list) else 'Not a list'}")
if isinstance(token_list, list) and len(token_list) > 0:
print(f" First 10 tokens: {token_list[:10]}")
elif isinstance(tokens, list):
print(f" Tokens list length: {len(tokens)}")
print(f" First 10 tokens: {tokens[:10]}")
# Try extraction
print(f"\n πŸ§ͺ Testing extraction methods:")
question = self._extract_question(item)
print(f" Question extracted: '{question[:50]}...' " if question else " ❌ Failed to extract question")
answers = self._extract_answers(item)
print(f" Answers extracted: {answers}" if answers else " ❌ Failed to extract answers")
if not answers and 'annotations' in item:
print(f"\n πŸ” Deep dive into annotations:")
anns = item['annotations']
# Convert to list format for uniform processing
ann_list = [anns] if isinstance(anns, dict) else (anns if isinstance(anns, list) else [])
for i, ann in enumerate(ann_list[:2]): # Check first 2 annotations
print(f"\n Annotation {i}:")
if isinstance(ann, dict):
print(f" Keys: {list(ann.keys())}")
if 'short_answers' in ann:
sa = ann['short_answers']
print(f" Short answers type: {type(sa)}")
if isinstance(sa, list):
print(f" Short answers count: {len(sa)}")
if len(sa) > 0:
first_sa = sa[0]
print(f" First short answer: {first_sa}")
if isinstance(first_sa, dict):
start = first_sa.get('start_token')
end = first_sa.get('end_token')
print(f" Start: {start}, End: {end}")
if start is not None and end is not None:
reconstructed = self._get_text_from_tokens(item, start, end)
print(f" Reconstructed: '{reconstructed}'")
def _extract_question(self, item: dict) -> str:
"""Extract question text"""
if 'question' not in item:
return ""
q = item['question']
if isinstance(q, dict):
return q.get('text', '')
elif isinstance(q, str):
return q
return ""
def _extract_answers(self, item: dict) -> List[str]:
"""Extract answers from annotations"""
answers = []
if 'annotations' not in item:
return answers
annotations = item['annotations']
# Handle both list and dict formats
annotation_list = []
if isinstance(annotations, list):
annotation_list = annotations
elif isinstance(annotations, dict):
# If it's a single dict, treat it as a list of one
annotation_list = [annotations]
else:
return answers
for ann in annotation_list:
if not isinstance(ann, dict):
continue
# Get short answers
short_answers = ann.get('short_answers', [])
if not isinstance(short_answers, list) or len(short_answers) == 0:
continue
# Extract each answer
for ans_span in short_answers:
if not isinstance(ans_span, dict):
continue
# Method 1: Try to get text directly (if available)
text_field = ans_span.get('text', [])
if isinstance(text_field, list) and len(text_field) > 0:
# Text is provided directly
for text in text_field:
if text and str(text).strip():
answers.append(str(text).strip())
continue
# Method 2: Reconstruct from tokens
start = ans_span.get('start_token')
end = ans_span.get('end_token')
# Handle list format for start/end tokens
if isinstance(start, list):
start = start[0] if len(start) > 0 else None
if isinstance(end, list):
end = end[0] if len(end) > 0 else None
if start is None or end is None:
continue
# Reconstruct from tokens
answer_text = self._get_text_from_tokens(item, start, end)
if answer_text:
answers.append(answer_text)
# Remove duplicates
return list(set(a.strip() for a in answers if a and a.strip()))
def _get_text_from_tokens(self, item: dict, start_token: int, end_token: int) -> str:
"""Get text from document tokens using start/end indices"""
if 'document' not in item:
return ""
doc = item['document']
if not isinstance(doc, dict) or 'tokens' not in doc:
return ""
tokens = doc['tokens']
# Get token list
token_list = None
if isinstance(tokens, dict) and 'token' in tokens:
token_list = tokens['token']
elif isinstance(tokens, list):
token_list = tokens
if not token_list or not isinstance(token_list, list):
return ""
# Check bounds
if start_token < 0 or end_token >= len(token_list) or start_token > end_token:
return ""
# Extract and join tokens
answer_tokens = token_list[start_token:end_token+1]
return " ".join(str(t) for t in answer_tokens).strip()
def _extract_context(self, item: dict) -> str:
"""Extract context from document"""
if 'document' not in item:
return ""
doc = item['document']
if not isinstance(doc, dict) or 'tokens' not in doc:
return ""
tokens = doc['tokens']
# Get token list
token_list = None
if isinstance(tokens, dict) and 'token' in tokens:
token_list = tokens['token']
elif isinstance(tokens, list):
token_list = tokens
if not token_list or not isinstance(token_list, list):
return ""
# Take first 100 tokens as context
context_tokens = token_list[:100]
return " ".join(str(t) for t in context_tokens)
def get_questions(self) -> List[Question]:
"""Get loaded questions"""
return self.questions
def save_json(self, filepath: str):
"""Save questions to JSON"""
if not self.questions:
print(f"⚠️ No questions to save")
return
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
data = [
{
'id': q.id,
'question': q.question,
'answers': q.answers,
'context': q.context,
}
for q in self.questions
]
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
print(f"βœ… Saved {len(data)} questions to {filepath}")
def get_stats(self) -> Dict:
"""Get dataset statistics"""
if not self.questions:
return {
'total': 0,
'with_answers': 0,
'avg_question_len': 0,
'avg_answers': 0
}
return {
'total': len(self.questions),
'with_answers': sum(1 for q in self.questions if q.answers),
'avg_question_len': sum(len(q.question.split()) for q in self.questions) / len(self.questions),
'avg_answers': sum(len(q.answers) for q in self.questions) / len(self.questions)
}
# =============================================================================
# WIKIPEDIA CORPUS (Full Dataset from HuggingFace)
# =============================================================================
@dataclass
class WikiPassage:
"""Wikipedia passage"""
id: str
title: str
text: str
url: str = ""
class WikiDPRCorpus:
"""Load full Wikipedia corpus from HuggingFace wikimedia/wikipedia"""
def __init__(self, cache_dir: str = "./data/datasets"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.corpus = None
self.passages = []
def load(
self,
language: str = "20231101.simple", # or "20231101.en" for full English
max_passages: int = 10000,
show_progress: bool = True
):
"""
Load full Wikipedia corpus from HuggingFace
Args:
language: Wikipedia dump date.language (e.g., "20231101.simple", "20231101.en")
max_passages: Maximum number of passages to extract
show_progress: Show progress
"""
if show_progress:
print(f"\nπŸ“₯ Loading Wikipedia corpus ({language})...")
print("=" * 80)
print(f"⚠️ Note: Full English Wikipedia is ~20GB, Simple is ~200MB")
# Load full Wikipedia dataset
self.corpus = load_dataset(
"wikimedia/wikipedia",
language,
split="train",
cache_dir=str(self.cache_dir),
)
if show_progress:
print(f"βœ… Loaded {len(self.corpus)} Wikipedia articles")
# Extract passages
self._extract_passages(max_passages, show_progress)
def _extract_passages(self, max_passages: int, show_progress: bool):
"""
Extract passages by chunking Wikipedia articles into ~100 word chunks
(matching wiki_dpr format)
"""
if show_progress:
print(f"\nπŸ”„ Chunking articles into ~100 word passages...")
passages = []
article_idx = 0
iterator = tqdm(
total=max_passages,
desc="Extracting passages",
disable=not show_progress
)
while len(passages) < max_passages and article_idx < len(self.corpus):
item = self.corpus[article_idx]
article_idx += 1
# Get article data
article_id = item.get('id', f'wiki_{article_idx}')
title = item.get('title', '')
text = item.get('text', '')
url = item.get('url', '')
# Skip empty or very short articles
if not text or len(text.strip()) < 100:
continue
# Split into ~100 word chunks (wiki_dpr format)
words = text.split()
chunk_size = 100
for chunk_idx in range(0, len(words), chunk_size):
if len(passages) >= max_passages:
break
chunk_words = words[chunk_idx:chunk_idx + chunk_size]
# Skip very short chunks
if len(chunk_words) < 20:
continue
chunk_text = ' '.join(chunk_words)
passages.append(WikiPassage(
id=f'{article_id}_chunk_{chunk_idx // chunk_size}',
title=title,
text=chunk_text,
url=url
))
iterator.update(1)
iterator.close()
self.passages = passages
if show_progress:
print(f"βœ… Extracted {len(passages)} passages from {article_idx} articles")
print(f" Each passage is ~100 words (wiki_dpr format)")
def get_passages(self) -> List[WikiPassage]:
"""Get loaded passages"""
return self.passages
def save_json(self, filepath: str):
"""Save passages to JSON"""
if not self.passages:
print(f"⚠️ No passages to save")
return
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
data = [
{
'id': p.id,
'title': p.title,
'text': p.text,
'url': p.url,
}
for p in self.passages
]
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
print(f"βœ… Saved {len(data)} passages to {filepath}")
def get_stats(self) -> Dict:
"""Get corpus statistics"""
if not self.passages:
return {}
return {
"total_passages": len(self.passages),
"avg_passage_length": sum(len(p.text.split()) for p in self.passages) / len(self.passages),
"unique_titles": len(set(p.title for p in self.passages)),
}
if __name__ == "__main__":
print("πŸ“š Dataset Loader Test")
print("=" * 80)
# Test 1: Natural Questions
print("\n[Test 1] Natural Questions")
nq = NaturalQuestionsDataset()
nq.load(max_samples=50, show_progress=True)
questions = nq.get_questions()
if len(questions) > 0:
print(f"\nπŸ“‹ Sample Questions (first 3):")
for i, q in enumerate(questions[:3], 1):
print(f"\n{i}. Q: {q.question}")
print(f" A: {q.answers}")
stats = nq.get_stats()
print(f"\nπŸ“Š Stats: {stats['total']} questions")
nq.save_json("./data/datasets/nq_sample.json")
# Test 2: Wikipedia Corpus
print("\n" + "=" * 80)
print("\n[Test 2] Wikipedia Corpus")
wiki = WikiDPRCorpus()
wiki.load(
language="20231101.en", # Full English Wikipedia, change to "20231101.simple" for Simple English
max_passages=10000,
show_progress=True
)
passages = wiki.get_passages()
print(f"\nπŸ“‹ Sample Passages (first 3):")
for i, p in enumerate(passages[:3], 1):
print(f"\n{i}. Title: {p.title}")
print(f" Text: {p.text[:100]}...")
print(f" URL: {p.url}")
stats = wiki.get_stats()
print(f"\nπŸ“Š Stats:")
for key, value in stats.items():
print(f" {key}: {value}")
wiki.save_json("./data/datasets/wiki_passages_1000.json")
print("\n" + "=" * 80)
print("βœ… Test complete!")