OpenTransformer's picture
Upload feeder/src/main.rs with huggingface_hub
54a950a verified
use crossbeam_channel::{bounded, Sender, Receiver};
use dashmap::DashMap;
use regex::Regex;
use reqwest::Client;
use scraper::{Html, Selector};
use std::io::{self, Write, BufWriter};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use tokenizers::Tokenizer;
use url::Url;
lazy_static::lazy_static! {
static ref MULTI_SPACE: Regex = Regex::new(r"\s+").unwrap();
}
struct Feeder {
client: Client,
visited: DashMap<u64, ()>,
url_queue: Sender<String>,
url_recv: Receiver<String>,
token_tx: Sender<u32>,
tokenizer: Tokenizer,
tokens_out: AtomicU64,
pages: AtomicU64,
start: Instant,
}
fn hash_url(url: &str) -> u64 {
let mut h = DefaultHasher::new();
url.hash(&mut h);
h.finish()
}
impl Feeder {
fn new(token_tx: Sender<u32>, queue_size: usize) -> Self {
let (url_tx, url_rx) = bounded(queue_size);
// Load tokenizer from file (pre-downloaded)
let tokenizer = Tokenizer::from_file("/workspace/tokenizer.json")
.expect("Failed to load tokenizer - download first with: curl -L 'https://huggingface.co/gpt2/resolve/main/tokenizer.json' -o /workspace/tokenizer.json");
Self {
client: Client::builder()
.timeout(Duration::from_secs(8))
.pool_max_idle_per_host(100)
.gzip(true).brotli(true).deflate(true)
.redirect(reqwest::redirect::Policy::limited(3))
.danger_accept_invalid_certs(true)
.user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120.0.0.0")
.build().unwrap(),
visited: DashMap::new(),
url_queue: url_tx,
url_recv: url_rx,
token_tx,
tokenizer,
tokens_out: AtomicU64::new(0),
pages: AtomicU64::new(0),
start: Instant::now(),
}
}
fn extract_text(html: &str) -> String {
let doc = Html::parse_document(html);
let sel = Selector::parse("p,h1,h2,h3,h4,li,td,article,span").unwrap();
let mut text = String::with_capacity(html.len() / 4);
for el in doc.select(&sel) {
let t: String = el.text().collect();
if t.len() > 20 {
text.push_str(t.trim());
text.push(' ');
}
}
MULTI_SPACE.replace_all(&text, " ").trim().to_string()
}
fn extract_links(html: &str, base: &Url) -> Vec<String> {
let doc = Html::parse_document(html);
let sel = Selector::parse("a[href]").unwrap();
let mut links = Vec::with_capacity(30);
for el in doc.select(&sel) {
if let Some(href) = el.value().attr("href") {
if let Ok(u) = base.join(href) {
let s = u.to_string();
if s.starts_with("http") && !s.contains("login")
&& !s.contains(".pdf") && !s.contains(".jpg") {
links.push(s);
}
}
}
}
links
}
async fn crawl_and_feed(&self, url: String) {
let h = hash_url(&url);
if self.visited.contains_key(&h) { return; }
self.visited.insert(h, ());
// Fetch
let resp = match self.client.get(&url).send().await {
Ok(r) if r.status().is_success() => r,
_ => return,
};
let ct = resp.headers().get("content-type")
.and_then(|v| v.to_str().ok()).unwrap_or("");
if !ct.contains("text/html") { return; }
let html = match resp.text().await { Ok(t) => t, _ => return };
let base = match Url::parse(&url) { Ok(u) => u, _ => return };
// Extract and queue links
let links = Self::extract_links(&html, &base);
for link in links.into_iter().take(15) {
let _ = self.url_queue.try_send(link);
}
// Extract text
let text = Self::extract_text(&html);
if text.len() < 100 { return; }
// Tokenize and send
if let Ok(encoding) = self.tokenizer.encode(text, false) {
for &id in encoding.get_ids() {
if self.token_tx.send(id).is_err() {
return; // Consumer closed
}
self.tokens_out.fetch_add(1, Ordering::Relaxed);
}
}
self.pages.fetch_add(1, Ordering::Relaxed);
}
fn print_stats(&self) {
let e = self.start.elapsed().as_secs_f64();
let tok = self.tokens_out.load(Ordering::Relaxed);
let pgs = self.pages.load(Ordering::Relaxed);
eprintln!("[FEEDER {:.0}s] {} pages | {} tok | {:.0} tok/s",
e, pgs, tok, tok as f64 / e);
}
}
#[tokio::main(flavor = "multi_thread", worker_threads = 64)]
async fn main() {
eprintln!("Wire-Speed Feeder starting...");
// Channel to Python trainer
let (token_tx, token_rx) = bounded::<u32>(100_000);
let feeder = Arc::new(Feeder::new(token_tx, 500_000));
// Seed URLs
let seeds = vec![
"https://en.wikipedia.org/wiki/Main_Page",
"https://news.ycombinator.com/",
"https://reddit.com/r/all",
"https://bbc.com/news",
"https://reuters.com",
"https://arxiv.org",
"https://stackoverflow.com/questions",
"https://medium.com",
"https://theguardian.com",
"https://nature.com",
"https://github.com/explore",
"https://nytimes.com",
];
for seed in seeds {
let _ = feeder.url_queue.try_send(seed.to_string());
}
// Stats printer
let stats_feeder = Arc::clone(&feeder);
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(10)).await;
stats_feeder.print_stats();
}
});
// Output thread - writes tokens to stdout
std::thread::spawn(move || {
let stdout = io::stdout();
let mut out = BufWriter::with_capacity(1 << 16, stdout.lock());
while let Ok(token) = token_rx.recv() {
let _ = writeln!(out, "{}", token);
}
});
// Worker pool - 500 concurrent fetchers
let mut handles = Vec::with_capacity(500);
for _ in 0..500 {
let f = Arc::clone(&feeder);
handles.push(tokio::spawn(async move {
loop {
if let Ok(url) = f.url_recv.try_recv() {
f.crawl_and_feed(url).await;
} else {
tokio::time::sleep(Duration::from_micros(100)).await;
}
}
}));
}
// Run forever
loop {
tokio::time::sleep(Duration::from_secs(3600)).await;
}
}