Upload scripts/cache_dataset.py with huggingface_hub
Browse files- scripts/cache_dataset.py +37 -2
scripts/cache_dataset.py
CHANGED
|
@@ -288,11 +288,37 @@ def main():
|
|
| 288 |
class_distribution = {}
|
| 289 |
process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
if args.num_workers == 1:
|
| 292 |
print("INFO: Single-threaded mode...")
|
| 293 |
_init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
| 295 |
result = process_fn(task)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
if result['status'] == 'success':
|
| 297 |
success_count += 1
|
| 298 |
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
|
@@ -301,13 +327,21 @@ def main():
|
|
| 301 |
else:
|
| 302 |
error_count += 1
|
| 303 |
tqdm.write(f"ERROR: {result['mint'][:16]} - {result['error']}")
|
|
|
|
| 304 |
else:
|
| 305 |
print(f"INFO: Running with {args.num_workers} workers...")
|
|
|
|
|
|
|
| 306 |
with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor:
|
| 307 |
futures = {executor.submit(process_fn, task): task for task in tasks}
|
| 308 |
-
for future in tqdm(as_completed(futures), total=len(futures), desc="Caching"):
|
|
|
|
| 309 |
try:
|
| 310 |
result = future.result(timeout=300)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
if result['status'] == 'success':
|
| 312 |
success_count += 1
|
| 313 |
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
|
@@ -318,6 +352,7 @@ def main():
|
|
| 318 |
except Exception as e:
|
| 319 |
error_count += 1
|
| 320 |
tqdm.write(f"WORKER ERROR: {e}")
|
|
|
|
| 321 |
|
| 322 |
print("INFO: Building metadata...")
|
| 323 |
file_class_map = {}
|
|
|
|
| 288 |
class_distribution = {}
|
| 289 |
process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
|
| 290 |
|
| 291 |
+
import time as _time
|
| 292 |
+
|
| 293 |
+
def _log_progress(task_num, total, start_time, recent_times, success_count, skipped_count, error_count):
|
| 294 |
+
"""Print progress with rolling ETA every 10 tokens."""
|
| 295 |
+
if (task_num + 1) % 10 == 0 and recent_times:
|
| 296 |
+
avg_time = sum(recent_times) / len(recent_times)
|
| 297 |
+
remaining = total - (task_num + 1)
|
| 298 |
+
eta_seconds = avg_time * remaining
|
| 299 |
+
eta_hours = eta_seconds / 3600
|
| 300 |
+
wall_elapsed = _time.perf_counter() - start_time
|
| 301 |
+
speed = (task_num + 1) / wall_elapsed
|
| 302 |
+
tqdm.write(
|
| 303 |
+
f" [PROGRESS] {task_num+1}/{total} | "
|
| 304 |
+
f"Speed: {speed:.1f} tok/s ({speed*60:.0f} tok/min) | "
|
| 305 |
+
f"Avg: {avg_time:.1f}s/tok | "
|
| 306 |
+
f"ETA: {eta_hours:.1f}h | "
|
| 307 |
+
f"OK: {success_count} Skip: {skipped_count} Err: {error_count}"
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
if args.num_workers == 1:
|
| 311 |
print("INFO: Single-threaded mode...")
|
| 312 |
_init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
|
| 313 |
+
start_time = _time.perf_counter()
|
| 314 |
+
recent_times = []
|
| 315 |
+
for task_num, task in enumerate(tqdm(tasks, desc="Caching", unit="tok")):
|
| 316 |
+
t0 = _time.perf_counter()
|
| 317 |
result = process_fn(task)
|
| 318 |
+
elapsed = _time.perf_counter() - t0
|
| 319 |
+
recent_times.append(elapsed)
|
| 320 |
+
if len(recent_times) > 50:
|
| 321 |
+
recent_times.pop(0)
|
| 322 |
if result['status'] == 'success':
|
| 323 |
success_count += 1
|
| 324 |
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
|
|
|
| 327 |
else:
|
| 328 |
error_count += 1
|
| 329 |
tqdm.write(f"ERROR: {result['mint'][:16]} - {result['error']}")
|
| 330 |
+
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 331 |
else:
|
| 332 |
print(f"INFO: Running with {args.num_workers} workers...")
|
| 333 |
+
start_time = _time.perf_counter()
|
| 334 |
+
recent_times = []
|
| 335 |
with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor:
|
| 336 |
futures = {executor.submit(process_fn, task): task for task in tasks}
|
| 337 |
+
for task_num, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Caching", unit="tok")):
|
| 338 |
+
t0 = _time.perf_counter()
|
| 339 |
try:
|
| 340 |
result = future.result(timeout=300)
|
| 341 |
+
elapsed = _time.perf_counter() - t0
|
| 342 |
+
recent_times.append(elapsed)
|
| 343 |
+
if len(recent_times) > 50:
|
| 344 |
+
recent_times.pop(0)
|
| 345 |
if result['status'] == 'success':
|
| 346 |
success_count += 1
|
| 347 |
class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
|
|
|
|
| 352 |
except Exception as e:
|
| 353 |
error_count += 1
|
| 354 |
tqdm.write(f"WORKER ERROR: {e}")
|
| 355 |
+
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 356 |
|
| 357 |
print("INFO: Building metadata...")
|
| 358 |
file_class_map = {}
|