use crossbeam_channel::{bounded, Sender, Receiver}; use dashmap::DashMap; use regex::Regex; use reqwest::Client; use scraper::{Html, Selector}; use std::io::{self, Write, BufWriter}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; use tokenizers::Tokenizer; use url::Url; lazy_static::lazy_static! { static ref MULTI_SPACE: Regex = Regex::new(r"\s+").unwrap(); } struct Feeder { client: Client, visited: DashMap, url_queue: Sender, url_recv: Receiver, token_tx: Sender, tokenizer: Tokenizer, tokens_out: AtomicU64, pages: AtomicU64, start: Instant, } fn hash_url(url: &str) -> u64 { let mut h = DefaultHasher::new(); url.hash(&mut h); h.finish() } impl Feeder { fn new(token_tx: Sender, queue_size: usize) -> Self { let (url_tx, url_rx) = bounded(queue_size); // Load tokenizer from file (pre-downloaded) let tokenizer = Tokenizer::from_file("/workspace/tokenizer.json") .expect("Failed to load tokenizer - download first with: curl -L 'https://huggingface.co/gpt2/resolve/main/tokenizer.json' -o /workspace/tokenizer.json"); Self { client: Client::builder() .timeout(Duration::from_secs(8)) .pool_max_idle_per_host(100) .gzip(true).brotli(true).deflate(true) .redirect(reqwest::redirect::Policy::limited(3)) .danger_accept_invalid_certs(true) .user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120.0.0.0") .build().unwrap(), visited: DashMap::new(), url_queue: url_tx, url_recv: url_rx, token_tx, tokenizer, tokens_out: AtomicU64::new(0), pages: AtomicU64::new(0), start: Instant::now(), } } fn extract_text(html: &str) -> String { let doc = Html::parse_document(html); let sel = Selector::parse("p,h1,h2,h3,h4,li,td,article,span").unwrap(); let mut text = String::with_capacity(html.len() / 4); for el in doc.select(&sel) { let t: String = el.text().collect(); if t.len() > 20 { text.push_str(t.trim()); text.push(' '); } } MULTI_SPACE.replace_all(&text, " ").trim().to_string() } fn extract_links(html: &str, base: &Url) -> Vec { let doc = Html::parse_document(html); let sel = Selector::parse("a[href]").unwrap(); let mut links = Vec::with_capacity(30); for el in doc.select(&sel) { if let Some(href) = el.value().attr("href") { if let Ok(u) = base.join(href) { let s = u.to_string(); if s.starts_with("http") && !s.contains("login") && !s.contains(".pdf") && !s.contains(".jpg") { links.push(s); } } } } links } async fn crawl_and_feed(&self, url: String) { let h = hash_url(&url); if self.visited.contains_key(&h) { return; } self.visited.insert(h, ()); // Fetch let resp = match self.client.get(&url).send().await { Ok(r) if r.status().is_success() => r, _ => return, }; let ct = resp.headers().get("content-type") .and_then(|v| v.to_str().ok()).unwrap_or(""); if !ct.contains("text/html") { return; } let html = match resp.text().await { Ok(t) => t, _ => return }; let base = match Url::parse(&url) { Ok(u) => u, _ => return }; // Extract and queue links let links = Self::extract_links(&html, &base); for link in links.into_iter().take(15) { let _ = self.url_queue.try_send(link); } // Extract text let text = Self::extract_text(&html); if text.len() < 100 { return; } // Tokenize and send if let Ok(encoding) = self.tokenizer.encode(text, false) { for &id in encoding.get_ids() { if self.token_tx.send(id).is_err() { return; // Consumer closed } self.tokens_out.fetch_add(1, Ordering::Relaxed); } } self.pages.fetch_add(1, Ordering::Relaxed); } fn print_stats(&self) { let e = self.start.elapsed().as_secs_f64(); let tok = self.tokens_out.load(Ordering::Relaxed); let pgs = self.pages.load(Ordering::Relaxed); eprintln!("[FEEDER {:.0}s] {} pages | {} tok | {:.0} tok/s", e, pgs, tok, tok as f64 / e); } } #[tokio::main(flavor = "multi_thread", worker_threads = 64)] async fn main() { eprintln!("Wire-Speed Feeder starting..."); // Channel to Python trainer let (token_tx, token_rx) = bounded::(100_000); let feeder = Arc::new(Feeder::new(token_tx, 500_000)); // Seed URLs let seeds = vec![ "https://en.wikipedia.org/wiki/Main_Page", "https://news.ycombinator.com/", "https://reddit.com/r/all", "https://bbc.com/news", "https://reuters.com", "https://arxiv.org", "https://stackoverflow.com/questions", "https://medium.com", "https://theguardian.com", "https://nature.com", "https://github.com/explore", "https://nytimes.com", ]; for seed in seeds { let _ = feeder.url_queue.try_send(seed.to_string()); } // Stats printer let stats_feeder = Arc::clone(&feeder); tokio::spawn(async move { loop { tokio::time::sleep(Duration::from_secs(10)).await; stats_feeder.print_stats(); } }); // Output thread - writes tokens to stdout std::thread::spawn(move || { let stdout = io::stdout(); let mut out = BufWriter::with_capacity(1 << 16, stdout.lock()); while let Ok(token) = token_rx.recv() { let _ = writeln!(out, "{}", token); } }); // Worker pool - 500 concurrent fetchers let mut handles = Vec::with_capacity(500); for _ in 0..500 { let f = Arc::clone(&feeder); handles.push(tokio::spawn(async move { loop { if let Ok(url) = f.url_recv.try_recv() { f.crawl_and_feed(url).await; } else { tokio::time::sleep(Duration::from_micros(100)).await; } } })); } // Run forever loop { tokio::time::sleep(Duration::from_secs(3600)).await; } }