|
|
use crossbeam_channel::{bounded, Sender, Receiver}; |
|
|
use dashmap::DashMap; |
|
|
use regex::Regex; |
|
|
use reqwest::Client; |
|
|
use scraper::{Html, Selector}; |
|
|
use std::io::{self, Write, BufWriter}; |
|
|
use std::sync::atomic::{AtomicU64, Ordering}; |
|
|
use std::sync::Arc; |
|
|
use std::time::{Duration, Instant}; |
|
|
use std::hash::{Hash, Hasher}; |
|
|
use std::collections::hash_map::DefaultHasher; |
|
|
use tokenizers::Tokenizer; |
|
|
use url::Url; |
|
|
|
|
|
lazy_static::lazy_static! { |
|
|
static ref MULTI_SPACE: Regex = Regex::new(r"\s+").unwrap(); |
|
|
} |
|
|
|
|
|
struct Feeder { |
|
|
client: Client, |
|
|
visited: DashMap<u64, ()>, |
|
|
url_queue: Sender<String>, |
|
|
url_recv: Receiver<String>, |
|
|
token_tx: Sender<u32>, |
|
|
tokenizer: Tokenizer, |
|
|
tokens_out: AtomicU64, |
|
|
pages: AtomicU64, |
|
|
start: Instant, |
|
|
} |
|
|
|
|
|
fn hash_url(url: &str) -> u64 { |
|
|
let mut h = DefaultHasher::new(); |
|
|
url.hash(&mut h); |
|
|
h.finish() |
|
|
} |
|
|
|
|
|
impl Feeder { |
|
|
fn new(token_tx: Sender<u32>, queue_size: usize) -> Self { |
|
|
let (url_tx, url_rx) = bounded(queue_size); |
|
|
|
|
|
|
|
|
let tokenizer = Tokenizer::from_file("/workspace/tokenizer.json") |
|
|
.expect("Failed to load tokenizer - download first with: curl -L 'https://huggingface.co/gpt2/resolve/main/tokenizer.json' -o /workspace/tokenizer.json"); |
|
|
|
|
|
Self { |
|
|
client: Client::builder() |
|
|
.timeout(Duration::from_secs(8)) |
|
|
.pool_max_idle_per_host(100) |
|
|
.gzip(true).brotli(true).deflate(true) |
|
|
.redirect(reqwest::redirect::Policy::limited(3)) |
|
|
.danger_accept_invalid_certs(true) |
|
|
.user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120.0.0.0") |
|
|
.build().unwrap(), |
|
|
visited: DashMap::new(), |
|
|
url_queue: url_tx, |
|
|
url_recv: url_rx, |
|
|
token_tx, |
|
|
tokenizer, |
|
|
tokens_out: AtomicU64::new(0), |
|
|
pages: AtomicU64::new(0), |
|
|
start: Instant::now(), |
|
|
} |
|
|
} |
|
|
|
|
|
fn extract_text(html: &str) -> String { |
|
|
let doc = Html::parse_document(html); |
|
|
let sel = Selector::parse("p,h1,h2,h3,h4,li,td,article,span").unwrap(); |
|
|
|
|
|
let mut text = String::with_capacity(html.len() / 4); |
|
|
for el in doc.select(&sel) { |
|
|
let t: String = el.text().collect(); |
|
|
if t.len() > 20 { |
|
|
text.push_str(t.trim()); |
|
|
text.push(' '); |
|
|
} |
|
|
} |
|
|
|
|
|
MULTI_SPACE.replace_all(&text, " ").trim().to_string() |
|
|
} |
|
|
|
|
|
fn extract_links(html: &str, base: &Url) -> Vec<String> { |
|
|
let doc = Html::parse_document(html); |
|
|
let sel = Selector::parse("a[href]").unwrap(); |
|
|
|
|
|
let mut links = Vec::with_capacity(30); |
|
|
for el in doc.select(&sel) { |
|
|
if let Some(href) = el.value().attr("href") { |
|
|
if let Ok(u) = base.join(href) { |
|
|
let s = u.to_string(); |
|
|
if s.starts_with("http") && !s.contains("login") |
|
|
&& !s.contains(".pdf") && !s.contains(".jpg") { |
|
|
links.push(s); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
links |
|
|
} |
|
|
|
|
|
async fn crawl_and_feed(&self, url: String) { |
|
|
let h = hash_url(&url); |
|
|
if self.visited.contains_key(&h) { return; } |
|
|
self.visited.insert(h, ()); |
|
|
|
|
|
|
|
|
let resp = match self.client.get(&url).send().await { |
|
|
Ok(r) if r.status().is_success() => r, |
|
|
_ => return, |
|
|
}; |
|
|
|
|
|
let ct = resp.headers().get("content-type") |
|
|
.and_then(|v| v.to_str().ok()).unwrap_or(""); |
|
|
if !ct.contains("text/html") { return; } |
|
|
|
|
|
let html = match resp.text().await { Ok(t) => t, _ => return }; |
|
|
let base = match Url::parse(&url) { Ok(u) => u, _ => return }; |
|
|
|
|
|
|
|
|
let links = Self::extract_links(&html, &base); |
|
|
for link in links.into_iter().take(15) { |
|
|
let _ = self.url_queue.try_send(link); |
|
|
} |
|
|
|
|
|
|
|
|
let text = Self::extract_text(&html); |
|
|
if text.len() < 100 { return; } |
|
|
|
|
|
|
|
|
if let Ok(encoding) = self.tokenizer.encode(text, false) { |
|
|
for &id in encoding.get_ids() { |
|
|
if self.token_tx.send(id).is_err() { |
|
|
return; |
|
|
} |
|
|
self.tokens_out.fetch_add(1, Ordering::Relaxed); |
|
|
} |
|
|
} |
|
|
|
|
|
self.pages.fetch_add(1, Ordering::Relaxed); |
|
|
} |
|
|
|
|
|
fn print_stats(&self) { |
|
|
let e = self.start.elapsed().as_secs_f64(); |
|
|
let tok = self.tokens_out.load(Ordering::Relaxed); |
|
|
let pgs = self.pages.load(Ordering::Relaxed); |
|
|
eprintln!("[FEEDER {:.0}s] {} pages | {} tok | {:.0} tok/s", |
|
|
e, pgs, tok, tok as f64 / e); |
|
|
} |
|
|
} |
|
|
|
|
|
#[tokio::main(flavor = "multi_thread", worker_threads = 64)] |
|
|
async fn main() { |
|
|
eprintln!("Wire-Speed Feeder starting..."); |
|
|
|
|
|
|
|
|
let (token_tx, token_rx) = bounded::<u32>(100_000); |
|
|
|
|
|
let feeder = Arc::new(Feeder::new(token_tx, 500_000)); |
|
|
|
|
|
|
|
|
let seeds = vec![ |
|
|
"https://en.wikipedia.org/wiki/Main_Page", |
|
|
"https://news.ycombinator.com/", |
|
|
"https://reddit.com/r/all", |
|
|
"https://bbc.com/news", |
|
|
"https://reuters.com", |
|
|
"https://arxiv.org", |
|
|
"https://stackoverflow.com/questions", |
|
|
"https://medium.com", |
|
|
"https://theguardian.com", |
|
|
"https://nature.com", |
|
|
"https://github.com/explore", |
|
|
"https://nytimes.com", |
|
|
]; |
|
|
|
|
|
for seed in seeds { |
|
|
let _ = feeder.url_queue.try_send(seed.to_string()); |
|
|
} |
|
|
|
|
|
|
|
|
let stats_feeder = Arc::clone(&feeder); |
|
|
tokio::spawn(async move { |
|
|
loop { |
|
|
tokio::time::sleep(Duration::from_secs(10)).await; |
|
|
stats_feeder.print_stats(); |
|
|
} |
|
|
}); |
|
|
|
|
|
|
|
|
std::thread::spawn(move || { |
|
|
let stdout = io::stdout(); |
|
|
let mut out = BufWriter::with_capacity(1 << 16, stdout.lock()); |
|
|
|
|
|
while let Ok(token) = token_rx.recv() { |
|
|
let _ = writeln!(out, "{}", token); |
|
|
} |
|
|
}); |
|
|
|
|
|
|
|
|
let mut handles = Vec::with_capacity(500); |
|
|
for _ in 0..500 { |
|
|
let f = Arc::clone(&feeder); |
|
|
handles.push(tokio::spawn(async move { |
|
|
loop { |
|
|
if let Ok(url) = f.url_recv.try_recv() { |
|
|
f.crawl_and_feed(url).await; |
|
|
} else { |
|
|
tokio::time::sleep(Duration::from_micros(100)).await; |
|
|
} |
|
|
} |
|
|
})); |
|
|
} |
|
|
|
|
|
|
|
|
loop { |
|
|
tokio::time::sleep(Duration::from_secs(3600)).await; |
|
|
} |
|
|
} |
|
|
|