File size: 6,883 Bytes
da27cbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
Download and prepare QA training data for H4 RAG.
Uses a simple extractive QA format:
- Input: [context] | [question] |
- Target: [answer]
Data sources (in order of preference):
1. SQuAD-style QA pairs generated from the sample documents
2. Downloaded SQuAD 2.0 dev set (small, freely available)
For CPU training with 2-minute budget, we need small data that
trains fast. The sample doc QA pairs are ideal for proving the
pipeline works; SQuAD provides real benchmark numbers.
"""
import json
import os
import sys
import random
from typing import List, Tuple, Dict
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
def generate_sample_qa() -> List[Dict]:
"""
Generate QA pairs from the sample documents.
These are hand-crafted to match the sample_docs content.
The model's job: learn to extract the answer from the context.
"""
qa_pairs = [
# golden_ratio.txt
{"context": "The golden ratio, often denoted by the Greek letter phi, is a special number approximately equal to 1.618.",
"question": "What is the golden ratio approximately equal to?",
"answer": "1.618"},
{"context": "Two quantities are in the golden ratio if their ratio is the same as the ratio of their sum to the larger of the two quantities.",
"question": "When are two quantities in the golden ratio?",
"answer": "if their ratio is the same as the ratio of their sum to the larger"},
{"context": "The golden ratio is closely related to the Fibonacci sequence. As Fibonacci numbers increase, the ratio of consecutive Fibonacci numbers approaches the golden ratio.",
"question": "How is the golden ratio related to Fibonacci numbers?",
"answer": "the ratio of consecutive Fibonacci numbers approaches the golden ratio"},
{"context": "The golden ratio appears in the geometry of pentagons and in the arrangement of leaves and petals in many plants.",
"question": "Where does the golden ratio appear in nature?",
"answer": "in the arrangement of leaves and petals in many plants"},
# polytopes.txt
{"context": "The 600-cell is a regular 4-polytope with 120 vertices, 720 edges, 1200 triangular faces, and 600 tetrahedral cells.",
"question": "How many vertices does the 600-cell have?",
"answer": "120"},
{"context": "The 600-cell has the H4 symmetry group, which contains 14400 elements. This is the largest finite reflection group in four dimensions.",
"question": "How many elements does the H4 symmetry group contain?",
"answer": "14400"},
{"context": "The 600-cell is dual to the 120-cell, which has 600 vertices.",
"question": "What is the 600-cell dual to?",
"answer": "the 120-cell"},
{"context": "A polytope is a geometric object with flat sides in any number of dimensions.",
"question": "What is a polytope?",
"answer": "a geometric object with flat sides in any number of dimensions"},
# e8_lattice.txt
{"context": "The E8 lattice is the densest sphere packing in eight dimensions. This was proven by Maryna Viazovska in 2016.",
"question": "Who proved E8 is the densest sphere packing?",
"answer": "Maryna Viazovska"},
{"context": "The E8 lattice has a kissing number of 240, meaning each sphere touches exactly 240 others.",
"question": "What is the kissing number of E8?",
"answer": "240"},
{"context": "The Coxeter element of E8 has eigenvalues that include cosine of pi over five, which equals phi over two.",
"question": "What eigenvalue connects E8 to the golden ratio?",
"answer": "cosine of pi over five, which equals phi over two"},
{"context": "When the 240 roots of E8 are projected along these eigenspaces, they map to the vertices of H4 polytopes.",
"question": "What happens when E8 roots are projected along the eigenspaces?",
"answer": "they map to the vertices of H4 polytopes"},
]
return qa_pairs
def prepare_training_data(
qa_pairs: List[Dict],
val_fraction: float = 0.2,
) -> Tuple[List[Dict], List[Dict]]:
"""Split QA pairs into train and validation sets."""
random.seed(42)
pairs = list(qa_pairs)
random.shuffle(pairs)
n_val = max(1, int(len(pairs) * val_fraction))
return pairs[n_val:], pairs[:n_val]
def format_qa_for_training(qa_pair: Dict, sep: str = " | ") -> Tuple[str, str]:
"""
Format a QA pair for character-level training.
Input: [context] | [question] |
Target: [answer]
The model learns to generate the answer given context + question.
"""
input_text = qa_pair['context'] + sep + qa_pair['question'] + sep
target_text = qa_pair['answer']
return input_text, target_text
def download_squad_dev():
"""
Download SQuAD 2.0 dev set for real benchmark evaluation.
Returns list of QA dicts with context/question/answer.
"""
import urllib.request
cache_dir = os.path.join(os.path.dirname(__file__), '..', '..', 'data')
os.makedirs(cache_dir, exist_ok=True)
cache_path = os.path.join(cache_dir, 'squad_dev.json')
if not os.path.exists(cache_path):
url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"
print(f"Downloading SQuAD 2.0 dev set...")
try:
urllib.request.urlretrieve(url, cache_path)
print(f"Saved to {cache_path}")
except Exception as e:
print(f"Download failed: {e}")
return []
with open(cache_path, 'r', encoding='utf-8') as f:
data = json.load(f)
qa_pairs = []
for article in data['data']:
for paragraph in article['paragraphs']:
context = paragraph['context']
for qa in paragraph['qas']:
if qa.get('is_impossible', False):
continue
if qa['answers']:
answer = qa['answers'][0]['text']
qa_pairs.append({
'context': context[:500], # truncate long contexts
'question': qa['question'],
'answer': answer,
})
return qa_pairs
if __name__ == '__main__':
print("Generating sample QA pairs...")
pairs = generate_sample_qa()
train, val = prepare_training_data(pairs)
print(f"Sample QA: {len(train)} train, {len(val)} val")
for p in pairs[:3]:
inp, tgt = format_qa_for_training(p)
print(f"\nInput: {inp[:80]}...")
print(f"Target: {tgt}")
print("\nAttempting SQuAD download...")
squad = download_squad_dev()
if squad:
print(f"SQuAD 2.0 dev: {len(squad)} answerable questions")
else:
print("SQuAD not available (offline?). Using sample QA only.")
|