Spaces:
Paused
Paused
| import os, time, json, gc, threading, numpy as np, signal, hashlib, random, sys | |
| import zstandard as zstd | |
| from flask import Flask, jsonify | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import HfApi, hf_hub_download | |
| # --- π± CONFIGURATION --- | |
| REPO_ID = "abhinav337463/indro-web-data" | |
| TOKENIZER_ID = "gpt2" | |
| TARGET_TOKENS = 3_000_000_000 | |
| CHUNK_SIZE_MB = 100 | |
| BATCH_SIZE = 1000 | |
| WARMUP_TOKENS = 1_000_000 | |
| RATIO = {"data": 0.70, "code": 0.20, "math": 0.09, "training": 0.01} | |
| STATE_FILE = "state.json" | |
| MAX_RETRIES = 5 # Audit Point: Bound upload failures | |
| # --- π± GLOBAL STATE & LOCKS --- | |
| state_lock = threading.Lock() | |
| state = { | |
| "tokens_processed": 0, | |
| "domain_tokens": {k: 0 for k in RATIO.keys()}, | |
| "distribution_audit": [], # Audit Point: Drift tracking | |
| "files_uploaded": 0, | |
| "file_hashes": {}, | |
| "status": "Initializing", | |
| "start_time": time.time(), | |
| "stop_requested": False, | |
| "critical_error": False | |
| } | |
| token_buffer = [] | |
| app = Flask(__name__) | |
| api = HfApi() | |
| hf_token = os.environ.get("HF_TOKEN") | |
| cctx = zstd.ZstdCompressor(level=3) | |
| # --- π± SYSTEM HANDLERS --- | |
| def graceful_shutdown(signum, frame): | |
| with state_lock: | |
| print("\nπ SIGTERM/SIGINT! Emergency saving...") | |
| state["stop_requested"] = True | |
| signal.signal(signal.SIGTERM, graceful_shutdown) | |
| signal.signal(signal.SIGINT, graceful_shutdown) | |
| def atomic_save_state(): | |
| with state_lock: | |
| temp_file = STATE_FILE + ".tmp" | |
| with open(temp_file, "w") as f: | |
| json.dump(state, f, indent=2) | |
| os.replace(temp_file, STATE_FILE) | |
| try: | |
| api.upload_file(path_or_fileobj=STATE_FILE, path_in_repo=f"Tokenized/{STATE_FILE}", | |
| repo_id=REPO_ID, repo_type="dataset", token=hf_token) | |
| except Exception as e: print(f"β οΈ State sync error: {e}") | |
| def load_state(): | |
| global state | |
| try: | |
| path = hf_hub_download(repo_id=REPO_ID, filename=f"Tokenized/{STATE_FILE}", repo_type="dataset", token=hf_token) | |
| with open(path, "r") as f: | |
| with state_lock: | |
| state.update(json.load(f)) | |
| print(f"π Resuming from {state['tokens_processed']} tokens...") | |
| except: print("π Fresh start.") | |
| def health(): | |
| with state_lock: | |
| elapsed = time.time() - state["start_time"] | |
| speed = state["tokens_processed"] / elapsed if elapsed > 0 else 0 | |
| return jsonify({ | |
| **state, | |
| "progress_percent": round((state['tokens_processed']/TARGET_TOKENS)*100, 4), | |
| "speed_tokens_sec": round(speed, 2), | |
| "eta_hours": round((TARGET_TOKENS - state["tokens_processed"]) / speed / 3600, 2) if speed > 0 else 0 | |
| }) | |
| def upload_chunk(data_to_process, file_idx): | |
| """Audit Point: Copy-then-clear pattern (Process outside lock)""" | |
| zstd_name = f"tokens_{file_idx}.bin.zst" | |
| # Audit Point: Memory-efficient hashing | |
| raw_bytes = data_to_process.tobytes() | |
| sha256_hash = hashlib.sha256(raw_bytes).hexdigest() | |
| # Compression | |
| compressed_data = cctx.compress(raw_bytes) | |
| with open(zstd_name, "wb") as f: | |
| f.write(compressed_data) | |
| success = False | |
| for i in range(MAX_RETRIES): | |
| try: | |
| api.upload_file(path_or_fileobj=zstd_name, path_in_repo=f"Tokenized/{zstd_name}", | |
| repo_id=REPO_ID, repo_type="dataset", token=hf_token) | |
| success = True | |
| break | |
| except Exception as e: | |
| print(f"β οΈ Upload fail {i+1}: {e}") | |
| time.sleep(20 * (i + 1)) | |
| if success: | |
| with state_lock: | |
| state["file_hashes"][zstd_name] = sha256_hash | |
| state["files_uploaded"] += 1 | |
| # Audit: Take distribution snapshot | |
| audit_entry = {k: round(state["domain_tokens"][k]/state["tokens_processed"], 4) for k in RATIO.keys()} | |
| state["distribution_audit"].append({"file": zstd_name, "mix": audit_entry}) | |
| atomic_save_state() | |
| os.remove(zstd_name) | |
| gc.collect() | |
| return True | |
| else: | |
| with state_lock: | |
| state["critical_error"] = True | |
| state["status"] = "CRITICAL_FAILURE: Upload loop broken" | |
| return False | |
| def worker(): | |
| global state, token_buffer | |
| load_state() | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True) | |
| assert tokenizer.eos_token_id is not None | |
| streams = {k: iter(load_dataset(REPO_ID, data_dir=k, split="train", streaming=True)) for k in RATIO.keys()} | |
| while state["tokens_processed"] < TARGET_TOKENS and not state["stop_requested"] and not state["critical_error"]: | |
| # Ratio Selection Logic | |
| with state_lock: | |
| if state["tokens_processed"] < WARMUP_TOKENS: | |
| target_domain = random.choices(list(RATIO.keys()), weights=list(RATIO.values()))[0] | |
| else: | |
| current_dist = {k: state["domain_tokens"][k] / (state["tokens_processed"] + 1) for k in RATIO.keys()} | |
| target_domain = min(RATIO.keys(), key=lambda k: current_dist[k] / RATIO[k]) | |
| batch_texts = [] | |
| for _ in range(BATCH_SIZE): | |
| try: | |
| example = next(streams[target_domain]) | |
| batch_texts.append(example.get("text") or example.get("content") or "") | |
| except StopIteration: | |
| streams[target_domain] = iter(load_dataset(REPO_ID, data_dir=target_domain, split="train", streaming=True)) | |
| if batch_texts: | |
| encodings = tokenizer(batch_texts, add_special_tokens=False) | |
| local_processed = 0 | |
| local_ids = [] | |
| for ids in encodings["input_ids"]: | |
| chunk_ids = ids + [tokenizer.eos_token_id] | |
| local_ids.extend(chunk_ids) | |
| local_processed += len(chunk_ids) | |
| with state_lock: | |
| token_buffer.extend(local_ids) | |
| state["tokens_processed"] += local_processed | |
| state["domain_tokens"][target_domain] += local_processed | |
| # Buffer Threshold check | |
| if (len(token_buffer) * 4) / (1024 * 1024) >= CHUNK_SIZE_MB: | |
| with state_lock: | |
| state["status"] = f"Compressing chunk {state['files_uploaded']}..." | |
| # Audit Point: Copy buffer and clear under lock | |
| data_to_save = np.array(token_buffer, dtype=np.uint32) | |
| token_buffer = [] | |
| idx = state['files_uploaded'] | |
| # Heavy task OUTSIDE lock | |
| if not upload_chunk(data_to_save, idx): | |
| break | |
| with state_lock: state["status"] = "Processing" | |
| # Final cleanup | |
| if token_buffer and not state["critical_error"]: | |
| upload_chunk(np.array(token_buffer, dtype=np.uint32), state['files_uploaded']) | |
| with state_lock: state["status"] = "Finished" if not state["critical_error"] else "FAILED" | |
| atomic_save_state() | |
| threading.Thread(target=worker, daemon=True).start() | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) |