Indrohelper / app.py
abhinav337463's picture
Rename main.py to app.py
755efa1 verified
import os, time, json, gc, threading, numpy as np, signal, hashlib, random, sys
import zstandard as zstd
from flask import Flask, jsonify
from datasets import load_dataset
from transformers import AutoTokenizer
from huggingface_hub import HfApi, hf_hub_download
# --- πŸ”± CONFIGURATION ---
REPO_ID = "abhinav337463/indro-web-data"
TOKENIZER_ID = "gpt2"
TARGET_TOKENS = 3_000_000_000
CHUNK_SIZE_MB = 100
BATCH_SIZE = 1000
WARMUP_TOKENS = 1_000_000
RATIO = {"data": 0.70, "code": 0.20, "math": 0.09, "training": 0.01}
STATE_FILE = "state.json"
MAX_RETRIES = 5 # Audit Point: Bound upload failures
# --- πŸ”± GLOBAL STATE & LOCKS ---
state_lock = threading.Lock()
state = {
"tokens_processed": 0,
"domain_tokens": {k: 0 for k in RATIO.keys()},
"distribution_audit": [], # Audit Point: Drift tracking
"files_uploaded": 0,
"file_hashes": {},
"status": "Initializing",
"start_time": time.time(),
"stop_requested": False,
"critical_error": False
}
token_buffer = []
app = Flask(__name__)
api = HfApi()
hf_token = os.environ.get("HF_TOKEN")
cctx = zstd.ZstdCompressor(level=3)
# --- πŸ”± SYSTEM HANDLERS ---
def graceful_shutdown(signum, frame):
with state_lock:
print("\nπŸ›‘ SIGTERM/SIGINT! Emergency saving...")
state["stop_requested"] = True
signal.signal(signal.SIGTERM, graceful_shutdown)
signal.signal(signal.SIGINT, graceful_shutdown)
def atomic_save_state():
with state_lock:
temp_file = STATE_FILE + ".tmp"
with open(temp_file, "w") as f:
json.dump(state, f, indent=2)
os.replace(temp_file, STATE_FILE)
try:
api.upload_file(path_or_fileobj=STATE_FILE, path_in_repo=f"Tokenized/{STATE_FILE}",
repo_id=REPO_ID, repo_type="dataset", token=hf_token)
except Exception as e: print(f"⚠️ State sync error: {e}")
def load_state():
global state
try:
path = hf_hub_download(repo_id=REPO_ID, filename=f"Tokenized/{STATE_FILE}", repo_type="dataset", token=hf_token)
with open(path, "r") as f:
with state_lock:
state.update(json.load(f))
print(f"πŸ”„ Resuming from {state['tokens_processed']} tokens...")
except: print("πŸ†• Fresh start.")
@app.route('/health')
def health():
with state_lock:
elapsed = time.time() - state["start_time"]
speed = state["tokens_processed"] / elapsed if elapsed > 0 else 0
return jsonify({
**state,
"progress_percent": round((state['tokens_processed']/TARGET_TOKENS)*100, 4),
"speed_tokens_sec": round(speed, 2),
"eta_hours": round((TARGET_TOKENS - state["tokens_processed"]) / speed / 3600, 2) if speed > 0 else 0
})
def upload_chunk(data_to_process, file_idx):
"""Audit Point: Copy-then-clear pattern (Process outside lock)"""
zstd_name = f"tokens_{file_idx}.bin.zst"
# Audit Point: Memory-efficient hashing
raw_bytes = data_to_process.tobytes()
sha256_hash = hashlib.sha256(raw_bytes).hexdigest()
# Compression
compressed_data = cctx.compress(raw_bytes)
with open(zstd_name, "wb") as f:
f.write(compressed_data)
success = False
for i in range(MAX_RETRIES):
try:
api.upload_file(path_or_fileobj=zstd_name, path_in_repo=f"Tokenized/{zstd_name}",
repo_id=REPO_ID, repo_type="dataset", token=hf_token)
success = True
break
except Exception as e:
print(f"⚠️ Upload fail {i+1}: {e}")
time.sleep(20 * (i + 1))
if success:
with state_lock:
state["file_hashes"][zstd_name] = sha256_hash
state["files_uploaded"] += 1
# Audit: Take distribution snapshot
audit_entry = {k: round(state["domain_tokens"][k]/state["tokens_processed"], 4) for k in RATIO.keys()}
state["distribution_audit"].append({"file": zstd_name, "mix": audit_entry})
atomic_save_state()
os.remove(zstd_name)
gc.collect()
return True
else:
with state_lock:
state["critical_error"] = True
state["status"] = "CRITICAL_FAILURE: Upload loop broken"
return False
def worker():
global state, token_buffer
load_state()
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True)
assert tokenizer.eos_token_id is not None
streams = {k: iter(load_dataset(REPO_ID, data_dir=k, split="train", streaming=True)) for k in RATIO.keys()}
while state["tokens_processed"] < TARGET_TOKENS and not state["stop_requested"] and not state["critical_error"]:
# Ratio Selection Logic
with state_lock:
if state["tokens_processed"] < WARMUP_TOKENS:
target_domain = random.choices(list(RATIO.keys()), weights=list(RATIO.values()))[0]
else:
current_dist = {k: state["domain_tokens"][k] / (state["tokens_processed"] + 1) for k in RATIO.keys()}
target_domain = min(RATIO.keys(), key=lambda k: current_dist[k] / RATIO[k])
batch_texts = []
for _ in range(BATCH_SIZE):
try:
example = next(streams[target_domain])
batch_texts.append(example.get("text") or example.get("content") or "")
except StopIteration:
streams[target_domain] = iter(load_dataset(REPO_ID, data_dir=target_domain, split="train", streaming=True))
if batch_texts:
encodings = tokenizer(batch_texts, add_special_tokens=False)
local_processed = 0
local_ids = []
for ids in encodings["input_ids"]:
chunk_ids = ids + [tokenizer.eos_token_id]
local_ids.extend(chunk_ids)
local_processed += len(chunk_ids)
with state_lock:
token_buffer.extend(local_ids)
state["tokens_processed"] += local_processed
state["domain_tokens"][target_domain] += local_processed
# Buffer Threshold check
if (len(token_buffer) * 4) / (1024 * 1024) >= CHUNK_SIZE_MB:
with state_lock:
state["status"] = f"Compressing chunk {state['files_uploaded']}..."
# Audit Point: Copy buffer and clear under lock
data_to_save = np.array(token_buffer, dtype=np.uint32)
token_buffer = []
idx = state['files_uploaded']
# Heavy task OUTSIDE lock
if not upload_chunk(data_to_save, idx):
break
with state_lock: state["status"] = "Processing"
# Final cleanup
if token_buffer and not state["critical_error"]:
upload_chunk(np.array(token_buffer, dtype=np.uint32), state['files_uploaded'])
with state_lock: state["status"] = "Finished" if not state["critical_error"] else "FAILED"
atomic_save_state()
threading.Thread(target=worker, daemon=True).start()
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)