"""
Prepare IndexLM training data from HotpotQA and MSMARCO.
Pipeline:
1. Load HotpotQA (has context = list of (title, sentences) + supporting_facts)
2. Convert context into indexed HTML-like blocks: [i]
for regular text blocks.append(f"[{idx}]
{sentence}
") content_indices.append(idx) if title in sf_lookup and sent_idx in sf_lookup[title]: relevant_indices.append(idx) idx += 1 # Inject noise between documents sometimes if inject_noise and random.random() < 0.4 and doc_idx < len(titles) - 1: noise = random.choice(NOISE_BLOCKS) blocks.append(f"[{idx}] {noise}") idx += 1 # Sometimes add noise at start and end if inject_noise: prefix_noise = [] if random.random() < 0.5: for _ in range(random.randint(1, 3)): noise = random.choice(NOISE_BLOCKS) prefix_noise.append(noise) suffix_noise = [] if random.random() < 0.5: for _ in range(random.randint(1, 3)): noise = random.choice(NOISE_BLOCKS) suffix_noise.append(noise) if prefix_noise or suffix_noise: # Reindex everything new_blocks = [] new_relevant = [] new_content = [] new_idx = 1 # Prefix noise for noise in prefix_noise: new_blocks.append(f"[{new_idx}] {noise}") new_idx += 1 # Remap original blocks offset = len(prefix_noise) for b in blocks: old_idx = int(b.split(']')[0].replace('[', '')) new_b = f"[{old_idx + offset}] " + '] '.join(b.split('] ')[1:]) new_blocks.append(new_b) new_relevant = [r + offset for r in relevant_indices] new_content = [c + offset for c in content_indices] # Suffix noise next_idx = len(new_blocks) + 1 for noise in suffix_noise: new_blocks.append(f"[{next_idx}] {noise}") next_idx += 1 blocks = new_blocks relevant_indices = new_relevant content_indices = new_content block_text = "\n".join(blocks) return block_text, relevant_indices, content_indices def build_query_relevant_example(question, block_text, relevant_indices, url="https://en.wikipedia.org"): """Build a query-relevant extraction (QE) example.""" intervals = indices_to_intervals(relevant_indices) user_content = f"URL: {url}\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." messages = [ {"role": "system", "content": SYSTEM_PROMPT_QE}, {"role": "user", "content": user_content}, {"role": "assistant", "content": intervals} ] return messages def build_main_content_example(block_text, content_indices, title="Wikipedia Article", url="https://en.wikipedia.org"): """Build a main content extraction (ME) example.""" intervals = indices_to_intervals(content_indices) user_content = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." messages = [ {"role": "system", "content": SYSTEM_PROMPT_ME}, {"role": "user", "content": user_content}, {"role": "assistant", "content": intervals} ] return messages def process_hotpotqa(): """Process HotpotQA into IndexLM training data.""" print("Loading HotpotQA...") ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train") # Sample a manageable amount num_samples = min(15000, len(ds)) ds = ds.shuffle(seed=42).select(range(num_samples)) all_examples = [] skipped = 0 for i, row in enumerate(ds): if i % 1000 == 0: print(f"Processing {i}/{num_samples}...") try: block_text, relevant_indices, content_indices = create_indexed_blocks_from_hotpotqa( row['context'], row['supporting_facts'], inject_noise=True ) # Skip if too few relevant indices if len(relevant_indices) < 1: skipped += 1 continue # Query-relevant extraction example qe_messages = build_query_relevant_example( row['question'], block_text, relevant_indices ) all_examples.append({ "messages": qe_messages, "task_type": "query_relevant", "source": "hotpotqa" }) # Main content extraction example (50% of the time) if random.random() < 0.5: me_messages = build_main_content_example( block_text, content_indices, title=row['context']['title'][0] if row['context']['title'] else "Article" ) all_examples.append({ "messages": me_messages, "task_type": "main_content", "source": "hotpotqa" }) except Exception as e: skipped += 1 if skipped < 5: print(f"Error on row {i}: {e}") continue print(f"Created {len(all_examples)} examples from HotpotQA ({skipped} skipped)") return all_examples def create_synthetic_web_pages(): """Create synthetic web page examples for main content extraction training.""" print("Creating synthetic web page examples...") # Load a text dataset to get content ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") ds = ds.shuffle(seed=123).select(range(3000)) examples = [] for i, row in enumerate(ds): if i % 500 == 0: print(f"Synthetic page {i}/3000...") try: # Build a more realistic web page structure titles = row['context']['title'] sentences_list = row['context']['sentences'] if not titles or not sentences_list: continue blocks = [] content_indices = [] idx = 1 # Header noise (nav, etc.) num_header_noise = random.randint(1, 4) for _ in range(num_header_noise): blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") idx += 1 # Page title main_title = titles[0] blocks.append(f"[{idx}]{sent}
") content_indices.append(idx) idx += 1 # Occasional inline noise if random.random() < 0.3: blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") idx += 1 # Footer noise num_footer_noise = random.randint(1, 4) for _ in range(num_footer_noise): blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") idx += 1 block_text = "\n".join(blocks) me_messages = build_main_content_example( block_text, content_indices, title=main_title, url=f"https://en.wikipedia.org/wiki/{main_title.replace(' ', '_')}" ) examples.append({ "messages": me_messages, "task_type": "main_content", "source": "synthetic" }) except Exception as e: continue print(f"Created {len(examples)} synthetic web page examples") return examples def create_na_examples(): """Create examples where no relevant content exists (model should output 'NA').""" print("Creating NA examples...") ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") ds = ds.shuffle(seed=456).select(range(1000)) examples = [] for i, row in enumerate(ds): try: # Use context from one question but query from another (mismatched) other_idx = (i + 500) % len(ds) other_question = ds[other_idx]['question'] # Build blocks from current context but keep only non-supporting content block_text, _, content_indices = create_indexed_blocks_from_hotpotqa( row['context'], {'title': [], 'sent_id': []}, inject_noise=True ) # The query doesn't match this content → expected output: NA # But actually some content might still be tangentially relevant, # so we'll be conservative and only do this for clearly mismatched pairs user_content = f"URL: https://en.wikipedia.org\nQuery: {other_question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." messages = [ {"role": "system", "content": SYSTEM_PROMPT_QE}, {"role": "user", "content": user_content}, {"role": "assistant", "content": "NA"} ] examples.append({ "messages": messages, "task_type": "query_relevant_na", "source": "hotpotqa_mismatched" }) except: continue # Keep only a fraction (the paper mentions partial filtering of NA) random.shuffle(examples) examples = examples[:300] print(f"Created {len(examples)} NA examples") return examples def main(): # Build all training examples qe_examples = process_hotpotqa() me_examples = create_synthetic_web_pages() na_examples = create_na_examples() all_examples = qe_examples + me_examples + na_examples random.shuffle(all_examples) print(f"\nTotal examples: {len(all_examples)}") # Count by type type_counts = defaultdict(int) for ex in all_examples: type_counts[ex['task_type']] += 1 for t, c in type_counts.items(): print(f" {t}: {c}") # Check lengths from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") lengths = [] for ex in all_examples[:500]: text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) tokens = tokenizer.encode(text) lengths.append(len(tokens)) print(f"\nToken length stats (sample of 500):") print(f" Min: {min(lengths)}") print(f" Max: {max(lengths)}") print(f" Mean: {sum(lengths)/len(lengths):.0f}") print(f" Median: {sorted(lengths)[len(lengths)//2]}") # Filter out examples that are too long (>4096 tokens for efficiency) MAX_LEN = 4096 filtered = [] too_long = 0 for ex in all_examples: text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) tokens = tokenizer.encode(text) if len(tokens) <= MAX_LEN: filtered.append(ex) else: too_long += 1 print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)") print(f"Final dataset: {len(filtered)} examples") # Split into train/eval random.shuffle(filtered) eval_size = min(500, len(filtered) // 10) train_data = filtered[:-eval_size] eval_data = filtered[-eval_size:] print(f"Train: {len(train_data)}, Eval: {len(eval_data)}") # Create HF dataset with just messages column (for SFTTrainer) train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data]) eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data]) # Save locally train_ds.save_to_disk("/app/indexlm_train") eval_ds.save_to_disk("/app/indexlm_eval") # Also push to HF Hub from huggingface_hub import login import os login(token=os.environ.get("HF_TOKEN")) from datasets import DatasetDict ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds}) ds_dict.push_to_hub("OmAlve/indexlm-training-data") print("\nDone! Dataset pushed to OmAlve/indexlm-training-data") # Print sample print("\n=== Sample QE example ===") for ex in train_data[:3]: if ex.get("task_type", "") == "query_relevant": for m in ex["messages"]: print(f"\n[{m['role']}]: {m['content'][:200]}...") break print("\n=== Sample ME example ===") for ex in train_data[:10]: if ex.get("task_type", "") == "main_content": for m in ex["messages"]: print(f"\n[{m['role']}]: {m['content'][:200]}...") break if __name__ == "__main__": main()