zirobtc commited on
Commit
fdd6265
·
1 Parent(s): 3a4bd03

Upload scripts/cache_dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/cache_dataset.py +47 -3
scripts/cache_dataset.py CHANGED
@@ -282,10 +282,33 @@ def main():
282
  )
283
  print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
284
 
285
- print(f"INFO: Starting to cache {len(tasks)} tokens...")
286
-
287
  success_count, skipped_count, error_count = 0, 0, 0
288
  class_distribution = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
290
 
291
  import time as _time
@@ -307,6 +330,10 @@ def main():
307
  f"OK: {success_count} Skip: {skipped_count} Err: {error_count}"
308
  )
309
 
 
 
 
 
310
  if args.num_workers == 1:
311
  print("INFO: Single-threaded mode...")
312
  _init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
@@ -326,7 +353,10 @@ def main():
326
  skipped_count += 1
327
  else:
328
  error_count += 1
329
- tqdm.write(f"ERROR: {result['mint'][:16]} - {result['error']}")
 
 
 
330
  _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
331
  else:
332
  print(f"INFO: Running with {args.num_workers} workers...")
@@ -349,11 +379,25 @@ def main():
349
  skipped_count += 1
350
  else:
351
  error_count += 1
 
 
 
 
 
352
  except Exception as e:
353
  error_count += 1
354
  tqdm.write(f"WORKER ERROR: {e}")
355
  _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
356
 
 
 
 
 
 
 
 
 
 
357
  print("INFO: Building metadata...")
358
  file_class_map = {}
359
  for f in sorted(output_dir.glob("sample_*.pt")):
 
282
  )
283
  print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
284
 
 
 
285
  success_count, skipped_count, error_count = 0, 0, 0
286
  class_distribution = {}
287
+
288
+ # --- Resume support: skip tokens that already have cached files ---
289
+ existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
290
+ if existing_files:
291
+ pre_resume = len(tasks)
292
+ filtered_tasks = []
293
+ already_cached = 0
294
+ for task in tasks:
295
+ mint_addr = task[1] # task = (idx, mint_addr, ...)
296
+ # Check if any file exists for this mint (context mode: sample_MINT_0.pt, raw mode: sample_MINT.pt)
297
+ mint_prefix = f"sample_{mint_addr[:16]}"
298
+ has_cached = any(ef.startswith(mint_prefix) for ef in existing_files)
299
+ if has_cached:
300
+ already_cached += 1
301
+ # Count existing files toward class distribution
302
+ cid = return_class_map.get(mint_addr)
303
+ if cid is not None:
304
+ class_distribution[cid] = class_distribution.get(cid, 0) + 1
305
+ success_count += 1
306
+ else:
307
+ filtered_tasks.append(task)
308
+ tasks = filtered_tasks
309
+ print(f"INFO: Resume: {already_cached} tokens already cached, {len(tasks)} remaining (was {pre_resume})")
310
+
311
+ print(f"INFO: Starting to cache {len(tasks)} tokens...")
312
  process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
313
 
314
  import time as _time
 
330
  f"OK: {success_count} Skip: {skipped_count} Err: {error_count}"
331
  )
332
 
333
+ # Error log file for diagnosing failures
334
+ error_log_path = Path(args.output_dir) / "cache_errors.log"
335
+ error_samples = [] # First 20 unique error messages
336
+
337
  if args.num_workers == 1:
338
  print("INFO: Single-threaded mode...")
339
  _init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
 
353
  skipped_count += 1
354
  else:
355
  error_count += 1
356
+ err_msg = result.get('error', 'unknown')
357
+ tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}")
358
+ if len(error_samples) < 20:
359
+ error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
360
  _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
361
  else:
362
  print(f"INFO: Running with {args.num_workers} workers...")
 
379
  skipped_count += 1
380
  else:
381
  error_count += 1
382
+ err_msg = result.get('error', 'unknown')
383
+ if len(error_samples) < 20:
384
+ error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
385
+ if error_count <= 5:
386
+ tqdm.write(f"ERROR: {result.get('mint', '?')[:16]} - {err_msg}")
387
  except Exception as e:
388
  error_count += 1
389
  tqdm.write(f"WORKER ERROR: {e}")
390
  _log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
391
 
392
+ # Write error log
393
+ if error_samples:
394
+ with open(error_log_path, 'w') as ef:
395
+ for i, es in enumerate(error_samples):
396
+ ef.write(f"=== Error {i+1} === Token: {es['mint']}\n")
397
+ ef.write(f"Error: {es['error']}\n")
398
+ ef.write(f"Traceback:\n{es['traceback']}\n\n")
399
+ print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}")
400
+
401
  print("INFO: Building metadata...")
402
  file_class_map = {}
403
  for f in sorted(output_dir.glob("sample_*.pt")):