OpenTransformer commited on
Commit
54a950a
·
verified ·
1 Parent(s): d1eeb5d

Upload feeder/src/main.rs with huggingface_hub

Browse files
Files changed (1) hide show
  1. feeder/src/main.rs +217 -0
feeder/src/main.rs ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ use crossbeam_channel::{bounded, Sender, Receiver};
2
+ use dashmap::DashMap;
3
+ use regex::Regex;
4
+ use reqwest::Client;
5
+ use scraper::{Html, Selector};
6
+ use std::io::{self, Write, BufWriter};
7
+ use std::sync::atomic::{AtomicU64, Ordering};
8
+ use std::sync::Arc;
9
+ use std::time::{Duration, Instant};
10
+ use std::hash::{Hash, Hasher};
11
+ use std::collections::hash_map::DefaultHasher;
12
+ use tokenizers::Tokenizer;
13
+ use url::Url;
14
+
15
+ lazy_static::lazy_static! {
16
+ static ref MULTI_SPACE: Regex = Regex::new(r"\s+").unwrap();
17
+ }
18
+
19
+ struct Feeder {
20
+ client: Client,
21
+ visited: DashMap<u64, ()>,
22
+ url_queue: Sender<String>,
23
+ url_recv: Receiver<String>,
24
+ token_tx: Sender<u32>,
25
+ tokenizer: Tokenizer,
26
+ tokens_out: AtomicU64,
27
+ pages: AtomicU64,
28
+ start: Instant,
29
+ }
30
+
31
+ fn hash_url(url: &str) -> u64 {
32
+ let mut h = DefaultHasher::new();
33
+ url.hash(&mut h);
34
+ h.finish()
35
+ }
36
+
37
+ impl Feeder {
38
+ fn new(token_tx: Sender<u32>, queue_size: usize) -> Self {
39
+ let (url_tx, url_rx) = bounded(queue_size);
40
+
41
+ // Load tokenizer from file (pre-downloaded)
42
+ let tokenizer = Tokenizer::from_file("/workspace/tokenizer.json")
43
+ .expect("Failed to load tokenizer - download first with: curl -L 'https://huggingface.co/gpt2/resolve/main/tokenizer.json' -o /workspace/tokenizer.json");
44
+
45
+ Self {
46
+ client: Client::builder()
47
+ .timeout(Duration::from_secs(8))
48
+ .pool_max_idle_per_host(100)
49
+ .gzip(true).brotli(true).deflate(true)
50
+ .redirect(reqwest::redirect::Policy::limited(3))
51
+ .danger_accept_invalid_certs(true)
52
+ .user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/120.0.0.0")
53
+ .build().unwrap(),
54
+ visited: DashMap::new(),
55
+ url_queue: url_tx,
56
+ url_recv: url_rx,
57
+ token_tx,
58
+ tokenizer,
59
+ tokens_out: AtomicU64::new(0),
60
+ pages: AtomicU64::new(0),
61
+ start: Instant::now(),
62
+ }
63
+ }
64
+
65
+ fn extract_text(html: &str) -> String {
66
+ let doc = Html::parse_document(html);
67
+ let sel = Selector::parse("p,h1,h2,h3,h4,li,td,article,span").unwrap();
68
+
69
+ let mut text = String::with_capacity(html.len() / 4);
70
+ for el in doc.select(&sel) {
71
+ let t: String = el.text().collect();
72
+ if t.len() > 20 {
73
+ text.push_str(t.trim());
74
+ text.push(' ');
75
+ }
76
+ }
77
+
78
+ MULTI_SPACE.replace_all(&text, " ").trim().to_string()
79
+ }
80
+
81
+ fn extract_links(html: &str, base: &Url) -> Vec<String> {
82
+ let doc = Html::parse_document(html);
83
+ let sel = Selector::parse("a[href]").unwrap();
84
+
85
+ let mut links = Vec::with_capacity(30);
86
+ for el in doc.select(&sel) {
87
+ if let Some(href) = el.value().attr("href") {
88
+ if let Ok(u) = base.join(href) {
89
+ let s = u.to_string();
90
+ if s.starts_with("http") && !s.contains("login")
91
+ && !s.contains(".pdf") && !s.contains(".jpg") {
92
+ links.push(s);
93
+ }
94
+ }
95
+ }
96
+ }
97
+ links
98
+ }
99
+
100
+ async fn crawl_and_feed(&self, url: String) {
101
+ let h = hash_url(&url);
102
+ if self.visited.contains_key(&h) { return; }
103
+ self.visited.insert(h, ());
104
+
105
+ // Fetch
106
+ let resp = match self.client.get(&url).send().await {
107
+ Ok(r) if r.status().is_success() => r,
108
+ _ => return,
109
+ };
110
+
111
+ let ct = resp.headers().get("content-type")
112
+ .and_then(|v| v.to_str().ok()).unwrap_or("");
113
+ if !ct.contains("text/html") { return; }
114
+
115
+ let html = match resp.text().await { Ok(t) => t, _ => return };
116
+ let base = match Url::parse(&url) { Ok(u) => u, _ => return };
117
+
118
+ // Extract and queue links
119
+ let links = Self::extract_links(&html, &base);
120
+ for link in links.into_iter().take(15) {
121
+ let _ = self.url_queue.try_send(link);
122
+ }
123
+
124
+ // Extract text
125
+ let text = Self::extract_text(&html);
126
+ if text.len() < 100 { return; }
127
+
128
+ // Tokenize and send
129
+ if let Ok(encoding) = self.tokenizer.encode(text, false) {
130
+ for &id in encoding.get_ids() {
131
+ if self.token_tx.send(id).is_err() {
132
+ return; // Consumer closed
133
+ }
134
+ self.tokens_out.fetch_add(1, Ordering::Relaxed);
135
+ }
136
+ }
137
+
138
+ self.pages.fetch_add(1, Ordering::Relaxed);
139
+ }
140
+
141
+ fn print_stats(&self) {
142
+ let e = self.start.elapsed().as_secs_f64();
143
+ let tok = self.tokens_out.load(Ordering::Relaxed);
144
+ let pgs = self.pages.load(Ordering::Relaxed);
145
+ eprintln!("[FEEDER {:.0}s] {} pages | {} tok | {:.0} tok/s",
146
+ e, pgs, tok, tok as f64 / e);
147
+ }
148
+ }
149
+
150
+ #[tokio::main(flavor = "multi_thread", worker_threads = 64)]
151
+ async fn main() {
152
+ eprintln!("Wire-Speed Feeder starting...");
153
+
154
+ // Channel to Python trainer
155
+ let (token_tx, token_rx) = bounded::<u32>(100_000);
156
+
157
+ let feeder = Arc::new(Feeder::new(token_tx, 500_000));
158
+
159
+ // Seed URLs
160
+ let seeds = vec![
161
+ "https://en.wikipedia.org/wiki/Main_Page",
162
+ "https://news.ycombinator.com/",
163
+ "https://reddit.com/r/all",
164
+ "https://bbc.com/news",
165
+ "https://reuters.com",
166
+ "https://arxiv.org",
167
+ "https://stackoverflow.com/questions",
168
+ "https://medium.com",
169
+ "https://theguardian.com",
170
+ "https://nature.com",
171
+ "https://github.com/explore",
172
+ "https://nytimes.com",
173
+ ];
174
+
175
+ for seed in seeds {
176
+ let _ = feeder.url_queue.try_send(seed.to_string());
177
+ }
178
+
179
+ // Stats printer
180
+ let stats_feeder = Arc::clone(&feeder);
181
+ tokio::spawn(async move {
182
+ loop {
183
+ tokio::time::sleep(Duration::from_secs(10)).await;
184
+ stats_feeder.print_stats();
185
+ }
186
+ });
187
+
188
+ // Output thread - writes tokens to stdout
189
+ std::thread::spawn(move || {
190
+ let stdout = io::stdout();
191
+ let mut out = BufWriter::with_capacity(1 << 16, stdout.lock());
192
+
193
+ while let Ok(token) = token_rx.recv() {
194
+ let _ = writeln!(out, "{}", token);
195
+ }
196
+ });
197
+
198
+ // Worker pool - 500 concurrent fetchers
199
+ let mut handles = Vec::with_capacity(500);
200
+ for _ in 0..500 {
201
+ let f = Arc::clone(&feeder);
202
+ handles.push(tokio::spawn(async move {
203
+ loop {
204
+ if let Ok(url) = f.url_recv.try_recv() {
205
+ f.crawl_and_feed(url).await;
206
+ } else {
207
+ tokio::time::sleep(Duration::from_micros(100)).await;
208
+ }
209
+ }
210
+ }));
211
+ }
212
+
213
+ // Run forever
214
+ loop {
215
+ tokio::time::sleep(Duration::from_secs(3600)).await;
216
+ }
217
+ }