"""
Prepare DIVERSE IndexLM training data from multiple sources:
1. HtmlRAG-train (real Bing-scraped web HTML) — diverse domains
2. MultiHopRAG (news domain) — technology, business, sports, entertainment
3. HotpotQA (Wikipedia) — structured QA with supporting facts
This avoids the Wikipedia-only bias of the original dataset.
Output: Conversational messages for SFT with TRL SFTTrainer
Format: system + user (indexed HTML blocks + query) → assistant (index intervals)
"""
import json
import random
import re
import os
from datasets import load_dataset, Dataset, DatasetDict
from collections import defaultdict
from bs4 import BeautifulSoup
import html as html_lib
random.seed(42)
# ============ System Prompts ============
SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query.
Each block is formatted as: [i] content
Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive).
If no relevant content exists, output 'NA'.
Example output: [[2,4],[7,7],[10,12]]"""
SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements).
Each block is formatted as: [i] content
Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive).
If no main content exists, output 'NA'.
Example output: [[1,3],[5,8],[11,15]]"""
# ============ Noise blocks for injection ============
NOISE_BLOCKS_REALISTIC = [
'Home | About | Contact | Privacy Policy | Terms of Service ',
'
This website uses cookies to improve your experience. By continuing to use this site, you consent to our use of cookies. Accept | Manage Preferences
',
'Subscribe to our newsletter for the latest updates delivered to your inbox weekly.
',
'',
'Already a subscriber? Log in for full access. Not a member? Subscribe now starting at $4.99/month.
',
'Watch: Video player requires JavaScript to be enabled. [Video placeholder]
',
'BREAKING: Markets rally on latest economic data | Sports: Championship results | Weather: Storm warning issued
',
'We value your privacy. We and our partners use tracking technologies to improve your browsing experience, serve personalized content, and analyze traffic.
',
'Download our app for a better reading experience! Available on iOS and Android.
',
'',
'This article is available in print edition. Subscribe for home delivery.
',
'',
', , , , etc.
# Split at positions where block-level tags start
block_tag_pattern = r'(<(?:div|p|h[1-6]|li|ul|ol|table|tr|td|th|article|section|header|footer|nav|aside|main|blockquote|pre|form|figure|figcaption|details|summary|option|title|button|label|select|textarea|hgroup|dl|dd|dt|caption|thead|tbody|tfoot)\b[^>]*>)'
# Also handle HtmlRAG numbered tags like , , etc.
block_tag_pattern_numbered = r'(<(?:div|p|h|li|ul|ol|table|tr|td|th|article|section|header|footer|nav|aside|main|blockquote|pre|form|figure|option|title|button|hgroup)\d*[^>]*>)'
# Split content by block-level tags
parts = re.split(block_tag_pattern_numbered, html_content)
current_block = ''
for part in parts:
part = part.strip()
if not part:
continue
# Check if this part is a block-level opening tag
if re.match(block_tag_pattern_numbered, part):
# Save previous block if it has content
if current_block.strip():
blocks.append(current_block.strip())
current_block = part
else:
current_block += ' ' + part
# Don't forget the last block
if current_block.strip():
blocks.append(current_block.strip())
# If tag-based splitting yields too few blocks, fall back to line-based
if len(blocks) < 5:
blocks = []
lines = html_content.split('\n')
for line in lines:
line = line.strip()
if line and len(line) > 5:
blocks.append(line)
# If still too few, split by multiple tags on same line
if len(blocks) < 5:
new_blocks = []
for block in blocks:
# Try splitting long blocks by inner tags
if len(block) > 200:
inner_parts = re.split(r'((?:div|p|h[1-6]|li|td|th|article|section)\d*>)', block)
current = ''
for ip in inner_parts:
current += ip
if re.match(r'(?:div|p|h[1-6]|li|td|th|article|section)\d*>', ip):
if current.strip():
new_blocks.append(current.strip())
current = ''
if current.strip():
new_blocks.append(current.strip())
else:
new_blocks.append(block)
if len(new_blocks) > len(blocks):
blocks = new_blocks
# Filter: extract text and remove blocks with no meaningful content
def extract_text_simple(s):
clean = re.sub(r'<[^>]+>', ' ', s)
return re.sub(r'\s+', ' ', clean).strip()
blocks = [b for b in blocks if len(extract_text_simple(b)) > 5]
return blocks
def classify_block_as_noise(block_text):
"""Heuristic: classify if a block is likely noise (nav, ad, etc.)."""
text_lower = block_text.lower()
noise_indicators = [
'cookie', 'privacy policy', 'terms of service', 'advertisement',
'subscribe', 'newsletter', 'sign up', 'log in', 'login',
'copyright ©', 'all rights reserved', 'skip to', 'accessibility',
'share on twitter', 'share on facebook', 'social media',
'related articles', 'you may also like', 'trending now',
'app download', 'sponsored content', 'affiliate',
]
nav_patterns = [' div, h20 -> h2)
tag = re.sub(r'\d+$', '', tag)
if not tag:
tag = 'div'
else:
tag = 'p'
text = extract_text_content(block)
if not text or len(text) < 3:
continue
indexed_blocks.append(f"[{idx}] <{tag}>{text}{tag}>")
# Check if this block is noise
is_noise = classify_block_as_noise(block)
if not is_noise:
content_indices.append(idx)
# Check relevance by substring matching with assistant output
# Use the full relevant text as a search target
text_lower = text.lower()
relevant_lower = relevant_text.lower()
# Method 1: Check if significant portions of relevant text appear in block
# Split relevant text into 3-word ngrams and check for matches
rel_words_list = relevant_lower.split()
matched = False
# Check 3-gram overlap
for i in range(len(rel_words_list) - 2):
trigram = ' '.join(rel_words_list[i:i+3])
if trigram in text_lower:
matched = True
break
# Also check: does the block text appear as a substring in the relevant text?
if not matched and len(text) > 15:
# Check if meaningful portion of block appears in relevant output
block_sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 10]
for sent in block_sentences:
if sent.lower() in relevant_lower:
matched = True
break
# Also check word overlap with a more lenient threshold
if not matched:
block_words = set(text_lower.split())
if relevant_words and block_words:
overlap_count = len(block_words & relevant_words)
# At least 3 content words overlap (excluding stopwords)
stopwords = {'the','a','an','is','are','was','were','in','on','at','to','for','of','and','or','but','with','by','from','as','it','this','that','be','has','have','had','do','does','did','not','no'}
content_overlap = len((block_words - stopwords) & (relevant_words - stopwords))
if content_overlap >= 2:
matched = True
if matched:
relevant_indices.append(idx)
if not indexed_blocks or len(indexed_blocks) < 3:
return None
block_text = "\n".join(indexed_blocks)
results = []
# Query-relevant extraction example
if relevant_indices:
intervals = indices_to_intervals(relevant_indices)
user_msg = f"URL: https://example.com\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
results.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT_QE},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": intervals}
],
"task_type": "query_relevant",
"source": "htmlrag"
})
# Main content extraction example (30% of the time to balance)
if content_indices and random.random() < 0.3:
intervals = indices_to_intervals(content_indices)
user_msg = f"URL: https://example.com\nTitle: Web Page\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks."
results.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT_ME},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": intervals}
],
"task_type": "main_content",
"source": "htmlrag"
})
return results
def load_htmlrag_data():
"""Load and convert HtmlRAG-train data."""
print("Loading HtmlRAG-train (real web HTML)...")
# Use 4k and 8k token variants - good balance of context
files = [
'nq-4k.jsonl', 'nq-8k.jsonl',
'asqa-4k.jsonl', 'asqa-8k.jsonl',
'trivia-qa-4k.jsonl', 'trivia-qa-8k.jsonl',
'musique-4k.jsonl', 'musique-8k.jsonl',
'hotpot-qa-4k.jsonl', 'hotpot-qa-8k.jsonl',
]
all_examples = []
for file in files:
print(f" Processing {file}...")
try:
ds = load_dataset('zstanjj/HtmlRAG-train', data_files=file, split='train')
count = 0
for row in ds:
results = process_htmlrag_example(row)
if results:
all_examples.extend(results)
count += len(results)
print(f" Got {count} examples from {file}")
except Exception as e:
print(f" Error loading {file}: {e}")
print(f" Total HtmlRAG examples: {len(all_examples)}")
return all_examples
# ============================================================
# SOURCE 2: MultiHopRAG (News domain)
# ============================================================
def process_multihoprag():
"""Convert MultiHopRAG news articles into IndexLM format."""
print("Loading MultiHopRAG (news domain)...")
corpus = load_dataset("yixuantt/MultiHopRAG", name="corpus", split="train")
queries = load_dataset("yixuantt/MultiHopRAG", name="MultiHopRAG", split="train")
# Build URL->article lookup
url_to_article = {}
for article in corpus:
url_to_article[article['url']] = article
all_examples = []
for q_row in queries:
query = q_row['query']
evidence_list = q_row['evidence_list']
for evidence in evidence_list:
url = evidence.get('url', '')
fact = evidence.get('fact', '')
if url not in url_to_article or not fact:
continue
article = url_to_article[url]
title = article.get('title', 'News Article')
body = article.get('body', '')
source = article.get('source', 'Unknown')
category = article.get('category', 'general')
if not body or len(body) < 100:
continue
# Split article body into paragraphs
paragraphs = [p.strip() for p in body.split('\n') if p.strip() and len(p.strip()) > 20]
if not paragraphs:
continue
# Build indexed blocks with realistic web structure
blocks = []
content_indices = []
relevant_indices = []
idx = 1
# Add realistic header noise
num_header = random.randint(1, 3)
for _ in range(num_header):
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}")
idx += 1
# Article title
blocks.append(f"[{idx}] {title} ")
content_indices.append(idx)
idx += 1
# Author/date line
author = article.get('author', 'Staff Writer')
published = article.get('published_at', '2024-01-01')
blocks.append(f"[{idx}] By {author} | {source} | {published} | Category: {category}
")
content_indices.append(idx)
idx += 1
# Article paragraphs
fact_words = set(fact.lower().split())
for para in paragraphs:
# Determine tag
if len(para) < 60 and not para.endswith('.'):
tag = 'h2'
else:
tag = 'p'
blocks.append(f"[{idx}] <{tag}>{para}{tag}>")
content_indices.append(idx)
# Check if paragraph contains the evidence fact
para_words = set(para.lower().split())
overlap = len(para_words & fact_words)
if overlap > 5 or (fact_words and overlap / len(fact_words) > 0.3):
relevant_indices.append(idx)
idx += 1
# Occasional mid-article noise
if random.random() < 0.15:
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}")
idx += 1
# Footer noise
num_footer = random.randint(1, 4)
for _ in range(num_footer):
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}")
idx += 1
block_text = "\n".join(blocks)
# Query-relevant extraction
if relevant_indices:
intervals = indices_to_intervals(relevant_indices)
user_msg = f"URL: {url}\nQuery: {query}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
all_examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT_QE},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": intervals}
],
"task_type": "query_relevant",
"source": "multihoprag_news"
})
# Main content extraction
if content_indices and random.random() < 0.4:
intervals = indices_to_intervals(content_indices)
user_msg = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks."
all_examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT_ME},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": intervals}
],
"task_type": "main_content",
"source": "multihoprag_news"
})
print(f" Total MultiHopRAG examples: {len(all_examples)}")
return all_examples
# ============================================================
# SOURCE 3: HotpotQA (Wikipedia - but balanced as minority)
# ============================================================
def process_hotpotqa():
"""Process HotpotQA — kept but as a smaller proportion."""
print("Loading HotpotQA (Wikipedia domain)...")
ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train")
# Reduced from 15K to 5K — wiki should be minority source
num_samples = min(5000, len(ds))
ds = ds.shuffle(seed=42).select(range(num_samples))
all_examples = []
skipped = 0
for i, row in enumerate(ds):
if i % 1000 == 0:
print(f" Processing {i}/{num_samples}...")
try:
titles = row['context']['title']
sentences_list = row['context']['sentences']
sf = row['supporting_facts']
sf_lookup = defaultdict(set)
for title, sent_id in zip(sf['title'], sf['sent_id']):
sf_lookup[title].add(sent_id)
blocks = []
relevant_indices = []
content_indices = []
idx = 1
# Header noise
if random.random() < 0.6:
for _ in range(random.randint(1, 3)):
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}")
idx += 1
for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)):
blocks.append(f"[{idx}] {title} ")
content_indices.append(idx)
if title in sf_lookup:
relevant_indices.append(idx)
idx += 1
for sent_idx, sentence in enumerate(sentences):
sentence = sentence.strip()
if not sentence:
continue
blocks.append(f"[{idx}] {sentence}
")
content_indices.append(idx)
if title in sf_lookup and sent_idx in sf_lookup[title]:
relevant_indices.append(idx)
idx += 1
if random.random() < 0.3 and doc_idx < len(titles) - 1:
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}")
idx += 1
# Footer noise
if random.random() < 0.6:
for _ in range(random.randint(1, 3)):
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}")
idx += 1
if len(relevant_indices) < 1:
skipped += 1
continue
block_text = "\n".join(blocks)
# QE example
intervals = indices_to_intervals(relevant_indices)
user_msg = f"URL: https://en.wikipedia.org\nQuery: {row['question']}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
all_examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT_QE},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": intervals}
],
"task_type": "query_relevant",
"source": "hotpotqa_wiki"
})
# ME example (less frequent - wiki is minority)
if random.random() < 0.3:
intervals = indices_to_intervals(content_indices)
user_msg = f"URL: https://en.wikipedia.org\nTitle: {titles[0]}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks."
all_examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT_ME},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": intervals}
],
"task_type": "main_content",
"source": "hotpotqa_wiki"
})
except Exception as e:
skipped += 1
continue
print(f" Total HotpotQA examples: {len(all_examples)} ({skipped} skipped)")
return all_examples
# ============================================================
# SOURCE 4: MS MARCO (Diverse web QA)
# ============================================================
def process_msmarco():
"""Process MS MARCO for diverse web domain QA examples."""
print("Loading MS MARCO (diverse web QA)...")
try:
ds = load_dataset("microsoft/ms_marco", "v1.1", split="train")
# Sample a manageable subset
num_samples = min(5000, len(ds))
ds = ds.shuffle(seed=99).select(range(num_samples))
except Exception as e:
print(f" Could not load MS MARCO: {e}")
return []
all_examples = []
for i, row in enumerate(ds):
if i % 1000 == 0:
print(f" Processing {i}/{num_samples}...")
try:
query = row['query']
passages = row['passages']
if not passages or not passages.get('passage_text'):
continue
passage_texts = passages['passage_text']
is_selected = passages.get('is_selected', [0] * len(passage_texts))
if not any(is_selected):
continue
# Build blocks from passages (these are real web snippets from Bing)
blocks = []
relevant_indices = []
content_indices = []
idx = 1
# Header noise
if random.random() < 0.5:
for _ in range(random.randint(1, 2)):
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}")
idx += 1
for p_idx, (text, selected) in enumerate(zip(passage_texts, is_selected)):
text = text.strip()
if not text:
continue
# Simulate different content types
if p_idx == 0 and random.random() < 0.3:
tag = 'h1'
elif len(text) < 80:
tag = random.choice(['h2', 'h3', 'strong'])
else:
tag = 'p'
blocks.append(f"[{idx}] <{tag}>{text}{tag}>")
content_indices.append(idx)
if selected:
relevant_indices.append(idx)
idx += 1
# Between-passage noise
if random.random() < 0.2:
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}")
idx += 1
# Footer noise
if random.random() < 0.5:
for _ in range(random.randint(1, 2)):
blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}")
idx += 1
if not relevant_indices or len(blocks) < 3:
continue
block_text = "\n".join(blocks)
# QE example
intervals = indices_to_intervals(relevant_indices)
query_type = row.get('query_type', 'general')
user_msg = f"URL: https://www.bing.com/search\nQuery: {query}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query."
all_examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT_QE},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": intervals}
],
"task_type": "query_relevant",
"source": f"msmarco_{query_type}"
})
except Exception as e:
continue
print(f" Total MS MARCO examples: {len(all_examples)}")
return all_examples
# ============================================================
# NA Examples (no relevant content)
# ============================================================
def create_na_examples(all_examples):
"""Create NA examples by mismatching queries with pages."""
print("Creating NA examples (mismatched query-page pairs)...")
# Get QE examples
qe_examples = [e for e in all_examples if e['task_type'] == 'query_relevant']
if len(qe_examples) < 100:
print(" Too few QE examples for NA generation")
return []
na_examples = []
for i in range(min(500, len(qe_examples) // 5)):
# Pick two random QE examples
idx_a = random.randint(0, len(qe_examples) - 1)
idx_b = (idx_a + random.randint(100, len(qe_examples) - 1)) % len(qe_examples)
# Use query from A, blocks from B
msgs_a = qe_examples[idx_a]['messages']
msgs_b = qe_examples[idx_b]['messages']
# Extract query from A
user_a = msgs_a[1]['content']
query_match = re.search(r'Query: (.+?)(\n|$)', user_a)
if not query_match:
continue
query = query_match.group(1).strip()
# Extract blocks from B
user_b = msgs_b[1]['content']
blocks_match = re.search(r'Blocks:\n(.+?)(\n\nOutput)', user_b, re.DOTALL)
if not blocks_match:
continue
blocks = blocks_match.group(1)
user_msg = f"URL: https://example.com\nQuery: {query}\n\nBlocks:\n{blocks}\n\nOutput the index intervals of blocks relevant to the query."
na_examples.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT_QE},
{"role": "user", "content": user_msg},
{"role": "assistant", "content": "NA"}
],
"task_type": "query_relevant_na",
"source": "mismatched"
})
print(f" Created {len(na_examples)} NA examples")
return na_examples
# ============================================================
# Main Pipeline
# ============================================================
def main():
print("=" * 60)
print("Building DIVERSE IndexLM Training Data")
print("=" * 60)
# Collect from all sources
htmlrag_examples = load_htmlrag_data() # Real web HTML (primary)
multihoprag_examples = process_multihoprag() # News domain
hotpotqa_examples = process_hotpotqa() # Wikipedia (minority)
msmarco_examples = process_msmarco() # Diverse web QA
# Combine
all_examples = htmlrag_examples + multihoprag_examples + hotpotqa_examples + msmarco_examples
# Add NA examples
na_examples = create_na_examples(all_examples)
all_examples.extend(na_examples)
random.shuffle(all_examples)
# Print composition
print(f"\n{'='*60}")
print(f"Total examples: {len(all_examples)}")
source_counts = defaultdict(int)
type_counts = defaultdict(int)
for ex in all_examples:
source_counts[ex.get('source', 'unknown')] += 1
type_counts[ex['task_type']] += 1
print("\nBy source:")
for s, c in sorted(source_counts.items(), key=lambda x: -x[1]):
pct = c / len(all_examples) * 100
print(f" {s}: {c} ({pct:.1f}%)")
print("\nBy task type:")
for t, c in sorted(type_counts.items(), key=lambda x: -x[1]):
pct = c / len(all_examples) * 100
print(f" {t}: {c} ({pct:.1f}%)")
# Check token lengths
print("\nChecking token lengths...")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
lengths = []
for ex in random.sample(all_examples, min(500, len(all_examples))):
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
tokens = tokenizer.encode(text)
lengths.append(len(tokens))
print(f"Token length stats (sample of {len(lengths)}):")
print(f" Min: {min(lengths)}, Max: {max(lengths)}")
print(f" Mean: {sum(lengths)/len(lengths):.0f}, Median: {sorted(lengths)[len(lengths)//2]}")
# Filter by length
MAX_LEN = 4096
filtered = []
too_long = 0
for ex in all_examples:
text = tokenizer.apply_chat_template(ex['messages'], tokenize=False)
tokens = tokenizer.encode(text)
if len(tokens) <= MAX_LEN:
filtered.append(ex)
else:
too_long += 1
print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)")
print(f"Final dataset size: {len(filtered)}")
# Final composition
final_source_counts = defaultdict(int)
for ex in filtered:
final_source_counts[ex.get('source', 'unknown')] += 1
print("\nFinal composition by source:")
for s, c in sorted(final_source_counts.items(), key=lambda x: -x[1]):
pct = c / len(filtered) * 100
print(f" {s}: {c} ({pct:.1f}%)")
# Split
random.shuffle(filtered)
eval_size = min(500, len(filtered) // 10)
train_data = filtered[:-eval_size]
eval_data = filtered[-eval_size:]
print(f"\nTrain: {len(train_data)}, Eval: {len(eval_data)}")
# Create HF datasets
train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data])
eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data])
# Save locally
train_ds.save_to_disk("/app/indexlm_train_v2")
eval_ds.save_to_disk("/app/indexlm_eval_v2")
# Push to Hub
ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds})
ds_dict.push_to_hub("OmAlve/indexlm-training-data", token=os.environ.get("HF_TOKEN"))
print(f"\n{'='*60}")
print("Done! Dataset pushed to OmAlve/indexlm-training-data")
print(f"{'='*60}")
if __name__ == "__main__":
main()