zirobtc commited on
Commit
3a4bd03
·
1 Parent(s): 780c5db

Upload scripts/cache_dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/cache_dataset.py +37 -2
scripts/cache_dataset.py CHANGED
@@ -288,11 +288,37 @@ def main():
288
  class_distribution = {}
289
  process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  if args.num_workers == 1:
292
  print("INFO: Single-threaded mode...")
293
  _init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
294
- for task in tqdm(tasks, desc="Caching"):
 
 
 
295
  result = process_fn(task)
 
 
 
 
296
  if result['status'] == 'success':
297
  success_count += 1
298
  class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
@@ -301,13 +327,21 @@ def main():
301
  else:
302
  error_count += 1
303
  tqdm.write(f"ERROR: {result['mint'][:16]} - {result['error']}")
 
304
  else:
305
  print(f"INFO: Running with {args.num_workers} workers...")
 
 
306
  with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor:
307
  futures = {executor.submit(process_fn, task): task for task in tasks}
308
- for future in tqdm(as_completed(futures), total=len(futures), desc="Caching"):
 
309
  try:
310
  result = future.result(timeout=300)
 
 
 
 
311
  if result['status'] == 'success':
312
  success_count += 1
313
  class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
@@ -318,6 +352,7 @@ def main():
318
  except Exception as e:
319
  error_count += 1
320
  tqdm.write(f"WORKER ERROR: {e}")
 
321
 
322
  print("INFO: Building metadata...")
323
  file_class_map = {}
 
288
  class_distribution = {}
289
  process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
290
 
291
+ import time as _time
292
+
293
+ def _log_progress(task_num, total, start_time, recent_times, success_count, skipped_count, error_count):
294
+ """Print progress with rolling ETA every 10 tokens."""
295
+ if (task_num + 1) % 10 == 0 and recent_times:
296
+ avg_time = sum(recent_times) / len(recent_times)
297
+ remaining = total - (task_num + 1)
298
+ eta_seconds = avg_time * remaining
299
+ eta_hours = eta_seconds / 3600
300
+ wall_elapsed = _time.perf_counter() - start_time
301
+ speed = (task_num + 1) / wall_elapsed
302
+ tqdm.write(
303
+ f" [PROGRESS] {task_num+1}/{total} | "
304
+ f"Speed: {speed:.1f} tok/s ({speed*60:.0f} tok/min) | "
305
+ f"Avg: {avg_time:.1f}s/tok | "
306
+ f"ETA: {eta_hours:.1f}h | "
307
+ f"OK: {success_count} Skip: {skipped_count} Err: {error_count}"
308
+ )
309
+
310
  if args.num_workers == 1:
311
  print("INFO: Single-threaded mode...")
312
  _init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
313
+ start_time = _time.perf_counter()
314
+ recent_times = []
315
+ for task_num, task in enumerate(tqdm(tasks, desc="Caching", unit="tok")):
316
+ t0 = _time.perf_counter()
317
  result = process_fn(task)
318
+ elapsed = _time.perf_counter() - t0
319
+ recent_times.append(elapsed)
320
+ if len(recent_times) > 50:
321
+ recent_times.pop(0)
322
  if result['status'] == 'success':
323
  success_count += 1
324
  class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
 
327
  else:
328
  error_count += 1
329
  tqdm.write(f"ERROR: {result['mint'][:16]} - {result['error']}")
330
+ _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
331
  else:
332
  print(f"INFO: Running with {args.num_workers} workers...")
333
+ start_time = _time.perf_counter()
334
+ recent_times = []
335
  with ProcessPoolExecutor(max_workers=args.num_workers, initializer=_init_worker, initargs=(db_config, dataset_config, return_class_map, quality_scores_map)) as executor:
336
  futures = {executor.submit(process_fn, task): task for task in tasks}
337
+ for task_num, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Caching", unit="tok")):
338
+ t0 = _time.perf_counter()
339
  try:
340
  result = future.result(timeout=300)
341
+ elapsed = _time.perf_counter() - t0
342
+ recent_times.append(elapsed)
343
+ if len(recent_times) > 50:
344
+ recent_times.pop(0)
345
  if result['status'] == 'success':
346
  success_count += 1
347
  class_distribution[result['class_id']] = class_distribution.get(result['class_id'], 0) + 1
 
352
  except Exception as e:
353
  error_count += 1
354
  tqdm.write(f"WORKER ERROR: {e}")
355
+ _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
356
 
357
  print("INFO: Building metadata...")
358
  file_class_map = {}