add retry loop around load_dataset for transient HF Hub 5xx
Browse files- distill.py +16 -1
distill.py
CHANGED
|
@@ -293,7 +293,22 @@ class StreamingTextLoader:
|
|
| 293 |
from datasets import load_dataset
|
| 294 |
from datasets.distributed import split_dataset_by_node
|
| 295 |
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer)
|
| 298 |
ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)
|
| 299 |
self._ds = iter(ds)
|
|
|
|
| 293 |
from datasets import load_dataset
|
| 294 |
from datasets.distributed import split_dataset_by_node
|
| 295 |
|
| 296 |
+
# HF Hub occasionally returns 5xx during dataset metadata crawl. Retry.
|
| 297 |
+
last_err = None
|
| 298 |
+
for attempt in range(8):
|
| 299 |
+
try:
|
| 300 |
+
ds = load_dataset(name, split="train", streaming=True)
|
| 301 |
+
break
|
| 302 |
+
except Exception as e:
|
| 303 |
+
last_err = e
|
| 304 |
+
wait = min(2 ** attempt, 30)
|
| 305 |
+
log.warning(
|
| 306 |
+
f"load_dataset({name!r}) failed (attempt {attempt + 1}/8): "
|
| 307 |
+
f"{type(e).__name__}: {e}; sleeping {wait}s"
|
| 308 |
+
)
|
| 309 |
+
time.sleep(wait)
|
| 310 |
+
else:
|
| 311 |
+
raise RuntimeError(f"load_dataset failed after 8 retries") from last_err
|
| 312 |
ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer)
|
| 313 |
ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)
|
| 314 |
self._ds = iter(ds)
|