Text Generation
Transformers
TensorBoard
Safetensors
qwen3
Generated from Trainer
trl
sft
trackio
conversational
text-generation-inference
Instructions to use OmAlve/reading-steiner with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use OmAlve/reading-steiner with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="OmAlve/reading-steiner") messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("OmAlve/reading-steiner") model = AutoModelForCausalLM.from_pretrained("OmAlve/reading-steiner") messages = [ {"role": "user", "content": "Who are you?"}, ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=40) print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])) - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use OmAlve/reading-steiner with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "OmAlve/reading-steiner" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "OmAlve/reading-steiner", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/OmAlve/reading-steiner
- SGLang
How to use OmAlve/reading-steiner with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "OmAlve/reading-steiner" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "OmAlve/reading-steiner", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "OmAlve/reading-steiner" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "OmAlve/reading-steiner", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use OmAlve/reading-steiner with Docker Model Runner:
docker model run hf.co/OmAlve/reading-steiner
| """ | |
| Prepare DIVERSE IndexLM training data from multiple sources: | |
| 1. HtmlRAG-train (real Bing-scraped web HTML) — diverse domains | |
| 2. MultiHopRAG (news domain) — technology, business, sports, entertainment | |
| 3. HotpotQA (Wikipedia) — structured QA with supporting facts | |
| This avoids the Wikipedia-only bias of the original dataset. | |
| Output: Conversational messages for SFT with TRL SFTTrainer | |
| Format: system + user (indexed HTML blocks + query) → assistant (index intervals) | |
| """ | |
| import json | |
| import random | |
| import re | |
| import os | |
| from datasets import load_dataset, Dataset, DatasetDict | |
| from collections import defaultdict | |
| from bs4 import BeautifulSoup | |
| import html as html_lib | |
| random.seed(42) | |
| # ============ System Prompts ============ | |
| SYSTEM_PROMPT_QE = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks and a user query, identify which blocks contain content relevant to the query. | |
| Each block is formatted as: [i] <tag>content</tag> | |
| Output the indices of relevant blocks as a Python list of [start, end] intervals (inclusive). | |
| If no relevant content exists, output 'NA'. | |
| Example output: [[2,4],[7,7],[10,12]]""" | |
| SYSTEM_PROMPT_ME = """You are IndexLM, a web content extraction model. Given a webpage split into indexed blocks, identify which blocks contain the main content of the page (filtering out navigation, advertisements, sidebars, and other non-content elements). | |
| Each block is formatted as: [i] <tag>content</tag> | |
| Output the indices of main content blocks as a Python list of [start, end] intervals (inclusive). | |
| If no main content exists, output 'NA'. | |
| Example output: [[1,3],[5,8],[11,15]]""" | |
| # ============ Noise blocks for injection ============ | |
| NOISE_BLOCKS_REALISTIC = [ | |
| '<nav>Home | About | Contact | Privacy Policy | Terms of Service</nav>', | |
| '<div class="ad">Advertisement - Continue Reading Below</div>', | |
| '<div class="sidebar">Related Articles: Top 10 Facts You Didn\'t Know</div>', | |
| '<footer>© 2024 All Rights Reserved | Terms of Service | Cookie Policy</footer>', | |
| '<div class="cookie-banner">This website uses cookies to improve your experience. By continuing to use this site, you consent to our use of cookies. Accept | Manage Preferences</div>', | |
| '<div class="social-share">Share: <a>Twitter</a> | <a>Facebook</a> | <a>LinkedIn</a> | <a>Reddit</a> | <a>Email</a></div>', | |
| '<nav class="breadcrumb">Home > Category > Subcategory > Current Article</nav>', | |
| '<div class="newsletter-signup">Subscribe to our newsletter for the latest updates delivered to your inbox weekly.</div>', | |
| '<div class="popup-overlay">Sign up for free access to premium content! Enter your email below.</div>', | |
| '<aside class="trending">Trending Now: Latest breaking news and popular stories from around the web</aside>', | |
| '<div class="comments-section">Comments (0) — Be the first to comment! Please read our community guidelines before posting.</div>', | |
| '<div class="author-bio">Written by Staff Reporter | Updated: January 15, 2024 | 5 min read</div>', | |
| '<div class="pagination">← Previous Article | Page 1 of 3 | Next Article →</div>', | |
| '<div class="search-bar"><form>Search this site... <button>Go</button></form></div>', | |
| '<div class="category-menu">Categories: Science | Technology | Health | Business | Sports | Entertainment | Politics</div>', | |
| '<div class="login-prompt">Already a subscriber? Log in for full access. Not a member? Subscribe now starting at $4.99/month.</div>', | |
| '<div class="related-articles"><h3>You May Also Like</h3><ul><li>10 Things You Didn\'t Know About...</li><li>Breaking: Latest Update on...</li></ul></div>', | |
| '<div class="video-embed">Watch: Video player requires JavaScript to be enabled. [Video placeholder]</div>', | |
| '<div class="breaking-news-ticker">BREAKING: Markets rally on latest economic data | Sports: Championship results | Weather: Storm warning issued</div>', | |
| '<div class="accessibility">Skip to main content | Skip to navigation | Accessibility statement</div>', | |
| '<div class="gdpr-notice">We value your privacy. We and our partners use tracking technologies to improve your browsing experience, serve personalized content, and analyze traffic.</div>', | |
| '<div class="app-download">Download our app for a better reading experience! Available on iOS and Android.</div>', | |
| '<script>/* Google Analytics tracking code */</script>', | |
| '<div class="print-notice">This article is available in print edition. Subscribe for home delivery.</div>', | |
| '<div class="sponsored">Sponsored Content | Advertiser Disclosure: Some links on this page are affiliate links.</div>', | |
| '<div class="feedback">Was this article helpful? Yes | No | Send Feedback</div>', | |
| '<div class="language-selector">Language: English | Español | Français | Deutsch | 日本語 | 中文</div>', | |
| '<div class="site-footer"><ul><li>About Us</li><li>Careers</li><li>Advertise</li><li>Press</li><li>Help Center</li><li>Sitemap</li></ul></div>', | |
| ] | |
| def indices_to_intervals(indices): | |
| """Convert a sorted list of indices to intervals [[start,end], ...]""" | |
| if not indices: | |
| return "NA" | |
| indices = sorted(set(indices)) | |
| intervals = [] | |
| start = indices[0] | |
| end = indices[0] | |
| for i in indices[1:]: | |
| if i == end + 1: | |
| end = i | |
| else: | |
| intervals.append([start, end]) | |
| start = i | |
| end = i | |
| intervals.append([start, end]) | |
| return json.dumps(intervals) | |
| # ============================================================ | |
| # SOURCE 1: HtmlRAG-train (Real Bing-scraped web HTML) | |
| # ============================================================ | |
| def extract_text_content(html_str): | |
| """Extract visible text from an HTML string.""" | |
| try: | |
| soup = BeautifulSoup(html_str, 'html.parser') | |
| return soup.get_text(separator=' ', strip=True) | |
| except: | |
| # Fallback: strip tags with regex | |
| clean = re.sub(r'<[^>]+>', ' ', html_str) | |
| return re.sub(r'\s+', ' ', clean).strip() | |
| def segment_html_to_blocks(html_content): | |
| """ | |
| Segment real HTML content into indexed blocks. | |
| Splits by block-level HTML tags and line boundaries. | |
| """ | |
| blocks = [] | |
| # Strategy: split by block-level closing/opening tags | |
| # HtmlRAG uses tags like <div0>, <p>, <h20>, <li>, etc. | |
| # Split at positions where block-level tags start | |
| block_tag_pattern = r'(<(?:div|p|h[1-6]|li|ul|ol|table|tr|td|th|article|section|header|footer|nav|aside|main|blockquote|pre|form|figure|figcaption|details|summary|option|title|button|label|select|textarea|hgroup|dl|dd|dt|caption|thead|tbody|tfoot)\b[^>]*>)' | |
| # Also handle HtmlRAG numbered tags like <div0>, <h20>, etc. | |
| block_tag_pattern_numbered = r'(<(?:div|p|h|li|ul|ol|table|tr|td|th|article|section|header|footer|nav|aside|main|blockquote|pre|form|figure|option|title|button|hgroup)\d*[^>]*>)' | |
| # Split content by block-level tags | |
| parts = re.split(block_tag_pattern_numbered, html_content) | |
| current_block = '' | |
| for part in parts: | |
| part = part.strip() | |
| if not part: | |
| continue | |
| # Check if this part is a block-level opening tag | |
| if re.match(block_tag_pattern_numbered, part): | |
| # Save previous block if it has content | |
| if current_block.strip(): | |
| blocks.append(current_block.strip()) | |
| current_block = part | |
| else: | |
| current_block += ' ' + part | |
| # Don't forget the last block | |
| if current_block.strip(): | |
| blocks.append(current_block.strip()) | |
| # If tag-based splitting yields too few blocks, fall back to line-based | |
| if len(blocks) < 5: | |
| blocks = [] | |
| lines = html_content.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line and len(line) > 5: | |
| blocks.append(line) | |
| # If still too few, split by multiple tags on same line | |
| if len(blocks) < 5: | |
| new_blocks = [] | |
| for block in blocks: | |
| # Try splitting long blocks by inner tags | |
| if len(block) > 200: | |
| inner_parts = re.split(r'(</(?:div|p|h[1-6]|li|td|th|article|section)\d*>)', block) | |
| current = '' | |
| for ip in inner_parts: | |
| current += ip | |
| if re.match(r'</(?:div|p|h[1-6]|li|td|th|article|section)\d*>', ip): | |
| if current.strip(): | |
| new_blocks.append(current.strip()) | |
| current = '' | |
| if current.strip(): | |
| new_blocks.append(current.strip()) | |
| else: | |
| new_blocks.append(block) | |
| if len(new_blocks) > len(blocks): | |
| blocks = new_blocks | |
| # Filter: extract text and remove blocks with no meaningful content | |
| def extract_text_simple(s): | |
| clean = re.sub(r'<[^>]+>', ' ', s) | |
| return re.sub(r'\s+', ' ', clean).strip() | |
| blocks = [b for b in blocks if len(extract_text_simple(b)) > 5] | |
| return blocks | |
| def classify_block_as_noise(block_text): | |
| """Heuristic: classify if a block is likely noise (nav, ad, etc.).""" | |
| text_lower = block_text.lower() | |
| noise_indicators = [ | |
| 'cookie', 'privacy policy', 'terms of service', 'advertisement', | |
| 'subscribe', 'newsletter', 'sign up', 'log in', 'login', | |
| 'copyright ©', 'all rights reserved', 'skip to', 'accessibility', | |
| 'share on twitter', 'share on facebook', 'social media', | |
| 'related articles', 'you may also like', 'trending now', | |
| 'app download', 'sponsored content', 'affiliate', | |
| ] | |
| nav_patterns = ['<nav', '<footer', '<aside', 'class="ad"', 'class="sidebar"', | |
| 'class="menu"', 'class="social"', 'class="cookie"'] | |
| for indicator in noise_indicators: | |
| if indicator in text_lower: | |
| return True | |
| for pattern in nav_patterns: | |
| if pattern in text_lower: | |
| return True | |
| return False | |
| def process_htmlrag_example(row): | |
| """Convert an HtmlRAG example to IndexLM format.""" | |
| user_content = row['messages'][0]['content'] | |
| assistant_content = row['messages'][1]['content'] | |
| score = row.get('score', 0) | |
| # Skip low-quality examples | |
| if score < 0.5: | |
| return None | |
| # Parse out HTML and question | |
| parts = user_content.split('**Question**:') | |
| if len(parts) < 2: | |
| parts = user_content.split('**Question**') | |
| if len(parts) < 2: | |
| return None | |
| html_raw = parts[0] | |
| question_raw = parts[1].strip() | |
| # Clean up the HTML marker | |
| html_raw = html_raw.replace('**HTML**: ```', '').rstrip('`').strip() | |
| # Extract just the question (remove the instruction part) | |
| question = question_raw.split('\n')[0].strip().strip('*').strip() | |
| if not question: | |
| return None | |
| # Segment HTML into blocks | |
| blocks = segment_html_to_blocks(html_raw) | |
| if len(blocks) < 3: | |
| return None | |
| # Get the relevant content from assistant output | |
| relevant_text = extract_text_content(assistant_content) | |
| relevant_words = set(relevant_text.lower().split()) | |
| # Build indexed blocks and find relevant ones | |
| indexed_blocks = [] | |
| relevant_indices = [] | |
| content_indices = [] | |
| for idx, block in enumerate(blocks, 1): | |
| # Determine the best tag for this block | |
| tag_match = re.match(r'<(\w+)', block) | |
| if tag_match: | |
| tag = tag_match.group(1) | |
| # Normalize numbered tags (div0 -> div, h20 -> h2) | |
| tag = re.sub(r'\d+$', '', tag) | |
| if not tag: | |
| tag = 'div' | |
| else: | |
| tag = 'p' | |
| text = extract_text_content(block) | |
| if not text or len(text) < 3: | |
| continue | |
| indexed_blocks.append(f"[{idx}] <{tag}>{text}</{tag}>") | |
| # Check if this block is noise | |
| is_noise = classify_block_as_noise(block) | |
| if not is_noise: | |
| content_indices.append(idx) | |
| # Check relevance by substring matching with assistant output | |
| # Use the full relevant text as a search target | |
| text_lower = text.lower() | |
| relevant_lower = relevant_text.lower() | |
| # Method 1: Check if significant portions of relevant text appear in block | |
| # Split relevant text into 3-word ngrams and check for matches | |
| rel_words_list = relevant_lower.split() | |
| matched = False | |
| # Check 3-gram overlap | |
| for i in range(len(rel_words_list) - 2): | |
| trigram = ' '.join(rel_words_list[i:i+3]) | |
| if trigram in text_lower: | |
| matched = True | |
| break | |
| # Also check: does the block text appear as a substring in the relevant text? | |
| if not matched and len(text) > 15: | |
| # Check if meaningful portion of block appears in relevant output | |
| block_sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 10] | |
| for sent in block_sentences: | |
| if sent.lower() in relevant_lower: | |
| matched = True | |
| break | |
| # Also check word overlap with a more lenient threshold | |
| if not matched: | |
| block_words = set(text_lower.split()) | |
| if relevant_words and block_words: | |
| overlap_count = len(block_words & relevant_words) | |
| # At least 3 content words overlap (excluding stopwords) | |
| stopwords = {'the','a','an','is','are','was','were','in','on','at','to','for','of','and','or','but','with','by','from','as','it','this','that','be','has','have','had','do','does','did','not','no'} | |
| content_overlap = len((block_words - stopwords) & (relevant_words - stopwords)) | |
| if content_overlap >= 2: | |
| matched = True | |
| if matched: | |
| relevant_indices.append(idx) | |
| if not indexed_blocks or len(indexed_blocks) < 3: | |
| return None | |
| block_text = "\n".join(indexed_blocks) | |
| results = [] | |
| # Query-relevant extraction example | |
| if relevant_indices: | |
| intervals = indices_to_intervals(relevant_indices) | |
| user_msg = f"URL: https://example.com\nQuery: {question}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." | |
| results.append({ | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT_QE}, | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": intervals} | |
| ], | |
| "task_type": "query_relevant", | |
| "source": "htmlrag" | |
| }) | |
| # Main content extraction example (30% of the time to balance) | |
| if content_indices and random.random() < 0.3: | |
| intervals = indices_to_intervals(content_indices) | |
| user_msg = f"URL: https://example.com\nTitle: Web Page\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." | |
| results.append({ | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT_ME}, | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": intervals} | |
| ], | |
| "task_type": "main_content", | |
| "source": "htmlrag" | |
| }) | |
| return results | |
| def load_htmlrag_data(): | |
| """Load and convert HtmlRAG-train data.""" | |
| print("Loading HtmlRAG-train (real web HTML)...") | |
| # Use 4k and 8k token variants - good balance of context | |
| files = [ | |
| 'nq-4k.jsonl', 'nq-8k.jsonl', | |
| 'asqa-4k.jsonl', 'asqa-8k.jsonl', | |
| 'trivia-qa-4k.jsonl', 'trivia-qa-8k.jsonl', | |
| 'musique-4k.jsonl', 'musique-8k.jsonl', | |
| 'hotpot-qa-4k.jsonl', 'hotpot-qa-8k.jsonl', | |
| ] | |
| all_examples = [] | |
| for file in files: | |
| print(f" Processing {file}...") | |
| try: | |
| ds = load_dataset('zstanjj/HtmlRAG-train', data_files=file, split='train') | |
| count = 0 | |
| for row in ds: | |
| results = process_htmlrag_example(row) | |
| if results: | |
| all_examples.extend(results) | |
| count += len(results) | |
| print(f" Got {count} examples from {file}") | |
| except Exception as e: | |
| print(f" Error loading {file}: {e}") | |
| print(f" Total HtmlRAG examples: {len(all_examples)}") | |
| return all_examples | |
| # ============================================================ | |
| # SOURCE 2: MultiHopRAG (News domain) | |
| # ============================================================ | |
| def process_multihoprag(): | |
| """Convert MultiHopRAG news articles into IndexLM format.""" | |
| print("Loading MultiHopRAG (news domain)...") | |
| corpus = load_dataset("yixuantt/MultiHopRAG", name="corpus", split="train") | |
| queries = load_dataset("yixuantt/MultiHopRAG", name="MultiHopRAG", split="train") | |
| # Build URL->article lookup | |
| url_to_article = {} | |
| for article in corpus: | |
| url_to_article[article['url']] = article | |
| all_examples = [] | |
| for q_row in queries: | |
| query = q_row['query'] | |
| evidence_list = q_row['evidence_list'] | |
| for evidence in evidence_list: | |
| url = evidence.get('url', '') | |
| fact = evidence.get('fact', '') | |
| if url not in url_to_article or not fact: | |
| continue | |
| article = url_to_article[url] | |
| title = article.get('title', 'News Article') | |
| body = article.get('body', '') | |
| source = article.get('source', 'Unknown') | |
| category = article.get('category', 'general') | |
| if not body or len(body) < 100: | |
| continue | |
| # Split article body into paragraphs | |
| paragraphs = [p.strip() for p in body.split('\n') if p.strip() and len(p.strip()) > 20] | |
| if not paragraphs: | |
| continue | |
| # Build indexed blocks with realistic web structure | |
| blocks = [] | |
| content_indices = [] | |
| relevant_indices = [] | |
| idx = 1 | |
| # Add realistic header noise | |
| num_header = random.randint(1, 3) | |
| for _ in range(num_header): | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") | |
| idx += 1 | |
| # Article title | |
| blocks.append(f"[{idx}] <h1>{title}</h1>") | |
| content_indices.append(idx) | |
| idx += 1 | |
| # Author/date line | |
| author = article.get('author', 'Staff Writer') | |
| published = article.get('published_at', '2024-01-01') | |
| blocks.append(f"[{idx}] <div class=\"byline\">By {author} | {source} | {published} | Category: {category}</div>") | |
| content_indices.append(idx) | |
| idx += 1 | |
| # Article paragraphs | |
| fact_words = set(fact.lower().split()) | |
| for para in paragraphs: | |
| # Determine tag | |
| if len(para) < 60 and not para.endswith('.'): | |
| tag = 'h2' | |
| else: | |
| tag = 'p' | |
| blocks.append(f"[{idx}] <{tag}>{para}</{tag}>") | |
| content_indices.append(idx) | |
| # Check if paragraph contains the evidence fact | |
| para_words = set(para.lower().split()) | |
| overlap = len(para_words & fact_words) | |
| if overlap > 5 or (fact_words and overlap / len(fact_words) > 0.3): | |
| relevant_indices.append(idx) | |
| idx += 1 | |
| # Occasional mid-article noise | |
| if random.random() < 0.15: | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") | |
| idx += 1 | |
| # Footer noise | |
| num_footer = random.randint(1, 4) | |
| for _ in range(num_footer): | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") | |
| idx += 1 | |
| block_text = "\n".join(blocks) | |
| # Query-relevant extraction | |
| if relevant_indices: | |
| intervals = indices_to_intervals(relevant_indices) | |
| user_msg = f"URL: {url}\nQuery: {query}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." | |
| all_examples.append({ | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT_QE}, | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": intervals} | |
| ], | |
| "task_type": "query_relevant", | |
| "source": "multihoprag_news" | |
| }) | |
| # Main content extraction | |
| if content_indices and random.random() < 0.4: | |
| intervals = indices_to_intervals(content_indices) | |
| user_msg = f"URL: {url}\nTitle: {title}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." | |
| all_examples.append({ | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT_ME}, | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": intervals} | |
| ], | |
| "task_type": "main_content", | |
| "source": "multihoprag_news" | |
| }) | |
| print(f" Total MultiHopRAG examples: {len(all_examples)}") | |
| return all_examples | |
| # ============================================================ | |
| # SOURCE 3: HotpotQA (Wikipedia - but balanced as minority) | |
| # ============================================================ | |
| def process_hotpotqa(): | |
| """Process HotpotQA — kept but as a smaller proportion.""" | |
| print("Loading HotpotQA (Wikipedia domain)...") | |
| ds = load_dataset("hotpotqa/hotpot_qa", "distractor", split="train") | |
| # Reduced from 15K to 5K — wiki should be minority source | |
| num_samples = min(5000, len(ds)) | |
| ds = ds.shuffle(seed=42).select(range(num_samples)) | |
| all_examples = [] | |
| skipped = 0 | |
| for i, row in enumerate(ds): | |
| if i % 1000 == 0: | |
| print(f" Processing {i}/{num_samples}...") | |
| try: | |
| titles = row['context']['title'] | |
| sentences_list = row['context']['sentences'] | |
| sf = row['supporting_facts'] | |
| sf_lookup = defaultdict(set) | |
| for title, sent_id in zip(sf['title'], sf['sent_id']): | |
| sf_lookup[title].add(sent_id) | |
| blocks = [] | |
| relevant_indices = [] | |
| content_indices = [] | |
| idx = 1 | |
| # Header noise | |
| if random.random() < 0.6: | |
| for _ in range(random.randint(1, 3)): | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") | |
| idx += 1 | |
| for doc_idx, (title, sentences) in enumerate(zip(titles, sentences_list)): | |
| blocks.append(f"[{idx}] <h2>{title}</h2>") | |
| content_indices.append(idx) | |
| if title in sf_lookup: | |
| relevant_indices.append(idx) | |
| idx += 1 | |
| for sent_idx, sentence in enumerate(sentences): | |
| sentence = sentence.strip() | |
| if not sentence: | |
| continue | |
| blocks.append(f"[{idx}] <p>{sentence}</p>") | |
| content_indices.append(idx) | |
| if title in sf_lookup and sent_idx in sf_lookup[title]: | |
| relevant_indices.append(idx) | |
| idx += 1 | |
| if random.random() < 0.3 and doc_idx < len(titles) - 1: | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") | |
| idx += 1 | |
| # Footer noise | |
| if random.random() < 0.6: | |
| for _ in range(random.randint(1, 3)): | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") | |
| idx += 1 | |
| if len(relevant_indices) < 1: | |
| skipped += 1 | |
| continue | |
| block_text = "\n".join(blocks) | |
| # QE example | |
| intervals = indices_to_intervals(relevant_indices) | |
| user_msg = f"URL: https://en.wikipedia.org\nQuery: {row['question']}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." | |
| all_examples.append({ | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT_QE}, | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": intervals} | |
| ], | |
| "task_type": "query_relevant", | |
| "source": "hotpotqa_wiki" | |
| }) | |
| # ME example (less frequent - wiki is minority) | |
| if random.random() < 0.3: | |
| intervals = indices_to_intervals(content_indices) | |
| user_msg = f"URL: https://en.wikipedia.org\nTitle: {titles[0]}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of main content blocks." | |
| all_examples.append({ | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT_ME}, | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": intervals} | |
| ], | |
| "task_type": "main_content", | |
| "source": "hotpotqa_wiki" | |
| }) | |
| except Exception as e: | |
| skipped += 1 | |
| continue | |
| print(f" Total HotpotQA examples: {len(all_examples)} ({skipped} skipped)") | |
| return all_examples | |
| # ============================================================ | |
| # SOURCE 4: MS MARCO (Diverse web QA) | |
| # ============================================================ | |
| def process_msmarco(): | |
| """Process MS MARCO for diverse web domain QA examples.""" | |
| print("Loading MS MARCO (diverse web QA)...") | |
| try: | |
| ds = load_dataset("microsoft/ms_marco", "v1.1", split="train") | |
| # Sample a manageable subset | |
| num_samples = min(5000, len(ds)) | |
| ds = ds.shuffle(seed=99).select(range(num_samples)) | |
| except Exception as e: | |
| print(f" Could not load MS MARCO: {e}") | |
| return [] | |
| all_examples = [] | |
| for i, row in enumerate(ds): | |
| if i % 1000 == 0: | |
| print(f" Processing {i}/{num_samples}...") | |
| try: | |
| query = row['query'] | |
| passages = row['passages'] | |
| if not passages or not passages.get('passage_text'): | |
| continue | |
| passage_texts = passages['passage_text'] | |
| is_selected = passages.get('is_selected', [0] * len(passage_texts)) | |
| if not any(is_selected): | |
| continue | |
| # Build blocks from passages (these are real web snippets from Bing) | |
| blocks = [] | |
| relevant_indices = [] | |
| content_indices = [] | |
| idx = 1 | |
| # Header noise | |
| if random.random() < 0.5: | |
| for _ in range(random.randint(1, 2)): | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") | |
| idx += 1 | |
| for p_idx, (text, selected) in enumerate(zip(passage_texts, is_selected)): | |
| text = text.strip() | |
| if not text: | |
| continue | |
| # Simulate different content types | |
| if p_idx == 0 and random.random() < 0.3: | |
| tag = 'h1' | |
| elif len(text) < 80: | |
| tag = random.choice(['h2', 'h3', 'strong']) | |
| else: | |
| tag = 'p' | |
| blocks.append(f"[{idx}] <{tag}>{text}</{tag}>") | |
| content_indices.append(idx) | |
| if selected: | |
| relevant_indices.append(idx) | |
| idx += 1 | |
| # Between-passage noise | |
| if random.random() < 0.2: | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") | |
| idx += 1 | |
| # Footer noise | |
| if random.random() < 0.5: | |
| for _ in range(random.randint(1, 2)): | |
| blocks.append(f"[{idx}] {random.choice(NOISE_BLOCKS_REALISTIC)}") | |
| idx += 1 | |
| if not relevant_indices or len(blocks) < 3: | |
| continue | |
| block_text = "\n".join(blocks) | |
| # QE example | |
| intervals = indices_to_intervals(relevant_indices) | |
| query_type = row.get('query_type', 'general') | |
| user_msg = f"URL: https://www.bing.com/search\nQuery: {query}\n\nBlocks:\n{block_text}\n\nOutput the index intervals of blocks relevant to the query." | |
| all_examples.append({ | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT_QE}, | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": intervals} | |
| ], | |
| "task_type": "query_relevant", | |
| "source": f"msmarco_{query_type}" | |
| }) | |
| except Exception as e: | |
| continue | |
| print(f" Total MS MARCO examples: {len(all_examples)}") | |
| return all_examples | |
| # ============================================================ | |
| # NA Examples (no relevant content) | |
| # ============================================================ | |
| def create_na_examples(all_examples): | |
| """Create NA examples by mismatching queries with pages.""" | |
| print("Creating NA examples (mismatched query-page pairs)...") | |
| # Get QE examples | |
| qe_examples = [e for e in all_examples if e['task_type'] == 'query_relevant'] | |
| if len(qe_examples) < 100: | |
| print(" Too few QE examples for NA generation") | |
| return [] | |
| na_examples = [] | |
| for i in range(min(500, len(qe_examples) // 5)): | |
| # Pick two random QE examples | |
| idx_a = random.randint(0, len(qe_examples) - 1) | |
| idx_b = (idx_a + random.randint(100, len(qe_examples) - 1)) % len(qe_examples) | |
| # Use query from A, blocks from B | |
| msgs_a = qe_examples[idx_a]['messages'] | |
| msgs_b = qe_examples[idx_b]['messages'] | |
| # Extract query from A | |
| user_a = msgs_a[1]['content'] | |
| query_match = re.search(r'Query: (.+?)(\n|$)', user_a) | |
| if not query_match: | |
| continue | |
| query = query_match.group(1).strip() | |
| # Extract blocks from B | |
| user_b = msgs_b[1]['content'] | |
| blocks_match = re.search(r'Blocks:\n(.+?)(\n\nOutput)', user_b, re.DOTALL) | |
| if not blocks_match: | |
| continue | |
| blocks = blocks_match.group(1) | |
| user_msg = f"URL: https://example.com\nQuery: {query}\n\nBlocks:\n{blocks}\n\nOutput the index intervals of blocks relevant to the query." | |
| na_examples.append({ | |
| "messages": [ | |
| {"role": "system", "content": SYSTEM_PROMPT_QE}, | |
| {"role": "user", "content": user_msg}, | |
| {"role": "assistant", "content": "NA"} | |
| ], | |
| "task_type": "query_relevant_na", | |
| "source": "mismatched" | |
| }) | |
| print(f" Created {len(na_examples)} NA examples") | |
| return na_examples | |
| # ============================================================ | |
| # Main Pipeline | |
| # ============================================================ | |
| def main(): | |
| print("=" * 60) | |
| print("Building DIVERSE IndexLM Training Data") | |
| print("=" * 60) | |
| # Collect from all sources | |
| htmlrag_examples = load_htmlrag_data() # Real web HTML (primary) | |
| multihoprag_examples = process_multihoprag() # News domain | |
| hotpotqa_examples = process_hotpotqa() # Wikipedia (minority) | |
| msmarco_examples = process_msmarco() # Diverse web QA | |
| # Combine | |
| all_examples = htmlrag_examples + multihoprag_examples + hotpotqa_examples + msmarco_examples | |
| # Add NA examples | |
| na_examples = create_na_examples(all_examples) | |
| all_examples.extend(na_examples) | |
| random.shuffle(all_examples) | |
| # Print composition | |
| print(f"\n{'='*60}") | |
| print(f"Total examples: {len(all_examples)}") | |
| source_counts = defaultdict(int) | |
| type_counts = defaultdict(int) | |
| for ex in all_examples: | |
| source_counts[ex.get('source', 'unknown')] += 1 | |
| type_counts[ex['task_type']] += 1 | |
| print("\nBy source:") | |
| for s, c in sorted(source_counts.items(), key=lambda x: -x[1]): | |
| pct = c / len(all_examples) * 100 | |
| print(f" {s}: {c} ({pct:.1f}%)") | |
| print("\nBy task type:") | |
| for t, c in sorted(type_counts.items(), key=lambda x: -x[1]): | |
| pct = c / len(all_examples) * 100 | |
| print(f" {t}: {c} ({pct:.1f}%)") | |
| # Check token lengths | |
| print("\nChecking token lengths...") | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") | |
| lengths = [] | |
| for ex in random.sample(all_examples, min(500, len(all_examples))): | |
| text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) | |
| tokens = tokenizer.encode(text) | |
| lengths.append(len(tokens)) | |
| print(f"Token length stats (sample of {len(lengths)}):") | |
| print(f" Min: {min(lengths)}, Max: {max(lengths)}") | |
| print(f" Mean: {sum(lengths)/len(lengths):.0f}, Median: {sorted(lengths)[len(lengths)//2]}") | |
| # Filter by length | |
| MAX_LEN = 4096 | |
| filtered = [] | |
| too_long = 0 | |
| for ex in all_examples: | |
| text = tokenizer.apply_chat_template(ex['messages'], tokenize=False) | |
| tokens = tokenizer.encode(text) | |
| if len(tokens) <= MAX_LEN: | |
| filtered.append(ex) | |
| else: | |
| too_long += 1 | |
| print(f"\nFiltered: {too_long} examples too long (>{MAX_LEN} tokens)") | |
| print(f"Final dataset size: {len(filtered)}") | |
| # Final composition | |
| final_source_counts = defaultdict(int) | |
| for ex in filtered: | |
| final_source_counts[ex.get('source', 'unknown')] += 1 | |
| print("\nFinal composition by source:") | |
| for s, c in sorted(final_source_counts.items(), key=lambda x: -x[1]): | |
| pct = c / len(filtered) * 100 | |
| print(f" {s}: {c} ({pct:.1f}%)") | |
| # Split | |
| random.shuffle(filtered) | |
| eval_size = min(500, len(filtered) // 10) | |
| train_data = filtered[:-eval_size] | |
| eval_data = filtered[-eval_size:] | |
| print(f"\nTrain: {len(train_data)}, Eval: {len(eval_data)}") | |
| # Create HF datasets | |
| train_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in train_data]) | |
| eval_ds = Dataset.from_list([{"messages": ex["messages"]} for ex in eval_data]) | |
| # Save locally | |
| train_ds.save_to_disk("/app/indexlm_train_v2") | |
| eval_ds.save_to_disk("/app/indexlm_eval_v2") | |
| # Push to Hub | |
| ds_dict = DatasetDict({"train": train_ds, "eval": eval_ds}) | |
| ds_dict.push_to_hub("OmAlve/indexlm-training-data", token=os.environ.get("HF_TOKEN")) | |
| print(f"\n{'='*60}") | |
| print("Done! Dataset pushed to OmAlve/indexlm-training-data") | |
| print(f"{'='*60}") | |
| if __name__ == "__main__": | |
| main() | |