Text Generation
Transformers
TensorBoard
Safetensors
qwen3
Generated from Trainer
trl
sft
trackio
conversational
text-generation-inference
Instructions to use OmAlve/reading-steiner with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use OmAlve/reading-steiner with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="OmAlve/reading-steiner") messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("OmAlve/reading-steiner") model = AutoModelForCausalLM.from_pretrained("OmAlve/reading-steiner") messages = [ {"role": "user", "content": "Who are you?"}, ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=40) print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])) - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use OmAlve/reading-steiner with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "OmAlve/reading-steiner" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "OmAlve/reading-steiner", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/OmAlve/reading-steiner
- SGLang
How to use OmAlve/reading-steiner with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "OmAlve/reading-steiner" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "OmAlve/reading-steiner", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "OmAlve/reading-steiner" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "OmAlve/reading-steiner", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use OmAlve/reading-steiner with Docker Model Runner:
docker model run hf.co/OmAlve/reading-steiner
| # IndexLM-0.6B: Index-based Web Content Extraction | |
| ## Project Handoff Document | |
| **Paper**: [An Index-based Approach for Efficient and Effective Web Content Extraction](https://arxiv.org/abs/2512.06641) | |
| **Goal**: Fine-tune a SOTA web content extraction model that runs fast on CPU | |
| **Status**: Dataset prepared & pushed ✅ | Training script ready ✅ | Training NOT yet run ❌ | |
| --- | |
| ## 1. What This Is | |
| The paper introduces **IndexLM** — a model that extracts relevant content from web pages by predicting **index intervals** instead of generating full text. This makes it: | |
| - **10–50× faster** than generative extraction (ReaderLM-v2, Firecrawl, etc.) | |
| - **SOTA on RAG QA** benchmarks (HotpotQA, NQ, TriviaQA, MuSiQue, MultiHopRAG) | |
| - **Tiny**: even the 0.6B version beats all baselines | |
| The original IndexLM weights are **not publicly released**. This project replicates the approach. | |
| ### How It Works | |
| 1. HTML is cleaned and split into indexed blocks: `[1] <h1>Title</h1>`, `[2] <p>Content...</p>`, etc. | |
| 2. The model receives these blocks + a query | |
| 3. It outputs index intervals like `[[2,4],[7,7],[10,12]]` — identifying which blocks are relevant | |
| 4. The blocks are reassembled into clean HTML/Markdown | |
| Two tasks: | |
| - **Query-relevant extraction (QE)**: Extract blocks relevant to a specific query | |
| - **Main content extraction (ME)**: Extract main content, filtering out nav/ads/sidebars | |
| ### Paper Results (Table 2 & 3) | |
| | Model | Params | Avg RAG QA F1 | ME F1 | QE F1 | Latency (ME) | | |
| |-------|--------|---------------|-------|-------|-------------| | |
| | **IndexLM-0.6B** | 0.6B | 54.70 | 83.38 | 28.64 | **0.35s** | | |
| | **IndexLM-4B** | 4B | 55.41 | 87.40 | 31.69 | 0.81s | | |
| | ReaderLM-v2 | 1.5B | 46.84 | 68.89 | 13.31 | 11.76s | | |
| | HtmlRAG | - | 47.00 | 48.65 | 8.83 | 7.12s | | |
| | Firecrawl Extract | API | 52.72 | - | 29.48 | 11.33s | | |
| --- | |
| ## 2. What's Been Done | |
| ### ✅ Dataset Created & Pushed (v2 — Multi-domain) | |
| **Hub**: [`OmAlve/indexlm-training-data`](https://huggingface.co/datasets/OmAlve/indexlm-training-data) | |
| | Split | Rows | | |
| |-------|------| | |
| | train | 21,098 | | |
| | eval | 500 | | |
| **Domain Composition (avoids Wikipedia-only bias):** | |
| | Source | Count | % | Domain | | |
| |--------|-------|---|--------| | |
| | MultiHopRAG | 7,165 | 33.2% | News (Mashable, CNBC, AP, etc.) | | |
| | HotpotQA | 6,479 | 30.0% | Wikipedia | | |
| | HtmlRAG-train | 2,692 | 12.5% | **Real Bing-scraped web HTML** (diverse) | | |
| | MS MARCO | 4,844 | 22.4% | Diverse web (Bing search results) | | |
| | NA (mismatched) | 418 | 1.9% | Cross-domain | | |
| **Task Type Composition:** | |
| - `query_relevant`: ~78% — query-specific extraction | |
| - `main_content`: ~20% — main content vs. noise (nav/ads/cookies) | |
| - `query_relevant_na`: ~2% — no relevant content exists | |
| **Key improvement over v1**: Real web HTML from Bing search results (via HtmlRAG-train) + news articles + MS MARCO diverse web QA, not just Wikipedia. | |
| **Format**: Conversational `messages` column (SFTTrainer-native): | |
| ```json | |
| { | |
| "messages": [ | |
| {"role": "system", "content": "You are IndexLM, a web content extraction model..."}, | |
| {"role": "user", "content": "URL: ...\nQuery: ...\n\nBlocks:\n[1] <h2>Title</h2>\n[2] <p>Content</p>\n...\n\nOutput the index intervals of blocks relevant to the query."}, | |
| {"role": "assistant", "content": "[[2, 4], [7, 7]]"} | |
| ] | |
| } | |
| ``` | |
| **Token length stats** (Qwen3-0.6B tokenizer): | |
| - Min: 316, Max: 4,105, Mean: 1,944, Median: 2,019 | |
| - 43 examples filtered (>4096 tokens) | |
| **Data pipeline** (from `prepare_data_v2.py`): | |
| 1. **HtmlRAG-train** (5,880 raw examples): Real Bing-scraped HTML from 5 QA datasets (NQ, ASQA, TriviaQA, MuSiQue, HotpotQA). Segments HTML by block-level tags, matches relevant blocks to ground-truth answers using trigram/substring matching. | |
| 2. **MultiHopRAG** (8,521 examples): News articles from Mashable, CNBC, AP, etc. Converts article body + evidence annotations to indexed blocks. Injects realistic noise blocks. | |
| 3. **HotpotQA** (6,486 examples, minority): Wikipedia context with supporting facts → index intervals. Noise injected. | |
| 4. **MS MARCO** (4,844 examples): Diverse web QA from Bing search. Passages from real web pages across numeric, entity, description, location, person query types. | |
| 5. **NA examples** (500): Mismatched query-page pairs from different sources. | |
| 6. Filters to ≤4096 tokens, shuffles, splits train/eval. | |
| ### ✅ Training Script Ready | |
| **File**: `train_indexlm.py` (see Section 5 below) | |
| Key settings: | |
| - **Base model**: `Qwen/Qwen3-0.6B` (751M params, bf16, GQA, 32K context) | |
| - **Method**: SFT via TRL `SFTTrainer` + `SFTConfig` | |
| - **Output**: `OmAlve/IndexLM-0.6B` on Hub | |
| - **Hyperparameters**: lr=2e-5, epochs=3, batch=4, grad_accum=4 (effective BS=16), max_length=4096, cosine LR schedule, warmup=5% | |
| - `push_to_hub=True`, `hub_model_id="OmAlve/IndexLM-0.6B"` | |
| - Trackio monitoring included | |
| - Flash Attention 2 for training speed | |
| ### ✅ Evaluation Script Ready | |
| **File**: `eval_indexlm.py` (see Section 5 below) | |
| Evaluates: | |
| - QE F1/Precision/Recall on eval split | |
| - ME F1/Precision/Recall on eval split | |
| - CPU inference speed benchmark | |
| ### ❌ Training Not Yet Run | |
| Ran into credits issue on HF Jobs (402 Payment Required). You need to run `train_indexlm.py` on a GPU. | |
| --- | |
| ## 3. How to Train | |
| ### Option A: HF Jobs (if you have credits) | |
| ```bash | |
| # Dependencies | |
| pip install "transformers>=4.51.0" "trl>=1.2.0" torch datasets accelerate trackio "flash-attn --no-build-isolation" | |
| ``` | |
| Recommended hardware: **a10g-large** ($2/hr) or **t4-small** ($0.60/hr) — model is only 0.6B params. | |
| Estimated time: **2-4 hours** on a10g, **4-6 hours** on T4. | |
| Set timeout to **6h** minimum. | |
| ### Option B: Any GPU machine | |
| ```bash | |
| pip install "transformers>=4.51.0" "trl>=1.2.0" torch datasets accelerate trackio | |
| pip install flash-attn --no-build-isolation # optional, speeds up training | |
| python train_indexlm.py | |
| ``` | |
| **VRAM**: ~8-10 GB with gradient checkpointing + bf16 at batch_size=4. Fits on T4 (16GB), any A-series, etc. | |
| ### Option C: Without Flash Attention | |
| If `flash-attn` fails to install, change this line in `train_indexlm.py`: | |
| ```python | |
| # FROM: | |
| attn_implementation="flash_attention_2", | |
| # TO: | |
| attn_implementation="sdpa", | |
| ``` | |
| --- | |
| ## 4. How to Deploy on CPU | |
| After training, the model at `OmAlve/IndexLM-0.6B` can be loaded for CPU inference: | |
| ```python | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "OmAlve/IndexLM-0.6B", | |
| torch_dtype=torch.float32, | |
| attn_implementation="sdpa", | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("OmAlve/IndexLM-0.6B") | |
| model.eval() | |
| # Example: extract relevant content from a web page | |
| messages = [ | |
| {"role": "system", "content": "You are IndexLM, a web content extraction model..."}, | |
| {"role": "user", "content": "URL: ...\nQuery: What is Python?\n\nBlocks:\n[1] <nav>Home</nav>\n[2] <h1>Python Programming</h1>\n[3] <p>Python is a programming language...</p>\n[4] <footer>Copyright 2024</footer>\n\nOutput the index intervals of blocks relevant to the query."} | |
| ] | |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = model.generate(**inputs, max_new_tokens=128, do_sample=False) | |
| response = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
| print(response) # → [[2, 3]] | |
| ``` | |
| **For even faster CPU**: quantize to INT4/INT8 with `bitsandbytes` or export to ONNX. | |
| --- | |
| ## 5. All Scripts | |
| ### 5.1 Data Preparation (`prepare_data.py`) | |
| ```python | |
| """ | |
| Prepare IndexLM training data from HotpotQA and MSMARCO. | |
| Pipeline: | |
| 1. Load HotpotQA (has context = list of (title, sentences) + supporting_facts) | |
| 2. Convert context into indexed HTML-like blocks: [i] <tag>content</tag> | |
| 3. The target is index intervals of blocks containing supporting facts | |
| 4. Also create main-content extraction examples (all content blocks are "main content", | |
| but we inject noise blocks like nav/ads to train the model to filter them) | |
| 5. Format as conversational messages for SFT | |
| """ | |
| import json | |
| import random | |
| import re | |
| from datasets import load_dataset, Dataset | |
| from collections import defaultdict | |
| random.seed(42) | |
| # Noise blocks to inject (simulating real web page clutter) | |
| NOISE_BLOCKS = [ | |
| '<nav>Home | About | Contact | Privacy Policy</nav>', | |
| '<div class="ad">Advertisement - Continue Reading Below</div>', | |
| '<div class="sidebar">Related Articles: Top 10 Facts You Didn\'t Know</div>', | |
| '<footer>© 2024 All Rights Reserved | Terms of Service</footer>', | |
| '<div class="cookie-banner">This site uses cookies. Accept | Decline</div>', | |
| '<div class="social">Share on: Twitter | Facebook | LinkedIn</div>', | |
| '<nav class="breadcrumb">Home > Category > Subcategory > Article</nav>', | |
| '<div class="newsletter">Subscribe to our newsletter for updates</div>', | |
| '<div class="popup">Sign up for free access to premium content</div>', | |
| '<aside>Trending: Latest news and popular stories</aside>', | |
| '<div class="comments">Comments (0) - Be the first to comment</div>', | |
| '<div class="author">Written by Staff Reporter | Updated: Jan 2024</div>', | |
| '<div class="pagination">Previous | 1 | 2 | 3 | Next</div>', | |
| '<div class="search">Search this site...</div>', | |
| '<div class="menu">Categories: Science, Tech, Health, Sports</div>', | |
| ] | |
| SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query. | |
| Each block is formatted as: [i] <tag>content</tag> | |
| Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive). | |
| If no relevant content exists, output 'NA'. | |
| Example output: [[2,4],[7,7],[10,12]]""" | |
| SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements). | |
| Each block is formatted as: [i] <tag>content</tag> | |
| Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive). | |
| If no main content exists, output 'NA'. | |
| Example output: [[1,3],[5,8],[11,15]]""" | |
| def indices_to_intervals(indices): | |
| """Convert a sorted list of indices to intervals [[start,end], ...]""" | |
| if not indices: | |
| return "NA" | |
| indices = sorted(set(indices)) | |
| intervals = [] | |
| start = indices[0] | |
| end = indices[0] | |
| for i in indices[1:]: | |
| if i == end + 1: | |
| end = i | |
| else: | |
| intervals.append([start, end]) | |
| start = i | |
| end = i | |
| intervals.append([start, end]) | |
| return json.dumps(intervals) | |
| def create_indexed_blocks_from_hotpotqa(context, supporting_facts, inject_noise=True): | |
| """ | |
| Convert HotpotQA context into indexed HTML blocks. | |
| context: {'title': [...], 'sentences': [[...], ...]} | |
| supporting_facts: {'title': [...], 'sent_id': [...]} | |
| Returns: (block_text, relevant_indices, all_content_indices) | |
| """ | |
| titles = context['title'] | |
| sentences_list = context['sentences'] | |
| # Build supporting facts lookup | |
| sf_lookup = defaultdict(set) | |
| for title, sent_id in zip(supporting_facts['title'], supporting_facts['sent_id']): | |
| sf_lookup[title].add(sent_id) | |
| blocks = [] | |
| relevant_indices = [] | |
| content_indices = [] # All real content (non-noise) | |
| idx = 1 | |
| for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)): | |
| # Title block | |
| blocks.append(f"[{idx}] <h2>{title}</h2>") | |
| content_indices.append(idx) | |
| if title in sf_lookup: | |
| # Title of a supporting document is relevant | |
| relevant_indices.append(idx) | |
| idx += 1 | |
| # Sentence blocks | |
| for sent_idx, sentence in enumerate(sentences): | |
| sentence = sentence.strip() | |
| if not sentence: | |
| continue | |
| # Use <p> for regular text | |
| blocks.append(f"[{idx}] <p>{sentence}</p>") | |
| content_indices.append(idx) | |
| if title in sf_lookup and sent_idx in sf_lookup[title]: | |
| relevant_indices.append(idx) | |
| idx += 1 | |
| # Inject noise between documents sometimes | |
| if inject_noise and random.random() < 0.4 and doc_idx < len(titles) - 1: | |
| noise = random.choice(NOISE_BLOCKS) | |
| blocks.append(f"[{idx}] {noise}") | |
| idx += 1 | |
| # Sometimes add noise at start and end | |
| if inject_noise: | |
| prefix_noise = [] | |
| if random.random() < 0.5: | |
| for _ in range(random.randint(1, 3)): | |
| noise = random.choice(NOISE_BLOCKS) | |
| prefix_noise.append(noise) | |
| suffix_noise = [] | |
| if random.random() < 0.5: | |
| for _ in range(random.randint(1, 3)): | |
| noise = random.choice(NOISE_BLOCKS) | |
| suffix_noise.append(noise) | |
| if prefix_noise or suffix_noise: | |
| # Reindex everything | |
| new_blocks = [] | |
| new_relevant = [] | |
| new_content = [] | |
| new_idx = 1 | |
| # Prefix noise | |
| for noise in prefix_noise: | |
| new_blocks.append(f"[{new_idx}] {noise}") | |
| new_idx += 1 | |
| # Remap original blocks | |
| offset = len(prefix_noise) | |
| for b in blocks: | |
| old_idx = int(b.split(']')[0].replace('[', '')) | |
| new_b = f"[{old_idx + offset}] " + '] '.join(b.split('] ')[1:]) | |
| new_blocks.append(new_b) | |
| new_relevant = [r + offset for r in relevant_indices] | |
| new_content = [c + offset for c in content_indices] | |
| # Suffix noise | |
| next_idx = len(new_blocks) + 1 | |
| for noise in suffix_noise: | |
| new_blocks.append(f"[{next_idx}] {noise}") | |
| next_idx += 1 | |
| blocks = new_blocks | |
| relevant_indices = new_relevant | |
| content_indices = new_content | |
| block_text = "\n".join(blocks) | |
| return block_text, relevant_indices, content_indices | |
| def build_query_relevant_example(question, block_text, relevant_indices, url="https://en.wikipedia.org"): | |
| """Build a query-relevant extraction (QE) example.""" | |
| intervals = indices_to_intervals(relevant_indices) | |
| user_content = f"URL: {url}\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT_QE}, | |
| {"role": "user", "content": user_content}, | |
| {"role": "assistant", "content": intervals} | |
| ] | |
| return messages | |
| def build_main_content_example(block_text, content_indices, title="Wikipedia Article", url="https://en.wikipedia.org"): | |
| """Build a main content extraction (ME) example.""" | |
| intervals = indices_to_intervals(content_indices) | |
| user_content = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT_ME}, | |
| {"role": "user", "content": user_content}, | |
| {"role": "assistant", "content": intervals} | |
| ] | |
| return messages | |
| def process_hotpotqa(): | |
| """Process HotpotQA into IndexLM training data.""" | |
| print("Loading HotpotQA...") | |
| ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train") | |
| # Sample a manageable amount | |
| num_samples = min(15000, len(ds)) | |
| ds = ds.shuffle(seed=42).select(range(num_samples)) | |
| all_examples = [] | |
| skipped = 0 | |
| for i, row in enumerate(ds): | |
| if i % 1000 == 0: | |
| print(f"Processing {i}/{num_samples}...") | |
| try: | |
| block_text, relevant_indices, content_indices = create_indexed_blocks_from_hotpotqa( | |
| row['context'], row['supporting_facts'], inject_noise=True | |
| ) | |
| # Skip if too few relevant indices | |
| if len(relevant_indices) < 1: | |
| skipped += 1 | |
| continue | |
| # Query-relevant extraction example | |
| qe_messages = build_query_relevant_example( | |
| row['question'], block_text, relevant_indices | |
| ) | |
| all_examples.append({ | |
| "messages": qe_messages, | |
| "task_type": "query_relevant", | |
| "source": "hotpotqa" | |
| }) | |
| # Main content extraction example (50% of the time) | |
| if random.random() < 0.5: | |
| me_messages = build_main_content_example( | |
| block_text, content_indices, | |
| title=row['context']['title'][0] if row['context']['title'] else "Article" | |
| ) | |
| all_examples.append({ | |
| "messages": me_messages, | |
| "task_type": "main_content", | |
| "source": "hotpotqa" | |
| }) | |
| except Exception as e: | |
| skipped += 1 | |
| if skipped < 5: | |
| print(f"Error on row {i}: {e}") | |
| continue | |
| print(f"Created {len(all_examples)} examples from HotpotQA ({skipped} skipped)") | |
| return all_examples | |
| def create_synthetic_web_pages(): | |
| """Create synthetic web page examples for main content extraction training.""" | |
| print("Creating synthetic web page examples...") | |
| # Load a text dataset to get content | |
| ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") | |
| ds = ds.shuffle(seed=123).select(range(3000)) | |
| examples = [] | |
| for i, row in enumerate(ds): | |
| if i % 500 == 0: | |
| print(f"Synthetic page {i}/3000...") | |
| try: | |
| # Build a more realistic web page structure | |
| titles = row['context']['title'] | |
| sentences_list = row['context']['sentences'] | |
| if not titles or not sentences_list: | |
| continue | |
| blocks = [] | |
| content_indices = [] | |
| idx = 1 | |
| # Header noise (nav, etc.) | |
| num_header_noise = random.randint(1, 4) | |
| for _ in range(num_header_noise): | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") | |
| idx += 1 | |
| # Page title | |
| main_title = titles[0] | |
| blocks.append(f"[{idx}] <h1>{main_title}</h1>") | |
| content_indices.append(idx) | |
| idx += 1 | |
| # Main content (just first 1-3 documents) | |
| num_docs = min(random.randint(1, 3), len(titles)) | |
| for doc_idx in range(num_docs): | |
| title = titles[doc_idx] | |
| sents = sentences_list[doc_idx] | |
| if doc_idx > 0: | |
| blocks.append(f"[{idx}] <h2>{title}</h2>") | |
| content_indices.append(idx) | |
| idx += 1 | |
| for sent in sents: | |
| sent = sent.strip() | |
| if not sent: | |
| continue | |
| blocks.append(f"[{idx}] <p>{sent}</p>") | |
| content_indices.append(idx) | |
| idx += 1 | |
| # Occasional inline noise | |
| if random.random() < 0.3: | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") | |
| idx += 1 | |
| # Footer noise | |
| num_footer_noise = random.randint(1, 4) | |
| for _ in range(num_footer_noise): | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS)}") | |
| idx += 1 | |
| block_text = "\n".join(blocks) | |
| me_messages = build_main_content_example( | |
| block_text, content_indices, | |
| title=main_title, | |
| url=f"https://en.wikipedia.org/wiki/{main_title.replace(' ', '_')}" | |
| ) | |
| examples.append({ | |
| "messages": me_messages, | |
| "task_type": "main_content", | |
| "source": "synthetic" | |
| }) | |
| except Exception as e: | |
| continue | |
| print(f"Created {len(examples)} synthetic web page examples") | |
| return examples | |
| def create_na_examples(): | |
| """Create examples where no relevant content exists (model should output 'NA').""" | |
| print("Creating NA examples...") | |
| ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="validation") | |
| ds = ds.shuffle(seed=456).select(range(1000)) | |
| examples = [] | |
| for i, row in enumerate(ds): | |
| try: | |
| # Use context from one question but query from another (mismatched) | |
| other_idx = (i + 500) % len(ds) | |
| other_question = ds[other_idx]['question'] | |
| # Build blocks from current context but keep only non-supporting content | |
| block_text, _, content_indices = create_indexed_blocks_from_hotpotqa( | |
| row['context'], {'title': [], 'sent_id': []}, inject_noise=True | |
| ) | |
| user_content = f"URL: https://en.wikipedia.org\nQuery: {other_question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT_QE}, | |
| {"role": "user", "content": user_content}, | |
| {"role": "assistant", "content": "NA"} | |
| ] | |
| examples.append({ | |
| "messages": messages, | |
| "task_type": "query_relevant_na", | |
| "source": "hotpotqa_mismatched" | |
| }) | |
| except: | |
| continue | |
| # Keep only a fraction (the paper mentions partial filtering of NA) | |
| random.shuffle(examples) | |
| examples = examples[:300] | |
| print(f"Created {len(examples)} NA examples") | |
| return examples | |
| def main(): | |
| # Build all training examples | |
| qe_examples = process_hotpotqa() | |
| me_examples = create_synthetic_web_pages() | |
| na_examples = create_na_examples() | |
| all_examples = qe_examples + me_examples + na_examples | |
| random.shuffle(all_examples) | |
| print(f"\nTotal examples: {len(all_examples)}") | |
| # Count by type | |
| type_counts = defaultdict(int) | |
| for ex in all_examples: | |
| type_counts[ex['task_type']] += 1 | |
| for t, c in type_counts.items(): | |
| print(f" {t}: {c}") | |
| # Check lengths | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") | |
| lengths = [] | |
| for ex in all_examples[:500]: | |
| text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) | |
| tokens = tokenizer.encode(text) | |
| lengths.append(len(tokens)) | |
| print(f"\nToken length stats (sample of 500):") | |
| print(f" Min: {min(lengths)}") | |
| print(f" Max: {max(lengths)}") | |
| print(f" Mean: {sum(lengths)/len(lengths):.0f}") | |
| print(f" Median: {sorted(lengths)[len(lengths)//2]}") | |
| # Filter out examples that are too long (>4096 tokens for efficiency) | |
| MAX_LEN = 4096 | |
| filtered = [] | |
| too_long = 0 | |
| for ex in all_examples: | |
| text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) <= MAX_LEN: | |
| filtered.append(ex) | |
| else: | |
| too_long += 1 | |
| print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)") | |
| print(f"Final dataset: {len(filtered)} examples") | |
| # Split into train/eval | |
| random.shuffle(filtered) | |
| eval_size = min(500, len(filtered) // 10) | |
| train_data = filtered[:-eval_size] | |
| eval_data = filtered[-eval_size:] | |
| print(f"Train: {len(train_data)}, Eval: {len(eval_data)}") | |
| # Create HF dataset with just messages column (for SFTTrainer) | |
| train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data]) | |
| eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data]) | |
| # Save locally | |
| train_ds.save_to_disk("/app/indexlm_train") | |
| eval_ds.save_to_disk("/app/indexlm_eval") | |
| # Also push to HF Hub | |
| from datasets import DatasetDict | |
| import os | |
| ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds}) | |
| ds_dict.push_to_hub("OmAlve/indexlm-training-data", token=os.environ.get("HF_TOKEN")) | |
| print("\nDone! Dataset pushed to OmAlve/indexlm-training-data") | |
| if __name__ == "__main__": | |
| main() | |
| ``` | |
| ### 5.2 Training Script (`train_indexlm.py`) | |
| ```python | |
| """ | |
| IndexLM Training Script - Fine-tune Qwen3-0.6B for Index-based Web Content Extraction | |
| Based on: "An Index-based Approach for Efficient and Effective Web Content Extraction" (arxiv:2512.06641) | |
| Base model: Qwen/Qwen3-0.6B (0.6B params, ideal for CPU deployment) | |
| Training method: SFT with TRL SFTTrainer | |
| Dataset: OmAlve/indexlm-training-data (25K+ examples) | |
| """ | |
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import SFTTrainer, SFTConfig | |
| import trackio | |
| # ============ Configuration ============ | |
| MODEL_ID = "Qwen/Qwen3-0.6B" | |
| DATASET_ID = "OmAlve/indexlm-training-data" | |
| OUTPUT_DIR = "./indexlm-0.6b" | |
| HUB_MODEL_ID = "OmAlve/IndexLM-0.6B" | |
| # Training hyperparameters (from paper: standard SFT) | |
| LEARNING_RATE = 2e-5 | |
| NUM_EPOCHS = 3 | |
| BATCH_SIZE = 4 | |
| GRAD_ACCUM = 4 # Effective batch size = 16 | |
| MAX_SEQ_LENGTH = 4096 | |
| WARMUP_RATIO = 0.05 | |
| # ============ Setup Trackio ============ | |
| trackio.init( | |
| name="indexlm-0.6b-training", | |
| project="indexlm" | |
| ) | |
| # ============ Load Dataset ============ | |
| print("Loading dataset...") | |
| dataset = load_dataset(DATASET_ID) | |
| train_dataset = dataset["train"] | |
| eval_dataset = dataset["eval"] | |
| print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}") | |
| # ============ Load Model & Tokenizer ============ | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| # Ensure padding token is set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", # Change to "sdpa" if flash-attn unavailable | |
| ) | |
| print(f"Model loaded: {MODEL_ID}") | |
| print(f"Model params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M") | |
| # ============ Training Config ============ | |
| training_args = SFTConfig( | |
| output_dir=OUTPUT_DIR, | |
| num_train_epochs=NUM_EPOCHS, | |
| per_device_train_batch_size=BATCH_SIZE, | |
| per_device_eval_batch_size=BATCH_SIZE, | |
| gradient_accumulation_steps=GRAD_ACCUM, | |
| learning_rate=LEARNING_RATE, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=WARMUP_RATIO, | |
| weight_decay=0.01, | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| max_length=MAX_SEQ_LENGTH, | |
| # Logging | |
| logging_steps=10, | |
| logging_first_step=True, | |
| logging_strategy="steps", | |
| disable_tqdm=True, | |
| # Evaluation | |
| eval_strategy="steps", | |
| eval_steps=500, | |
| # Saving | |
| save_strategy="steps", | |
| save_steps=500, | |
| save_total_limit=3, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| greater_is_better=False, | |
| # Hub push | |
| push_to_hub=True, | |
| hub_model_id=HUB_MODEL_ID, | |
| hub_strategy="every_save", | |
| # Performance | |
| dataloader_num_workers=4, | |
| dataloader_pin_memory=True, | |
| # Report | |
| report_to="none", | |
| # Seed | |
| seed=42, | |
| ) | |
| # ============ Initialize Trainer ============ | |
| print("Initializing trainer...") | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| processing_class=tokenizer, | |
| ) | |
| # ============ Train ============ | |
| print("Starting training...") | |
| train_result = trainer.train() | |
| # ============ Save Final Model ============ | |
| print("Saving final model...") | |
| trainer.save_model(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| # Push to Hub | |
| print("Pushing to Hub...") | |
| trainer.push_to_hub(commit_message="Final IndexLM-0.6B model") | |
| # ============ Log Final Metrics ============ | |
| metrics = train_result.metrics | |
| print(f"\nTraining complete!") | |
| print(f" Train loss: {metrics.get('train_loss', 'N/A')}") | |
| print(f" Train runtime: {metrics.get('train_runtime', 'N/A'):.0f}s") | |
| print(f" Train samples/sec: {metrics.get('train_samples_per_second', 'N/A'):.1f}") | |
| # Final eval | |
| eval_metrics = trainer.evaluate() | |
| print(f" Eval loss: {eval_metrics.get('eval_loss', 'N/A')}") | |
| print(f"\nModel pushed to: https://huggingface.co/{HUB_MODEL_ID}") | |
| ``` | |
| ### 5.3 Evaluation Script (`eval_indexlm.py`) | |
| ```python | |
| """ | |
| IndexLM Evaluation Script | |
| Tests the trained model on: | |
| 1. Query-relevant extraction (QE) - F1/Precision/Recall | |
| 2. Main content extraction (ME) - F1/Precision/Recall | |
| 3. Inference speed on CPU | |
| """ | |
| import json | |
| import time | |
| import os | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def parse_intervals(text): | |
| """Parse interval string like '[[1,3],[5,7]]' into a set of indices.""" | |
| text = text.strip() | |
| if text.upper() == 'NA' or not text: | |
| return set() | |
| try: | |
| intervals = json.loads(text) | |
| indices = set() | |
| for start, end in intervals: | |
| indices.update(range(start, end + 1)) | |
| return indices | |
| except (json.JSONDecodeError, TypeError, ValueError): | |
| return set() | |
| def compute_f1(pred_indices, gold_indices): | |
| """Compute F1, precision, recall between two sets of indices.""" | |
| if not pred_indices and not gold_indices: | |
| return 1.0, 1.0, 1.0 | |
| if not pred_indices or not gold_indices: | |
| return 0.0, 0.0, 0.0 | |
| tp = len(pred_indices & gold_indices) | |
| precision = tp / len(pred_indices) if pred_indices else 0 | |
| recall = tp / len(gold_indices) if gold_indices else 0 | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | |
| return f1, precision, recall | |
| def generate_response(model, tokenizer, messages, device, max_new_tokens=128): | |
| """Generate model response for given messages.""" | |
| text = tokenizer.apply_chat_template( | |
| messages[:-1], # Exclude assistant message (ground truth) | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=4096).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, # Greedy for deterministic eval | |
| temperature=1.0, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| # Decode only the new tokens | |
| new_tokens = outputs[0][inputs['input_ids'].shape[1]:] | |
| response = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| return response.strip() | |
| def evaluate_model(model_id, device="cpu", num_samples=100): | |
| """Run full evaluation.""" | |
| print(f"\n{'='*60}") | |
| print(f"Evaluating: {model_id}") | |
| print(f"Device: {device}") | |
| print(f"{'='*60}") | |
| # Load model | |
| print("Loading model...") | |
| dtype = torch.float32 if device == "cpu" else torch.bfloat16 | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| attn_implementation="sdpa", | |
| ).to(device) | |
| model.eval() | |
| # Load eval dataset | |
| print("Loading eval dataset...") | |
| dataset = load_dataset("OmAlve/indexlm-training-data", split="eval") | |
| # Sample | |
| if len(dataset) > num_samples: | |
| dataset = dataset.shuffle(seed=42).select(range(num_samples)) | |
| # Categorize examples | |
| qe_examples = [] | |
| me_examples = [] | |
| for row in dataset: | |
| msgs = row['messages'] | |
| system_msg = msgs[0]['content'] if msgs[0]['role'] == 'system' else '' | |
| if 'query' in system_msg.lower() and 'relevant' in system_msg.lower(): | |
| qe_examples.append(msgs) | |
| else: | |
| me_examples.append(msgs) | |
| print(f"QE examples: {len(qe_examples)}, ME examples: {len(me_examples)}") | |
| # Evaluate QE | |
| print("\n--- Query-Relevant Extraction (QE) ---") | |
| qe_metrics = evaluate_task(model, tokenizer, qe_examples[:50], device) | |
| # Evaluate ME | |
| print("\n--- Main Content Extraction (ME) ---") | |
| me_metrics = evaluate_task(model, tokenizer, me_examples[:50], device) | |
| # Speed test | |
| print("\n--- Inference Speed Test ---") | |
| speed_test(model, tokenizer, qe_examples[:20], device) | |
| return qe_metrics, me_metrics | |
| def evaluate_task(model, tokenizer, examples, device): | |
| """Evaluate on a set of examples.""" | |
| if not examples: | |
| print("No examples for this task.") | |
| return {} | |
| f1_scores = [] | |
| precision_scores = [] | |
| recall_scores = [] | |
| exact_matches = 0 | |
| for i, msgs in enumerate(examples): | |
| gold = msgs[-1]['content'] | |
| gold_indices = parse_intervals(gold) | |
| pred = generate_response(model, tokenizer, msgs, device) | |
| pred_indices = parse_intervals(pred) | |
| f1, prec, rec = compute_f1(pred_indices, gold_indices) | |
| f1_scores.append(f1) | |
| precision_scores.append(prec) | |
| recall_scores.append(rec) | |
| if pred_indices == gold_indices: | |
| exact_matches += 1 | |
| if i < 3: | |
| print(f" Example {i+1}:") | |
| print(f" Gold: {gold}") | |
| print(f" Pred: {pred}") | |
| print(f" F1: {f1:.3f}, P: {prec:.3f}, R: {rec:.3f}") | |
| avg_f1 = sum(f1_scores) / len(f1_scores) * 100 | |
| avg_prec = sum(precision_scores) / len(precision_scores) * 100 | |
| avg_rec = sum(recall_scores) / len(recall_scores) * 100 | |
| em_rate = exact_matches / len(examples) * 100 | |
| print(f"\n Results ({len(examples)} examples):") | |
| print(f" F1: {avg_f1:.2f}") | |
| print(f" Precision: {avg_prec:.2f}") | |
| print(f" Recall: {avg_rec:.2f}") | |
| print(f" Exact Match: {em_rate:.2f}%") | |
| return {"f1": avg_f1, "precision": avg_prec, "recall": avg_rec, "exact_match": em_rate} | |
| def speed_test(model, tokenizer, examples, device): | |
| """Test inference speed.""" | |
| if not examples: | |
| return | |
| times = [] | |
| for msgs in examples: | |
| start = time.time() | |
| _ = generate_response(model, tokenizer, msgs, device) | |
| elapsed = time.time() - start | |
| times.append(elapsed) | |
| avg_time = sum(times) / len(times) | |
| print(f" Average inference time: {avg_time:.3f}s ({device})") | |
| print(f" Min: {min(times):.3f}s, Max: {max(times):.3f}s") | |
| print(f" Throughput: {1/avg_time:.1f} pages/sec") | |
| if __name__ == "__main__": | |
| model_id = os.environ.get("MODEL_ID", "OmAlve/IndexLM-0.6B") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| evaluate_model(model_id, device=device, num_samples=100) | |
| ``` | |
| --- | |
| ## 6. Key Design Decisions & Rationale | |
| ### Why Qwen3-0.6B? | |
| - The paper uses Qwen3-0.6B/1.7B/4B. The 0.6B achieves **near-identical performance** to 4B on RAG QA (54.70 vs 55.41 avg F1) | |
| - 0.6B is **1.4GB in bf16, ~700MB in INT4** — runs fast on CPU | |
| - TRL's own SFT documentation uses Qwen3-0.6B as its default example model — maximum compatibility | |
| - Qwen3 has GQA (grouped-query attention) which is faster for inference than MHA | |
| ### Why not ReaderLM-v2? | |
| - ReaderLM-v2 does generative HTML→Markdown extraction (different task) | |
| - It's **33-70× slower** than IndexLM on the paper's benchmarks | |
| - Fine-tuning it for index prediction would fight against its pretrained generation behavior | |
| ### Dataset construction vs. the paper | |
| The paper uses: | |
| 1. Google Search API crawls → real HTML from the web | |
| 2. DeepSeek V3 annotation with 5-run majority voting | |
| 3. Common Crawl WARC files | |
| We approximate this with: | |
| 1. HotpotQA's structured context (title + sentences) converted to indexed HTML blocks | |
| 2. Programmatic labeling from HotpotQA's `supporting_facts` ground truth (higher quality than LLM annotation) | |
| 3. Synthetic noise injection (nav, ads, cookies, etc.) to simulate real web clutter | |
| 4. Mismatched query-page pairs for NA examples | |
| **Trade-off**: Our HTML blocks are simpler than real web HTML (no nested tables, complex CSS-in-JS, etc.). For production use, augmenting with real crawled HTML would improve robustness. The paper's full pipeline would require API costs (Google Search, DeepSeek V3). | |
| ### Hyperparameters | |
| Directly from the paper Section 3.3.2: "The training process is a typical SFT process" on Qwen3. We use: | |
| - lr=2e-5 (TRL SFT default, standard for Qwen3) | |
| - 3 epochs (standard SFT) | |
| - Effective batch size 16 (4 × 4 grad accum) | |
| - Cosine LR schedule with 5% warmup | |
| - max_length=4096 (covers 99.8% of our data, well within Qwen3's 32K context) | |
| --- | |
| ## 7. What's Left To Do | |
| | Task | Status | Notes | | |
| |------|--------|-------| | |
| | Run `train_indexlm.py` | ❌ | Needs GPU — a10g-large recommended (~$8 total) | | |
| | Run `eval_indexlm.py` | ❌ | After training completes | | |
| | ONNX export for CPU | ❌ | Optional: `optimum-cli export onnx --model OmAlve/IndexLM-0.6B indexlm-onnx/` | | |
| | INT4 quantization | ❌ | Optional: use `bitsandbytes` or `llama.cpp` for faster CPU | | |
| | Real HTML augmentation | ❌ | Optional: crawl real web pages to augment training data | | |
| --- | |
| ## 8. Resources | |
| | Resource | URL | | |
| |----------|-----| | |
| | Paper | https://arxiv.org/abs/2512.06641 | | |
| | Training dataset | https://huggingface.co/datasets/OmAlve/indexlm-training-data | | |
| | Base model | https://huggingface.co/Qwen/Qwen3-0.6B | | |
| | Output model (after training) | https://huggingface.co/OmAlve/IndexLM-0.6B | | |
| | TRL SFT docs | https://huggingface.co/docs/trl/sft_trainer | | |
| | HotpotQA source | https://huggingface.co/datasets/hotpotqa/hotpot_qa | | |
| --- | |
| ## 9. Dependencies | |
| ``` | |
| transformers>=4.51.0 | |
| trl>=1.2.0 | |
| torch | |
| datasets | |
| accelerate | |
| trackio | |
| flash-attn # optional, GPU training only | |
| beautifulsoup4 # only for prepare_data.py | |
| lxml # only for prepare_data.py | |
| ``` | |