Spaces:
Sleeping
Sleeping
| """ | |
| utils/dataset_loader.py - Load Natural Questions + Wikipedia Dataset | |
| ====================================================================== | |
| Clean implementation for loading: | |
| - Natural Questions: Q&A dataset with answers extracted from Wikipedia | |
| - Wikipedia: Standard Wikipedia dataset chunked into ~100 word passages | |
| """ | |
| from typing import List, Dict, Optional | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import json | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| class Question: | |
| """Question with answers and context""" | |
| id: str | |
| question: str | |
| answers: List[str] | |
| context: str = "" | |
| has_answer: bool = True | |
| # ============================================================================= | |
| # NATURAL QUESTIONS DATASET | |
| # ============================================================================= | |
| class NaturalQuestionsDataset: | |
| """Load Natural Questions dataset""" | |
| def __init__(self, cache_dir: str = "./data/datasets"): | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self.dataset = None | |
| self.questions = [] | |
| def load(self, max_samples: int = 100, show_progress: bool = True): | |
| """Load Natural Questions validation set""" | |
| if show_progress: | |
| print(f"\nπ₯ Loading Natural Questions...") | |
| print("=" * 80) | |
| # Load from HuggingFace | |
| self.dataset = load_dataset( | |
| "google-research-datasets/natural_questions", | |
| split="validation", | |
| cache_dir=str(self.cache_dir), | |
| ) | |
| if show_progress: | |
| print(f"β Loaded {len(self.dataset)} questions from dataset") | |
| # Extract questions | |
| self._extract_questions(max_samples, show_progress) | |
| def _extract_questions(self, max_samples: int, show_progress: bool): | |
| """Extract questions from dataset""" | |
| if show_progress: | |
| print(f"\nπ Extracting questions...") | |
| # Debug first item | |
| if show_progress: | |
| print(f"\nπ Inspecting first item structure:") | |
| self._debug_first_item() | |
| num_to_process = min(max_samples, len(self.dataset)) | |
| questions = [] | |
| for idx in tqdm(range(num_to_process), desc="Processing", disable=not show_progress): | |
| item = self.dataset[idx] | |
| # Get question | |
| question_text = self._extract_question(item) | |
| if not question_text: | |
| continue | |
| # Get answers | |
| answers = self._extract_answers(item) | |
| if not answers: # Skip questions without answers | |
| continue | |
| # Get context | |
| context = self._extract_context(item) | |
| questions.append(Question( | |
| id=str(item.get('id', f'nq_{idx}')), | |
| question=question_text, | |
| answers=answers, | |
| context=context, | |
| has_answer=True | |
| )) | |
| self.questions = questions | |
| if show_progress: | |
| print(f"\nβ Successfully extracted {len(questions)} questions with answers") | |
| if len(questions) == 0: | |
| print(f"β ERROR: Extracted 0 questions from {num_to_process} items") | |
| print(f" This means the answer extraction is failing.") | |
| print(f" Check the debug output above to see the data structure.") | |
| def _debug_first_item(self): | |
| """Debug first item to understand structure""" | |
| if len(self.dataset) == 0: | |
| print(" β Dataset is empty!") | |
| return | |
| item = self.dataset[0] | |
| # Question | |
| print(f"\n π Question:") | |
| if 'question' in item: | |
| q = item['question'] | |
| print(f" Type: {type(q)}") | |
| if isinstance(q, dict): | |
| print(f" Keys: {list(q.keys())}") | |
| print(f" Text: {q.get('text', 'N/A')[:100]}") | |
| else: | |
| print(f" Value: {str(q)[:100]}") | |
| # Annotations | |
| print(f"\n π Annotations:") | |
| if 'annotations' in item: | |
| anns = item['annotations'] | |
| print(f" Type: {type(anns)}") | |
| if isinstance(anns, list): | |
| print(f" Count: {len(anns)}") | |
| if len(anns) > 0: | |
| ann = anns[0] | |
| print(f" First annotation keys: {list(ann.keys())}") | |
| if 'short_answers' in ann: | |
| short_ans = ann['short_answers'] | |
| print(f" Short answers count: {len(short_ans) if isinstance(short_ans, list) else 'Not a list'}") | |
| if isinstance(short_ans, list) and len(short_ans) > 0: | |
| print(f" First short answer: {short_ans[0]}") | |
| elif isinstance(anns, dict): | |
| print(f" It's a dict with keys: {list(anns.keys())}") | |
| # If it's a dict, check if it has the fields we need | |
| if 'short_answers' in anns: | |
| short_ans = anns['short_answers'] | |
| print(f" Short answers type: {type(short_ans)}") | |
| print(f" Short answers: {short_ans}") | |
| else: | |
| print(f" Unexpected type: {type(anns)}") | |
| print(f" Value: {anns}") | |
| else: | |
| print(f" β No 'annotations' field found") | |
| # Document tokens | |
| print(f"\n π Document:") | |
| if 'document' in item: | |
| doc = item['document'] | |
| print(f" Type: {type(doc)}") | |
| if isinstance(doc, dict): | |
| print(f" Keys: {list(doc.keys())}") | |
| if 'tokens' in doc: | |
| tokens = doc['tokens'] | |
| print(f" Tokens type: {type(tokens)}") | |
| if isinstance(tokens, dict): | |
| print(f" Tokens keys: {list(tokens.keys())}") | |
| if 'token' in tokens: | |
| token_list = tokens['token'] | |
| print(f" Token list length: {len(token_list) if isinstance(token_list, list) else 'Not a list'}") | |
| if isinstance(token_list, list) and len(token_list) > 0: | |
| print(f" First 10 tokens: {token_list[:10]}") | |
| elif isinstance(tokens, list): | |
| print(f" Tokens list length: {len(tokens)}") | |
| print(f" First 10 tokens: {tokens[:10]}") | |
| # Try extraction | |
| print(f"\n π§ͺ Testing extraction methods:") | |
| question = self._extract_question(item) | |
| print(f" Question extracted: '{question[:50]}...' " if question else " β Failed to extract question") | |
| answers = self._extract_answers(item) | |
| print(f" Answers extracted: {answers}" if answers else " β Failed to extract answers") | |
| if not answers and 'annotations' in item: | |
| print(f"\n π Deep dive into annotations:") | |
| anns = item['annotations'] | |
| # Convert to list format for uniform processing | |
| ann_list = [anns] if isinstance(anns, dict) else (anns if isinstance(anns, list) else []) | |
| for i, ann in enumerate(ann_list[:2]): # Check first 2 annotations | |
| print(f"\n Annotation {i}:") | |
| if isinstance(ann, dict): | |
| print(f" Keys: {list(ann.keys())}") | |
| if 'short_answers' in ann: | |
| sa = ann['short_answers'] | |
| print(f" Short answers type: {type(sa)}") | |
| if isinstance(sa, list): | |
| print(f" Short answers count: {len(sa)}") | |
| if len(sa) > 0: | |
| first_sa = sa[0] | |
| print(f" First short answer: {first_sa}") | |
| if isinstance(first_sa, dict): | |
| start = first_sa.get('start_token') | |
| end = first_sa.get('end_token') | |
| print(f" Start: {start}, End: {end}") | |
| if start is not None and end is not None: | |
| reconstructed = self._get_text_from_tokens(item, start, end) | |
| print(f" Reconstructed: '{reconstructed}'") | |
| def _extract_question(self, item: dict) -> str: | |
| """Extract question text""" | |
| if 'question' not in item: | |
| return "" | |
| q = item['question'] | |
| if isinstance(q, dict): | |
| return q.get('text', '') | |
| elif isinstance(q, str): | |
| return q | |
| return "" | |
| def _extract_answers(self, item: dict) -> List[str]: | |
| """Extract answers from annotations""" | |
| answers = [] | |
| if 'annotations' not in item: | |
| return answers | |
| annotations = item['annotations'] | |
| # Handle both list and dict formats | |
| annotation_list = [] | |
| if isinstance(annotations, list): | |
| annotation_list = annotations | |
| elif isinstance(annotations, dict): | |
| # If it's a single dict, treat it as a list of one | |
| annotation_list = [annotations] | |
| else: | |
| return answers | |
| for ann in annotation_list: | |
| if not isinstance(ann, dict): | |
| continue | |
| # Get short answers | |
| short_answers = ann.get('short_answers', []) | |
| if not isinstance(short_answers, list) or len(short_answers) == 0: | |
| continue | |
| # Extract each answer | |
| for ans_span in short_answers: | |
| if not isinstance(ans_span, dict): | |
| continue | |
| # Method 1: Try to get text directly (if available) | |
| text_field = ans_span.get('text', []) | |
| if isinstance(text_field, list) and len(text_field) > 0: | |
| # Text is provided directly | |
| for text in text_field: | |
| if text and str(text).strip(): | |
| answers.append(str(text).strip()) | |
| continue | |
| # Method 2: Reconstruct from tokens | |
| start = ans_span.get('start_token') | |
| end = ans_span.get('end_token') | |
| # Handle list format for start/end tokens | |
| if isinstance(start, list): | |
| start = start[0] if len(start) > 0 else None | |
| if isinstance(end, list): | |
| end = end[0] if len(end) > 0 else None | |
| if start is None or end is None: | |
| continue | |
| # Reconstruct from tokens | |
| answer_text = self._get_text_from_tokens(item, start, end) | |
| if answer_text: | |
| answers.append(answer_text) | |
| # Remove duplicates | |
| return list(set(a.strip() for a in answers if a and a.strip())) | |
| def _get_text_from_tokens(self, item: dict, start_token: int, end_token: int) -> str: | |
| """Get text from document tokens using start/end indices""" | |
| if 'document' not in item: | |
| return "" | |
| doc = item['document'] | |
| if not isinstance(doc, dict) or 'tokens' not in doc: | |
| return "" | |
| tokens = doc['tokens'] | |
| # Get token list | |
| token_list = None | |
| if isinstance(tokens, dict) and 'token' in tokens: | |
| token_list = tokens['token'] | |
| elif isinstance(tokens, list): | |
| token_list = tokens | |
| if not token_list or not isinstance(token_list, list): | |
| return "" | |
| # Check bounds | |
| if start_token < 0 or end_token >= len(token_list) or start_token > end_token: | |
| return "" | |
| # Extract and join tokens | |
| answer_tokens = token_list[start_token:end_token+1] | |
| return " ".join(str(t) for t in answer_tokens).strip() | |
| def _extract_context(self, item: dict) -> str: | |
| """Extract context from document""" | |
| if 'document' not in item: | |
| return "" | |
| doc = item['document'] | |
| if not isinstance(doc, dict) or 'tokens' not in doc: | |
| return "" | |
| tokens = doc['tokens'] | |
| # Get token list | |
| token_list = None | |
| if isinstance(tokens, dict) and 'token' in tokens: | |
| token_list = tokens['token'] | |
| elif isinstance(tokens, list): | |
| token_list = tokens | |
| if not token_list or not isinstance(token_list, list): | |
| return "" | |
| # Take first 100 tokens as context | |
| context_tokens = token_list[:100] | |
| return " ".join(str(t) for t in context_tokens) | |
| def get_questions(self) -> List[Question]: | |
| """Get loaded questions""" | |
| return self.questions | |
| def save_json(self, filepath: str): | |
| """Save questions to JSON""" | |
| if not self.questions: | |
| print(f"β οΈ No questions to save") | |
| return | |
| filepath = Path(filepath) | |
| filepath.parent.mkdir(parents=True, exist_ok=True) | |
| data = [ | |
| { | |
| 'id': q.id, | |
| 'question': q.question, | |
| 'answers': q.answers, | |
| 'context': q.context, | |
| } | |
| for q in self.questions | |
| ] | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=2, ensure_ascii=False) | |
| print(f"β Saved {len(data)} questions to {filepath}") | |
| def get_stats(self) -> Dict: | |
| """Get dataset statistics""" | |
| if not self.questions: | |
| return { | |
| 'total': 0, | |
| 'with_answers': 0, | |
| 'avg_question_len': 0, | |
| 'avg_answers': 0 | |
| } | |
| return { | |
| 'total': len(self.questions), | |
| 'with_answers': sum(1 for q in self.questions if q.answers), | |
| 'avg_question_len': sum(len(q.question.split()) for q in self.questions) / len(self.questions), | |
| 'avg_answers': sum(len(q.answers) for q in self.questions) / len(self.questions) | |
| } | |
| # ============================================================================= | |
| # WIKIPEDIA CORPUS (Full Dataset from HuggingFace) | |
| # ============================================================================= | |
| class WikiPassage: | |
| """Wikipedia passage""" | |
| id: str | |
| title: str | |
| text: str | |
| url: str = "" | |
| class WikiDPRCorpus: | |
| """Load full Wikipedia corpus from HuggingFace wikimedia/wikipedia""" | |
| def __init__(self, cache_dir: str = "./data/datasets"): | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self.corpus = None | |
| self.passages = [] | |
| def load( | |
| self, | |
| language: str = "20231101.simple", # or "20231101.en" for full English | |
| max_passages: int = 10000, | |
| show_progress: bool = True | |
| ): | |
| """ | |
| Load full Wikipedia corpus from HuggingFace | |
| Args: | |
| language: Wikipedia dump date.language (e.g., "20231101.simple", "20231101.en") | |
| max_passages: Maximum number of passages to extract | |
| show_progress: Show progress | |
| """ | |
| if show_progress: | |
| print(f"\nπ₯ Loading Wikipedia corpus ({language})...") | |
| print("=" * 80) | |
| print(f"β οΈ Note: Full English Wikipedia is ~20GB, Simple is ~200MB") | |
| # Load full Wikipedia dataset | |
| self.corpus = load_dataset( | |
| "wikimedia/wikipedia", | |
| language, | |
| split="train", | |
| cache_dir=str(self.cache_dir), | |
| ) | |
| if show_progress: | |
| print(f"β Loaded {len(self.corpus)} Wikipedia articles") | |
| # Extract passages | |
| self._extract_passages(max_passages, show_progress) | |
| def _extract_passages(self, max_passages: int, show_progress: bool): | |
| """ | |
| Extract passages by chunking Wikipedia articles into ~100 word chunks | |
| (matching wiki_dpr format) | |
| """ | |
| if show_progress: | |
| print(f"\nπ Chunking articles into ~100 word passages...") | |
| passages = [] | |
| article_idx = 0 | |
| iterator = tqdm( | |
| total=max_passages, | |
| desc="Extracting passages", | |
| disable=not show_progress | |
| ) | |
| while len(passages) < max_passages and article_idx < len(self.corpus): | |
| item = self.corpus[article_idx] | |
| article_idx += 1 | |
| # Get article data | |
| article_id = item.get('id', f'wiki_{article_idx}') | |
| title = item.get('title', '') | |
| text = item.get('text', '') | |
| url = item.get('url', '') | |
| # Skip empty or very short articles | |
| if not text or len(text.strip()) < 100: | |
| continue | |
| # Split into ~100 word chunks (wiki_dpr format) | |
| words = text.split() | |
| chunk_size = 100 | |
| for chunk_idx in range(0, len(words), chunk_size): | |
| if len(passages) >= max_passages: | |
| break | |
| chunk_words = words[chunk_idx:chunk_idx + chunk_size] | |
| # Skip very short chunks | |
| if len(chunk_words) < 20: | |
| continue | |
| chunk_text = ' '.join(chunk_words) | |
| passages.append(WikiPassage( | |
| id=f'{article_id}_chunk_{chunk_idx // chunk_size}', | |
| title=title, | |
| text=chunk_text, | |
| url=url | |
| )) | |
| iterator.update(1) | |
| iterator.close() | |
| self.passages = passages | |
| if show_progress: | |
| print(f"β Extracted {len(passages)} passages from {article_idx} articles") | |
| print(f" Each passage is ~100 words (wiki_dpr format)") | |
| def get_passages(self) -> List[WikiPassage]: | |
| """Get loaded passages""" | |
| return self.passages | |
| def save_json(self, filepath: str): | |
| """Save passages to JSON""" | |
| if not self.passages: | |
| print(f"β οΈ No passages to save") | |
| return | |
| filepath = Path(filepath) | |
| filepath.parent.mkdir(parents=True, exist_ok=True) | |
| data = [ | |
| { | |
| 'id': p.id, | |
| 'title': p.title, | |
| 'text': p.text, | |
| 'url': p.url, | |
| } | |
| for p in self.passages | |
| ] | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=2, ensure_ascii=False) | |
| print(f"β Saved {len(data)} passages to {filepath}") | |
| def get_stats(self) -> Dict: | |
| """Get corpus statistics""" | |
| if not self.passages: | |
| return {} | |
| return { | |
| "total_passages": len(self.passages), | |
| "avg_passage_length": sum(len(p.text.split()) for p in self.passages) / len(self.passages), | |
| "unique_titles": len(set(p.title for p in self.passages)), | |
| } | |
| if __name__ == "__main__": | |
| print("π Dataset Loader Test") | |
| print("=" * 80) | |
| # Test 1: Natural Questions | |
| print("\n[Test 1] Natural Questions") | |
| nq = NaturalQuestionsDataset() | |
| nq.load(max_samples=50, show_progress=True) | |
| questions = nq.get_questions() | |
| if len(questions) > 0: | |
| print(f"\nπ Sample Questions (first 3):") | |
| for i, q in enumerate(questions[:3], 1): | |
| print(f"\n{i}. Q: {q.question}") | |
| print(f" A: {q.answers}") | |
| stats = nq.get_stats() | |
| print(f"\nπ Stats: {stats['total']} questions") | |
| nq.save_json("./data/datasets/nq_sample.json") | |
| # Test 2: Wikipedia Corpus | |
| print("\n" + "=" * 80) | |
| print("\n[Test 2] Wikipedia Corpus") | |
| wiki = WikiDPRCorpus() | |
| wiki.load( | |
| language="20231101.en", # Full English Wikipedia, change to "20231101.simple" for Simple English | |
| max_passages=10000, | |
| show_progress=True | |
| ) | |
| passages = wiki.get_passages() | |
| print(f"\nπ Sample Passages (first 3):") | |
| for i, p in enumerate(passages[:3], 1): | |
| print(f"\n{i}. Title: {p.title}") | |
| print(f" Text: {p.text[:100]}...") | |
| print(f" URL: {p.url}") | |
| stats = wiki.get_stats() | |
| print(f"\nπ Stats:") | |
| for key, value in stats.items(): | |
| print(f" {key}: {value}") | |
| wiki.save_json("./data/datasets/wiki_passages_1000.json") | |
| print("\n" + "=" * 80) | |
| print("β Test complete!") | |