Neon-tech commited on
Commit
a6a417b
Β·
verified Β·
1 Parent(s): b3181f1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -0
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import socket
5
+ import threading
6
+ import gc
7
+ import ctypes
8
+ import multiprocessing as mp
9
+ from pathlib import Path
10
+ import numpy as np
11
+ from tokenizers import Tokenizer
12
+
13
+ # ── Config ───────────────────────────────────────────────────────────────────
14
+ STATE_FILE = "/data/state.json"
15
+ RAW_DIR = "/data/raw"
16
+ OUT_DIR = "/data/tokenized"
17
+ TOK_PATH = "/data/tokenizer.json"
18
+ WORKER_ID = socket.gethostname()
19
+ POLL_INTERVAL = 15
20
+ BATCH_SIZE = 2 # 2 lines at a time across 2 cores
21
+
22
+ os.makedirs(OUT_DIR, exist_ok=True)
23
+
24
+ # ── Keep-alive ────────────────────────────────────────────────────────────────
25
+ def serve():
26
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
28
+ s.bind(("0.0.0.0", 7860))
29
+ s.listen(5)
30
+ print(f"βœ“ [{WORKER_ID}] Listening on port 7860")
31
+ while True:
32
+ conn, _ = s.accept()
33
+ conn.send(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK")
34
+ conn.close()
35
+
36
+ # ── State ─────────────────────────────────────────────────────────────────────
37
+ def load_state():
38
+ with open(STATE_FILE) as f:
39
+ return json.load(f)
40
+
41
+ def save_state(state):
42
+ tmp = STATE_FILE + f".tmp.{WORKER_ID}"
43
+ with open(tmp, "w") as f:
44
+ json.dump(state, f, indent=2)
45
+ os.replace(tmp, STATE_FILE)
46
+
47
+ # ── Claim ─────────────────────────────────────────────────────────────────────
48
+ def claim_shard(state):
49
+ for name, info in state["shards"].items():
50
+ if info["status"] == "pending":
51
+ raw_path = Path(RAW_DIR) / name
52
+ if raw_path.exists():
53
+ info["status"] = "claimed"
54
+ info["worker"] = WORKER_ID
55
+ info["claimed_at"] = time.time()
56
+ save_state(state)
57
+ return name, raw_path
58
+ return None, None
59
+
60
+ # ── Tokenizer subprocess ──────────────────────────────────────────────────────
61
+ _worker_tok = None
62
+ _worker_sep = None
63
+
64
+ def init_worker(tok_path):
65
+ global _worker_tok, _worker_sep
66
+ _worker_tok = Tokenizer.from_file(tok_path)
67
+ _worker_sep = _worker_tok.token_to_id("<sep>")
68
+
69
+ def tokenize_texts(texts):
70
+ """Tokenize a list of texts, append <sep> to each."""
71
+ encs = _worker_tok.encode_batch(texts)
72
+ result = []
73
+ for enc in encs:
74
+ ids = enc.ids
75
+ if len(ids) >= 2:
76
+ ids.append(_worker_sep)
77
+ result.append(ids)
78
+ return result
79
+
80
+ # ── Process shard ─────────────────────────────────────────────────────────────
81
+ def process_shard(name, raw_path, pool):
82
+ print(f" [{WORKER_ID}] Processing: {name}")
83
+
84
+ out_name = name.replace(".jsonl", ".bin")
85
+ out_path = Path(OUT_DIR) / out_name
86
+ tmp_path = Path(OUT_DIR) / f"{out_name}.tmp"
87
+
88
+ # Crash recovery β€” delete any partial output from previous attempt
89
+ tmp_path.unlink(missing_ok=True)
90
+ out_path.unlink(missing_ok=True)
91
+
92
+ total_tokens = 0
93
+ total_docs = 0
94
+
95
+ try:
96
+ with open(raw_path, "r", encoding="utf-8") as fin, \
97
+ open(tmp_path, "wb") as fout:
98
+
99
+ batch_texts = []
100
+
101
+ for line in fin:
102
+ line = line.strip()
103
+ if not line:
104
+ continue
105
+ try:
106
+ obj = json.loads(line)
107
+ text = obj.get("text", "").strip()
108
+ except Exception:
109
+ continue
110
+ if not text:
111
+ continue
112
+
113
+ batch_texts.append(text)
114
+
115
+ if len(batch_texts) >= BATCH_SIZE:
116
+ try:
117
+ results = pool.apply(tokenize_texts, (batch_texts,))
118
+ except Exception as e:
119
+ tmp_path.unlink(missing_ok=True)
120
+ return False, f"tokenize_failed: {e}"
121
+
122
+ for ids in results:
123
+ arr = np.array(ids, dtype=np.uint16)
124
+ arr.tofile(fout)
125
+ total_tokens += len(ids)
126
+ total_docs += 1
127
+
128
+ batch_texts = []
129
+
130
+ # Flush remaining
131
+ if batch_texts:
132
+ try:
133
+ results = pool.apply(tokenize_texts, (batch_texts,))
134
+ except Exception as e:
135
+ tmp_path.unlink(missing_ok=True)
136
+ return False, f"tokenize_failed_flush: {e}"
137
+ for ids in results:
138
+ arr = np.array(ids, dtype=np.uint16)
139
+ arr.tofile(fout)
140
+ total_tokens += len(ids)
141
+ total_docs += 1
142
+
143
+ except Exception as e:
144
+ tmp_path.unlink(missing_ok=True)
145
+ return False, f"process_failed: {e}"
146
+
147
+ # Atomic rename β€” only visible when complete
148
+ tmp_path.rename(out_path)
149
+ print(f" βœ“ [{WORKER_ID}] {out_name} | {total_docs:,} docs | {total_tokens:,} tokens")
150
+ return True, None
151
+
152
+ # ── Memory flush ──────────────────────────────────────────────────────────────
153
+ def flush_memory():
154
+ gc.collect()
155
+ try:
156
+ ctypes.CDLL("libc.so.6").malloc_trim(0)
157
+ except Exception:
158
+ pass
159
+
160
+ # ── Worker loop ───────────────────────────────────────────────────────────────
161
+ def worker_loop():
162
+ print(f"βœ“ [{WORKER_ID}] Starting worker...")
163
+ pool = mp.Pool(processes=2, initializer=init_worker, initargs=(TOK_PATH,))
164
+ print(f"βœ“ [{WORKER_ID}] 2-core tokenizer pool ready")
165
+
166
+ try:
167
+ while True:
168
+ if not os.path.exists(STATE_FILE):
169
+ print(f" [{WORKER_ID}] Waiting for state.json...")
170
+ time.sleep(POLL_INTERVAL)
171
+ continue
172
+
173
+ try:
174
+ state = load_state()
175
+ except Exception as e:
176
+ print(f" [{WORKER_ID}] State read error: {e}")
177
+ time.sleep(POLL_INTERVAL)
178
+ continue
179
+
180
+ total = len(state["shards"]) + len(state.get("queue", []))
181
+ done = sum(1 for v in state["shards"].values() if v["status"] == "done")
182
+ if total > 0 and done == total:
183
+ print(f" [{WORKER_ID}] All done. Sleeping.")
184
+ time.sleep(300)
185
+ continue
186
+
187
+ name, raw_path = claim_shard(state)
188
+ if not name:
189
+ print(f" [{WORKER_ID}] Nothing ready β€” polling in {POLL_INTERVAL}s")
190
+ time.sleep(POLL_INTERVAL)
191
+ continue
192
+
193
+ print(f" [{WORKER_ID}] Claimed: {name}")
194
+ success, error = process_shard(name, raw_path, pool)
195
+
196
+ try:
197
+ state = load_state()
198
+ except Exception:
199
+ pass
200
+
201
+ if success:
202
+ state["shards"][name]["status"] = "done"
203
+ state["shards"][name]["error"] = None
204
+ save_state(state)
205
+ try:
206
+ raw_path.unlink()
207
+ print(f" [{WORKER_ID}] Deleted raw: {raw_path.name}")
208
+ except Exception as e:
209
+ print(f" [{WORKER_ID}] Delete failed: {e}")
210
+ else:
211
+ state["shards"][name]["status"] = "pending"
212
+ state["shards"][name]["worker"] = None
213
+ state["shards"][name]["claimed_at"] = None
214
+ state["shards"][name]["error"] = error
215
+ save_state(state)
216
+ print(f" [{WORKER_ID}] Failed ({error}) β€” reset to pending: {name}")
217
+
218
+ flush_memory()
219
+ time.sleep(5)
220
+
221
+ finally:
222
+ pool.terminate()
223
+ pool.join()
224
+
225
+ # ── Entry point ───────────────────────────────────────────────────────────────
226
+ if __name__ == "__main__":
227
+ threading.Thread(target=serve, daemon=True).start()
228
+ threading.Thread(target=worker_loop, daemon=True).start()
229
+ while True:
230
+ time.sleep(60)