""" utils/dataset_loader.py - Load Natural Questions + Wikipedia Dataset ====================================================================== Clean implementation for loading: - Natural Questions: Q&A dataset with answers extracted from Wikipedia - Wikipedia: Standard Wikipedia dataset chunked into ~100 word passages """ from typing import List, Dict, Optional from dataclasses import dataclass from pathlib import Path import json from datasets import load_dataset from tqdm import tqdm @dataclass class Question: """Question with answers and context""" id: str question: str answers: List[str] context: str = "" has_answer: bool = True # ============================================================================= # NATURAL QUESTIONS DATASET # ============================================================================= class NaturalQuestionsDataset: """Load Natural Questions dataset""" def __init__(self, cache_dir: str = "./data/datasets"): self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.dataset = None self.questions = [] def load(self, max_samples: int = 100, show_progress: bool = True): """Load Natural Questions validation set""" if show_progress: print(f"\n๐Ÿ“ฅ Loading Natural Questions...") print("=" * 80) # Load from HuggingFace self.dataset = load_dataset( "google-research-datasets/natural_questions", split="validation", cache_dir=str(self.cache_dir), ) if show_progress: print(f"โœ… Loaded {len(self.dataset)} questions from dataset") # Extract questions self._extract_questions(max_samples, show_progress) def _extract_questions(self, max_samples: int, show_progress: bool): """Extract questions from dataset""" if show_progress: print(f"\n๐Ÿ”„ Extracting questions...") # Debug first item if show_progress: print(f"\n๐Ÿ” Inspecting first item structure:") self._debug_first_item() num_to_process = min(max_samples, len(self.dataset)) questions = [] for idx in tqdm(range(num_to_process), desc="Processing", disable=not show_progress): item = self.dataset[idx] # Get question question_text = self._extract_question(item) if not question_text: continue # Get answers answers = self._extract_answers(item) if not answers: # Skip questions without answers continue # Get context context = self._extract_context(item) questions.append(Question( id=str(item.get('id', f'nq_{idx}')), question=question_text, answers=answers, context=context, has_answer=True )) self.questions = questions if show_progress: print(f"\nโœ… Successfully extracted {len(questions)} questions with answers") if len(questions) == 0: print(f"โŒ ERROR: Extracted 0 questions from {num_to_process} items") print(f" This means the answer extraction is failing.") print(f" Check the debug output above to see the data structure.") def _debug_first_item(self): """Debug first item to understand structure""" if len(self.dataset) == 0: print(" โŒ Dataset is empty!") return item = self.dataset[0] # Question print(f"\n ๐Ÿ“ Question:") if 'question' in item: q = item['question'] print(f" Type: {type(q)}") if isinstance(q, dict): print(f" Keys: {list(q.keys())}") print(f" Text: {q.get('text', 'N/A')[:100]}") else: print(f" Value: {str(q)[:100]}") # Annotations print(f"\n ๐Ÿ“‹ Annotations:") if 'annotations' in item: anns = item['annotations'] print(f" Type: {type(anns)}") if isinstance(anns, list): print(f" Count: {len(anns)}") if len(anns) > 0: ann = anns[0] print(f" First annotation keys: {list(ann.keys())}") if 'short_answers' in ann: short_ans = ann['short_answers'] print(f" Short answers count: {len(short_ans) if isinstance(short_ans, list) else 'Not a list'}") if isinstance(short_ans, list) and len(short_ans) > 0: print(f" First short answer: {short_ans[0]}") elif isinstance(anns, dict): print(f" It's a dict with keys: {list(anns.keys())}") # If it's a dict, check if it has the fields we need if 'short_answers' in anns: short_ans = anns['short_answers'] print(f" Short answers type: {type(short_ans)}") print(f" Short answers: {short_ans}") else: print(f" Unexpected type: {type(anns)}") print(f" Value: {anns}") else: print(f" โŒ No 'annotations' field found") # Document tokens print(f"\n ๐Ÿ“„ Document:") if 'document' in item: doc = item['document'] print(f" Type: {type(doc)}") if isinstance(doc, dict): print(f" Keys: {list(doc.keys())}") if 'tokens' in doc: tokens = doc['tokens'] print(f" Tokens type: {type(tokens)}") if isinstance(tokens, dict): print(f" Tokens keys: {list(tokens.keys())}") if 'token' in tokens: token_list = tokens['token'] print(f" Token list length: {len(token_list) if isinstance(token_list, list) else 'Not a list'}") if isinstance(token_list, list) and len(token_list) > 0: print(f" First 10 tokens: {token_list[:10]}") elif isinstance(tokens, list): print(f" Tokens list length: {len(tokens)}") print(f" First 10 tokens: {tokens[:10]}") # Try extraction print(f"\n ๐Ÿงช Testing extraction methods:") question = self._extract_question(item) print(f" Question extracted: '{question[:50]}...' " if question else " โŒ Failed to extract question") answers = self._extract_answers(item) print(f" Answers extracted: {answers}" if answers else " โŒ Failed to extract answers") if not answers and 'annotations' in item: print(f"\n ๐Ÿ” Deep dive into annotations:") anns = item['annotations'] # Convert to list format for uniform processing ann_list = [anns] if isinstance(anns, dict) else (anns if isinstance(anns, list) else []) for i, ann in enumerate(ann_list[:2]): # Check first 2 annotations print(f"\n Annotation {i}:") if isinstance(ann, dict): print(f" Keys: {list(ann.keys())}") if 'short_answers' in ann: sa = ann['short_answers'] print(f" Short answers type: {type(sa)}") if isinstance(sa, list): print(f" Short answers count: {len(sa)}") if len(sa) > 0: first_sa = sa[0] print(f" First short answer: {first_sa}") if isinstance(first_sa, dict): start = first_sa.get('start_token') end = first_sa.get('end_token') print(f" Start: {start}, End: {end}") if start is not None and end is not None: reconstructed = self._get_text_from_tokens(item, start, end) print(f" Reconstructed: '{reconstructed}'") def _extract_question(self, item: dict) -> str: """Extract question text""" if 'question' not in item: return "" q = item['question'] if isinstance(q, dict): return q.get('text', '') elif isinstance(q, str): return q return "" def _extract_answers(self, item: dict) -> List[str]: """Extract answers from annotations""" answers = [] if 'annotations' not in item: return answers annotations = item['annotations'] # Handle both list and dict formats annotation_list = [] if isinstance(annotations, list): annotation_list = annotations elif isinstance(annotations, dict): # If it's a single dict, treat it as a list of one annotation_list = [annotations] else: return answers for ann in annotation_list: if not isinstance(ann, dict): continue # Get short answers short_answers = ann.get('short_answers', []) if not isinstance(short_answers, list) or len(short_answers) == 0: continue # Extract each answer for ans_span in short_answers: if not isinstance(ans_span, dict): continue # Method 1: Try to get text directly (if available) text_field = ans_span.get('text', []) if isinstance(text_field, list) and len(text_field) > 0: # Text is provided directly for text in text_field: if text and str(text).strip(): answers.append(str(text).strip()) continue # Method 2: Reconstruct from tokens start = ans_span.get('start_token') end = ans_span.get('end_token') # Handle list format for start/end tokens if isinstance(start, list): start = start[0] if len(start) > 0 else None if isinstance(end, list): end = end[0] if len(end) > 0 else None if start is None or end is None: continue # Reconstruct from tokens answer_text = self._get_text_from_tokens(item, start, end) if answer_text: answers.append(answer_text) # Remove duplicates return list(set(a.strip() for a in answers if a and a.strip())) def _get_text_from_tokens(self, item: dict, start_token: int, end_token: int) -> str: """Get text from document tokens using start/end indices""" if 'document' not in item: return "" doc = item['document'] if not isinstance(doc, dict) or 'tokens' not in doc: return "" tokens = doc['tokens'] # Get token list token_list = None if isinstance(tokens, dict) and 'token' in tokens: token_list = tokens['token'] elif isinstance(tokens, list): token_list = tokens if not token_list or not isinstance(token_list, list): return "" # Check bounds if start_token < 0 or end_token >= len(token_list) or start_token > end_token: return "" # Extract and join tokens answer_tokens = token_list[start_token:end_token+1] return " ".join(str(t) for t in answer_tokens).strip() def _extract_context(self, item: dict) -> str: """Extract context from document""" if 'document' not in item: return "" doc = item['document'] if not isinstance(doc, dict) or 'tokens' not in doc: return "" tokens = doc['tokens'] # Get token list token_list = None if isinstance(tokens, dict) and 'token' in tokens: token_list = tokens['token'] elif isinstance(tokens, list): token_list = tokens if not token_list or not isinstance(token_list, list): return "" # Take first 100 tokens as context context_tokens = token_list[:100] return " ".join(str(t) for t in context_tokens) def get_questions(self) -> List[Question]: """Get loaded questions""" return self.questions def save_json(self, filepath: str): """Save questions to JSON""" if not self.questions: print(f"โš ๏ธ No questions to save") return filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) data = [ { 'id': q.id, 'question': q.question, 'answers': q.answers, 'context': q.context, } for q in self.questions ] with open(filepath, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) print(f"โœ… Saved {len(data)} questions to {filepath}") def get_stats(self) -> Dict: """Get dataset statistics""" if not self.questions: return { 'total': 0, 'with_answers': 0, 'avg_question_len': 0, 'avg_answers': 0 } return { 'total': len(self.questions), 'with_answers': sum(1 for q in self.questions if q.answers), 'avg_question_len': sum(len(q.question.split()) for q in self.questions) / len(self.questions), 'avg_answers': sum(len(q.answers) for q in self.questions) / len(self.questions) } # ============================================================================= # WIKIPEDIA CORPUS (Full Dataset from HuggingFace) # ============================================================================= @dataclass class WikiPassage: """Wikipedia passage""" id: str title: str text: str url: str = "" class WikiDPRCorpus: """Load full Wikipedia corpus from HuggingFace wikimedia/wikipedia""" def __init__(self, cache_dir: str = "./data/datasets"): self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self.corpus = None self.passages = [] def load( self, language: str = "20231101.simple", # or "20231101.en" for full English max_passages: int = 10000, show_progress: bool = True ): """ Load full Wikipedia corpus from HuggingFace Args: language: Wikipedia dump date.language (e.g., "20231101.simple", "20231101.en") max_passages: Maximum number of passages to extract show_progress: Show progress """ if show_progress: print(f"\n๐Ÿ“ฅ Loading Wikipedia corpus ({language})...") print("=" * 80) print(f"โš ๏ธ Note: Full English Wikipedia is ~20GB, Simple is ~200MB") # Load full Wikipedia dataset self.corpus = load_dataset( "wikimedia/wikipedia", language, split="train", cache_dir=str(self.cache_dir), ) if show_progress: print(f"โœ… Loaded {len(self.corpus)} Wikipedia articles") # Extract passages self._extract_passages(max_passages, show_progress) def _extract_passages(self, max_passages: int, show_progress: bool): """ Extract passages by chunking Wikipedia articles into ~100 word chunks (matching wiki_dpr format) """ if show_progress: print(f"\n๐Ÿ”„ Chunking articles into ~100 word passages...") passages = [] article_idx = 0 iterator = tqdm( total=max_passages, desc="Extracting passages", disable=not show_progress ) while len(passages) < max_passages and article_idx < len(self.corpus): item = self.corpus[article_idx] article_idx += 1 # Get article data article_id = item.get('id', f'wiki_{article_idx}') title = item.get('title', '') text = item.get('text', '') url = item.get('url', '') # Skip empty or very short articles if not text or len(text.strip()) < 100: continue # Split into ~100 word chunks (wiki_dpr format) words = text.split() chunk_size = 100 for chunk_idx in range(0, len(words), chunk_size): if len(passages) >= max_passages: break chunk_words = words[chunk_idx:chunk_idx + chunk_size] # Skip very short chunks if len(chunk_words) < 20: continue chunk_text = ' '.join(chunk_words) passages.append(WikiPassage( id=f'{article_id}_chunk_{chunk_idx // chunk_size}', title=title, text=chunk_text, url=url )) iterator.update(1) iterator.close() self.passages = passages if show_progress: print(f"โœ… Extracted {len(passages)} passages from {article_idx} articles") print(f" Each passage is ~100 words (wiki_dpr format)") def get_passages(self) -> List[WikiPassage]: """Get loaded passages""" return self.passages def save_json(self, filepath: str): """Save passages to JSON""" if not self.passages: print(f"โš ๏ธ No passages to save") return filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) data = [ { 'id': p.id, 'title': p.title, 'text': p.text, 'url': p.url, } for p in self.passages ] with open(filepath, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) print(f"โœ… Saved {len(data)} passages to {filepath}") def get_stats(self) -> Dict: """Get corpus statistics""" if not self.passages: return {} return { "total_passages": len(self.passages), "avg_passage_length": sum(len(p.text.split()) for p in self.passages) / len(self.passages), "unique_titles": len(set(p.title for p in self.passages)), } if __name__ == "__main__": print("๐Ÿ“š Dataset Loader Test") print("=" * 80) # Test 1: Natural Questions print("\n[Test 1] Natural Questions") nq = NaturalQuestionsDataset() nq.load(max_samples=50, show_progress=True) questions = nq.get_questions() if len(questions) > 0: print(f"\n๐Ÿ“‹ Sample Questions (first 3):") for i, q in enumerate(questions[:3], 1): print(f"\n{i}. Q: {q.question}") print(f" A: {q.answers}") stats = nq.get_stats() print(f"\n๐Ÿ“Š Stats: {stats['total']} questions") nq.save_json("./data/datasets/nq_sample.json") # Test 2: Wikipedia Corpus print("\n" + "=" * 80) print("\n[Test 2] Wikipedia Corpus") wiki = WikiDPRCorpus() wiki.load( language="20231101.en", # Full English Wikipedia, change to "20231101.simple" for Simple English max_passages=10000, show_progress=True ) passages = wiki.get_passages() print(f"\n๐Ÿ“‹ Sample Passages (first 3):") for i, p in enumerate(passages[:3], 1): print(f"\n{i}. Title: {p.title}") print(f" Text: {p.text[:100]}...") print(f" URL: {p.url}") stats = wiki.get_stats() print(f"\n๐Ÿ“Š Stats:") for key, value in stats.items(): print(f" {key}: {value}") wiki.save_json("./data/datasets/wiki_passages_1000.json") print("\n" + "=" * 80) print("โœ… Test complete!")