Upload scripts/cache_dataset.py with huggingface_hub
Browse files- scripts/cache_dataset.py +47 -3
scripts/cache_dataset.py
CHANGED
|
@@ -282,10 +282,33 @@ def main():
|
|
| 282 |
)
|
| 283 |
print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
|
| 284 |
|
| 285 |
-
print(f"INFO: Starting to cache {len(tasks)} tokens...")
|
| 286 |
-
|
| 287 |
success_count, skipped_count, error_count = 0, 0, 0
|
| 288 |
class_distribution = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
|
| 290 |
|
| 291 |
import time as _time
|
|
@@ -307,6 +330,10 @@ def main():
|
|
| 307 |
f"OK: {success_count} Skip: {skipped_count} Err: {error_count}"
|
| 308 |
)
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
if args.num_workers == 1:
|
| 311 |
print("INFO: Single-threaded mode...")
|
| 312 |
_init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
|
|
@@ -326,7 +353,10 @@ def main():
|
|
| 326 |
skipped_count += 1
|
| 327 |
else:
|
| 328 |
error_count += 1
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
| 330 |
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 331 |
else:
|
| 332 |
print(f"INFO: Running with {args.num_workers} workers...")
|
|
@@ -349,11 +379,25 @@ def main():
|
|
| 349 |
skipped_count += 1
|
| 350 |
else:
|
| 351 |
error_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
except Exception as e:
|
| 353 |
error_count += 1
|
| 354 |
tqdm.write(f"WORKER ERROR: {e}")
|
| 355 |
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
print("INFO: Building metadata...")
|
| 358 |
file_class_map = {}
|
| 359 |
for f in sorted(output_dir.glob("sample_*.pt")):
|
|
|
|
| 282 |
)
|
| 283 |
print(f"INFO: Total tasks: {len(tasks)} (expected ~{expected_files} output files, target ~{target_total})")
|
| 284 |
|
|
|
|
|
|
|
| 285 |
success_count, skipped_count, error_count = 0, 0, 0
|
| 286 |
class_distribution = {}
|
| 287 |
+
|
| 288 |
+
# --- Resume support: skip tokens that already have cached files ---
|
| 289 |
+
existing_files = set(f.name for f in output_dir.glob("sample_*.pt"))
|
| 290 |
+
if existing_files:
|
| 291 |
+
pre_resume = len(tasks)
|
| 292 |
+
filtered_tasks = []
|
| 293 |
+
already_cached = 0
|
| 294 |
+
for task in tasks:
|
| 295 |
+
mint_addr = task[1] # task = (idx, mint_addr, ...)
|
| 296 |
+
# Check if any file exists for this mint (context mode: sample_MINT_0.pt, raw mode: sample_MINT.pt)
|
| 297 |
+
mint_prefix = f"sample_{mint_addr[:16]}"
|
| 298 |
+
has_cached = any(ef.startswith(mint_prefix) for ef in existing_files)
|
| 299 |
+
if has_cached:
|
| 300 |
+
already_cached += 1
|
| 301 |
+
# Count existing files toward class distribution
|
| 302 |
+
cid = return_class_map.get(mint_addr)
|
| 303 |
+
if cid is not None:
|
| 304 |
+
class_distribution[cid] = class_distribution.get(cid, 0) + 1
|
| 305 |
+
success_count += 1
|
| 306 |
+
else:
|
| 307 |
+
filtered_tasks.append(task)
|
| 308 |
+
tasks = filtered_tasks
|
| 309 |
+
print(f"INFO: Resume: {already_cached} tokens already cached, {len(tasks)} remaining (was {pre_resume})")
|
| 310 |
+
|
| 311 |
+
print(f"INFO: Starting to cache {len(tasks)} tokens...")
|
| 312 |
process_fn = _process_single_token_context if args.cache_mode == "context" else _process_single_token_raw
|
| 313 |
|
| 314 |
import time as _time
|
|
|
|
| 330 |
f"OK: {success_count} Skip: {skipped_count} Err: {error_count}"
|
| 331 |
)
|
| 332 |
|
| 333 |
+
# Error log file for diagnosing failures
|
| 334 |
+
error_log_path = Path(args.output_dir) / "cache_errors.log"
|
| 335 |
+
error_samples = [] # First 20 unique error messages
|
| 336 |
+
|
| 337 |
if args.num_workers == 1:
|
| 338 |
print("INFO: Single-threaded mode...")
|
| 339 |
_init_worker(db_config, dataset_config, return_class_map, quality_scores_map)
|
|
|
|
| 353 |
skipped_count += 1
|
| 354 |
else:
|
| 355 |
error_count += 1
|
| 356 |
+
err_msg = result.get('error', 'unknown')
|
| 357 |
+
tqdm.write(f"ERROR: {result['mint'][:16]} - {err_msg}")
|
| 358 |
+
if len(error_samples) < 20:
|
| 359 |
+
error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
|
| 360 |
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 361 |
else:
|
| 362 |
print(f"INFO: Running with {args.num_workers} workers...")
|
|
|
|
| 379 |
skipped_count += 1
|
| 380 |
else:
|
| 381 |
error_count += 1
|
| 382 |
+
err_msg = result.get('error', 'unknown')
|
| 383 |
+
if len(error_samples) < 20:
|
| 384 |
+
error_samples.append({'mint': result.get('mint'), 'error': err_msg, 'traceback': result.get('traceback', '')})
|
| 385 |
+
if error_count <= 5:
|
| 386 |
+
tqdm.write(f"ERROR: {result.get('mint', '?')[:16]} - {err_msg}")
|
| 387 |
except Exception as e:
|
| 388 |
error_count += 1
|
| 389 |
tqdm.write(f"WORKER ERROR: {e}")
|
| 390 |
_log_progress(task_num, len(tasks), start_time, recent_times, success_count, skipped_count, error_count)
|
| 391 |
|
| 392 |
+
# Write error log
|
| 393 |
+
if error_samples:
|
| 394 |
+
with open(error_log_path, 'w') as ef:
|
| 395 |
+
for i, es in enumerate(error_samples):
|
| 396 |
+
ef.write(f"=== Error {i+1} === Token: {es['mint']}\n")
|
| 397 |
+
ef.write(f"Error: {es['error']}\n")
|
| 398 |
+
ef.write(f"Traceback:\n{es['traceback']}\n\n")
|
| 399 |
+
print(f"INFO: First {len(error_samples)} error tracebacks saved to {error_log_path}")
|
| 400 |
+
|
| 401 |
print("INFO: Building metadata...")
|
| 402 |
file_class_map = {}
|
| 403 |
for f in sorted(output_dir.glob("sample_*.pt")):
|