reading-steiner / prepare_data.py
OmAlve's picture
Copy prepare_data.py from IndexLM-0.6B
bfd6805 verified
"""
Prepare IndexLM training data from HotpotQA and MSMARCO.
Pipeline:
1. Load HotpotQA (has context = list of (title, sentences) + supporting_facts)
2. Convert context into indexed HTML-like blocks: [i] <tag>content</tag>
3. The target is index intervals of blocks containing supporting facts
4. Also create main-content extraction examples (all content blocks are "main content",
but we inject noise blocks like nav/ads to train the model to filter them)
5. Format as conversational messages for SFT
"""
import json
import random
import re
from datasets import load_dataset, Dataset
from collections import defaultdict
random.seed(42)
# Noise blocks to inject (simulating real web page clutter)
NOISE_BLOCKS = [
'<nav>Home | About | Contact | Privacy Policy</nav>',
'<div class="ad">Advertisement - Continue Reading Below</div>',
'<div class="sidebar">Related Articles: Top 10 Facts You Didn\'t Know</div>',
'<footer>© 2024 All Rights Reserved | Terms of Service</footer>',
'<div class="cookie-banner">This site uses cookies. Accept | Decline</div>',
'<div class="social">Share on: Twitter | Facebook | LinkedIn</div>',
'<nav class="breadcrumb">Home > Category > Subcategory > Article</nav>',
'<div class="newsletter">Subscribe to our newsletter for updates</div>',
'<div class="popup">Sign up for free access to premium content</div>',
'<aside>Trending: Latest news and popular stories</aside>',
'<div class="comments">Comments (0) - Be the first to comment</div>',
'<div class="author">Written by Staff Reporter | Updated: Jan 2024</div>',
'<div class="pagination">Previous | 1 | 2 | 3 | Next</div>',
'<div class="search">Search this site...</div>',
'<div class="menu">Categories: Science, Tech, Health, Sports</div>',
]
SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query.
Each block is formatted as: [i] <tag>content</tag>
Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive).
If no relevant content exists, output 'NA'.
Example output: [[2,4],[7,7],[10,12]]"""
SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements).
Each block is formatted as: [i] <tag>content</tag>
Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive).
If no main content exists, output 'NA'.
Example output: [[1,3],[5,8],[11,15]]"""
def indices_to_intervals(indices):
"""Convert a sorted list of indices to intervals [[start,end], ...]"""
if not indices:
return "NA"
indices = sorted(set(indices))
intervals = []
start = indices[0]
end = indices[0]
for i in indices[1:]:
if i == end + 1:
end = i
else:
intervals.append([start, end])
start = i
end = i
intervals.append([start, end])
return json.dumps(intervals)
def create_indexed_blocks_from_hotpotqa(context, supporting_facts, inject_noise=True):
"""
Convert HotpotQA context into indexed HTML blocks.
context: {'title': [...], 'sentences': [[...], ...]}
supporting_facts: {'title': [...], 'sent_id': [...]}
Returns: (block_text, relevant_indices, all_content_indices)
"""
titles = context['title']
sentences_list = context['sentences']
# Build supporting facts lookup
sf_lookup = defaultdict(set)
for title, sent_id in zip(supporting_facts['title'], supporting_facts['sent_id']):
sf_lookup[title].add(sent_id)
blocks = []
relevant_indices = []
content_indices = [] # All real content (non-noise)
idx = 1
for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)):
# Title block
blocks.append(f"[{idx}] <h2>{title}</h2>")
content_indices.append(idx)
if title in sf_lookup:
# Title of a supporting document is relevant
relevant_indices.append(idx)
idx += 1
# Sentence blocks
for sent_idx, sentence in enumerate(sentences):
sentence = sentence.strip()
if not sentence:
continue
# Use <p> for regular text
blocks.append(f"[{idx}] <p>{sentence}</p>")
content_indices.append(idx)
if title in sf_lookup and sent_idx in sf_lookup[title]:
relevant_indices.append(idx)
idx += 1
# Inject noise between documents sometimes
if inject_noise and random.random() < 0.4 and doc_idx < len(titles) - 1:
noise = random.choice(NOISE_BLOCKS)
blocks.append(f"[{idx}] {noise}")
idx += 1
# Sometimes add noise at start and end
if inject_noise:
prefix_noise = []
if random.random() < 0.5:
for _ in range(random.randint(1, 3)):
noise = random.choice(NOISE_BLOCKS)
prefix_noise.append(noise)
suffix_noise = []
if random.random() < 0.5:
for _ in range(random.randint(1, 3)):
noise = random.choice(NOISE_BLOCKS)
suffix_noise.append(noise)
if prefix_noise or suffix_noise:
# Reindex everything
new_blocks = []
new_relevant = []
new_content = []
new_idx = 1
# Prefix noise
for noise in prefix_noise:
new_blocks.append(f"[{new_idx}] {noise}")
new_idx += 1
# Remap original blocks
offset = len(prefix_noise)
for b in blocks:
old_idx = int(b.split(']')[0].replace('[', ''))
new_b = f"[{old_idx + offset}] " + '] '.join(b.split('] ')[1:])
new_blocks.append(new_b)
new_relevant = [r + offset for r in relevant_indices]
new_content = [c + offset for c in content_indices]
# Suffix noise
next_idx = len(new_blocks) + 1
for noise in suffix_noise:
new_blocks.append(f"[{next_idx}] {noise}")
next_idx += 1
blocks = new_blocks
relevant_indices = new_relevant
content_indices = new_content
block_text = "\n".join(blocks)
return block_text, relevant_indices, content_indices
def build_query_relevant_example(question, block_text, relevant_indices, url="https://en.wikipedia.org"):
"""Build a query-relevant extraction (QE) example."""
intervals = indices_to_intervals(relevant_indices)
user_content = f"URL: {url}\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
messages = [
{"role": "system", "content": SYSTEM_PROMPT_QE},
{"role": "user", "content": user_content},
{"role": "assistant", "content": intervals}
]
return messages
def build_main_content_example(block_text, content_indices, title="Wikipedia Article", url="https://en.wikipedia.org"):
"""Build a main content extraction (ME) example."""
intervals = indices_to_intervals(content_indices)
user_content = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks."
messages = [
{"role": "system", "content": SYSTEM_PROMPT_ME},
{"role": "user", "content": user_content},
{"role": "assistant", "content": intervals}
]
return messages
def process_hotpotqa():
"""Process HotpotQA into IndexLM training data."""
print("Loading HotpotQA...")
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train")
# Sample a manageable amount
num_samples = min(15000, len(ds))
ds = ds.shuffle(seed=42).select(range(num_samples))
all_examples = []
skipped = 0
for i, row in enumerate(ds):
if i % 1000 == 0:
print(f"Processing {i}/{num_samples}...")
try:
block_text, relevant_indices, content_indices = create_indexed_blocks_from_hotpotqa(
row['context'], row['supporting_facts'], inject_noise=True
)
# Skip if too few relevant indices
if len(relevant_indices) < 1:
skipped += 1
continue
# Query-relevant extraction example
qe_messages = build_query_relevant_example(
row['question'], block_text, relevant_indices
)
all_examples.append({
"messages": qe_messages,
"task_type": "query_relevant",
"source": "hotpotqa"
})
# Main content extraction example (50% of the time)
if random.random() < 0.5:
me_messages = build_main_content_example(
block_text, content_indices,
title=row['context']['title'][0] if row['context']['title'] else "Article"
)
all_examples.append({
"messages": me_messages,
"task_type": "main_content",
"source": "hotpotqa"
})
except Exception as e:
skipped += 1
if skipped < 5:
print(f"Error on row {i}: {e}")
continue
print(f"Created {len(all_examples)} examples from HotpotQA ({skipped} skipped)")
return all_examples
def create_synthetic_web_pages():
"""Create synthetic web page examples for main content extraction training."""
print("Creating synthetic web page examples...")
# Load a text dataset to get content
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation")
ds = ds.shuffle(seed=123).select(range(3000))
examples = []
for i, row in enumerate(ds):
if i % 500 == 0:
print(f"Synthetic page {i}/3000...")
try:
# Build a more realistic web page structure
titles = row['context']['title']
sentences_list = row['context']['sentences']
if not titles or not sentences_list:
continue
blocks = []
content_indices = []
idx = 1
# Header noise (nav, etc.)
num_header_noise = random.randint(1, 4)
for _ in range(num_header_noise):
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
idx += 1
# Page title
main_title = titles[0]
blocks.append(f"[{idx}] <h1>{main_title}</h1>")
content_indices.append(idx)
idx += 1
# Main content (just first 1-3 documents)
num_docs = min(random.randint(1, 3), len(titles))
for doc_idx in range(num_docs):
title = titles[doc_idx]
sents = sentences_list[doc_idx]
if doc_idx > 0:
blocks.append(f"[{idx}] <h2>{title}</h2>")
content_indices.append(idx)
idx += 1
for sent in sents:
sent = sent.strip()
if not sent:
continue
blocks.append(f"[{idx}] <p>{sent}</p>")
content_indices.append(idx)
idx += 1
# Occasional inline noise
if random.random() < 0.3:
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
idx += 1
# Footer noise
num_footer_noise = random.randint(1, 4)
for _ in range(num_footer_noise):
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}")
idx += 1
block_text = "\n".join(blocks)
me_messages = build_main_content_example(
block_text, content_indices,
title=main_title,
url=f"https://en.wikipedia.org/wiki/{main_title.replace(' ', '_')}"
)
examples.append({
"messages": me_messages,
"task_type": "main_content",
"source": "synthetic"
})
except Exception as e:
continue
print(f"Created {len(examples)} synthetic web page examples")
return examples
def create_na_examples():
"""Create examples where no relevant content exists (model should output 'NA')."""
print("Creating NA examples...")
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation")
ds = ds.shuffle(seed=456).select(range(1000))
examples = []
for i, row in enumerate(ds):
try:
# Use context from one question but query from another (mismatched)
other_idx = (i + 500) % len(ds)
other_question = ds[other_idx]['question']
# Build blocks from current context but keep only non-supporting content
block_text, _, content_indices = create_indexed_blocks_from_hotpotqa(
row['context'], {'title': [], 'sent_id': []}, inject_noise=True
)
# The query doesn't match this content → expected output: NA
# But actually some content might still be tangentially relevant,
# so we'll be conservative and only do this for clearly mismatched pairs
user_content = f"URL: https://en.wikipedia.org\nQuery: {other_question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
messages = [
{"role": "system", "content": SYSTEM_PROMPT_QE},
{"role": "user", "content": user_content},
{"role": "assistant", "content": "NA"}
]
examples.append({
"messages": messages,
"task_type": "query_relevant_na",
"source": "hotpotqa_mismatched"
})
except:
continue
# Keep only a fraction (the paper mentions partial filtering of NA)
random.shuffle(examples)
examples = examples[:300]
print(f"Created {len(examples)} NA examples")
return examples
def main():
# Build all training examples
qe_examples = process_hotpotqa()
me_examples = create_synthetic_web_pages()
na_examples = create_na_examples()
all_examples = qe_examples + me_examples + na_examples
random.shuffle(all_examples)
print(f"\nTotal examples: {len(all_examples)}")
# Count by type
type_counts = defaultdict(int)
for ex in all_examples:
type_counts[ex['task_type']] += 1
for t, c in type_counts.items():
print(f" {t}: {c}")
# Check lengths
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
lengths = []
for ex in all_examples[:500]:
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
tokens = tokenizer.encode(text)
lengths.append(len(tokens))
print(f"\nToken length stats (sample of 500):")
print(f" Min: {min(lengths)}")
print(f" Max: {max(lengths)}")
print(f" Mean: {sum(lengths)/len(lengths):.0f}")
print(f" Median: {sorted(lengths)[len(lengths)//2]}")
# Filter out examples that are too long (>4096 tokens for efficiency)
MAX_LEN = 4096
filtered = []
too_long = 0
for ex in all_examples:
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
tokens = tokenizer.encode(text)
if len(tokens) <= MAX_LEN:
filtered.append(ex)
else:
too_long += 1
print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)")
print(f"Final dataset: {len(filtered)} examples")
# Split into train/eval
random.shuffle(filtered)
eval_size = min(500, len(filtered) // 10)
train_data = filtered[:-eval_size]
eval_data = filtered[-eval_size:]
print(f"Train: {len(train_data)}, Eval: {len(eval_data)}")
# Create HF dataset with just messages column (for SFTTrainer)
train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data])
eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data])
# Save locally
train_ds.save_to_disk("/app/indexlm_train")
eval_ds.save_to_disk("/app/indexlm_eval")
# Also push to HF Hub
from huggingface_hub import login
import os
login(token=os.environ.get("HF_TOKEN"))
from datasets import DatasetDict
ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds})
ds_dict.push_to_hub("OmAlve/indexlm-training-data")
print("\nDone! Dataset pushed to OmAlve/indexlm-training-data")
# Print sample
print("\n=== Sample QE example ===")
for ex in train_data[:3]:
if ex.get("task_type", "") == "query_relevant":
for m in ex["messages"]:
print(f"\n[{m['role']}]: {m['content'][:200]}...")
break
print("\n=== Sample ME example ===")
for ex in train_data[:10]:
if ex.get("task_type", "") == "main_content":
for m in ex["messages"]:
print(f"\n[{m['role']}]: {m['content'][:200]}...")
break
if __name__ == "__main__":
main()