AutoWS / crawler /config.py
Roman190928's picture
Raise max shard cap to 15 in crawler config
09a0929 verified
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
NORMAL_TOTAL_WORKERS = 12
SUPER_TOTAL_WORKERS = 24
MAX_SHARD_ROWS = 15_000
MAX_SHARDS = 15
def validate_total_workers(total_workers: int) -> int:
value = int(total_workers)
if value not in {NORMAL_TOTAL_WORKERS, SUPER_TOTAL_WORKERS}:
raise ValueError(
f"total_workers must be {NORMAL_TOTAL_WORKERS} or {SUPER_TOTAL_WORKERS}, got {value}."
)
return value
def compute_worker_split(total_workers: int) -> tuple[int, int]:
total = validate_total_workers(total_workers)
fetch_workers = (total * 5) // 6
parser_workers = total - fetch_workers
if fetch_workers < 1 or parser_workers < 1:
raise ValueError(f"Invalid worker split for total_workers={total}.")
return fetch_workers, parser_workers
@dataclass
class CrawlerConfig:
seed_urls: list[str]
max_links_per_page: int = 250
request_timeout_seconds: float = 18.0
max_response_bytes: int = 3_000_000
user_agent: str = "HFDBContCrawler/1.0 (+https://huggingface.co/datasets)"
seen_url_cache_size: int = 2_000_000
fetch_queue_size: int = 100_000
parse_queue_size: int = 25_000
record_queue_size: int = 50_000
report_every_seconds: float = 5.0
output_dir: Path = field(
default_factory=lambda: Path(__file__).resolve().parents[1] / "shards"
)
shard_size_rows: int = 10_000
max_shards: int = MAX_SHARDS
parquet_compression: str = "zstd"
parquet_compression_level: int = 9
enable_hf_upload: bool = False
upload_incomplete_shards: bool = False
hf_repo_id: str = ""
hf_token: str = ""
hf_repo_type: str = "dataset"
hf_private_repo: bool = False
hf_path_prefix: str = "crawl_shards"
total_workers: int = NORMAL_TOTAL_WORKERS
request_delay_global_seconds: float = 0.02
request_delay_per_domain_seconds: float = 0.0
same_site_delay_per_worker_seconds: float = 0.5
robots_cache_ttl_seconds: float = 3600.0
robots_fail_closed: bool = True
robots_max_bytes: int = 300_000
fetch_workers: int = field(init=False)
parser_workers: int = field(init=False)
def __post_init__(self) -> None:
self.seed_urls = [u.strip() for u in self.seed_urls if u and u.strip()]
if not self.seed_urls:
raise ValueError("At least one seed URL is required.")
self.total_workers = validate_total_workers(self.total_workers)
self.fetch_workers, self.parser_workers = compute_worker_split(self.total_workers)
self.shard_size_rows = int(self.shard_size_rows)
if self.shard_size_rows < 1 or self.shard_size_rows > MAX_SHARD_ROWS:
raise ValueError(f"shard_size_rows must be between 1 and {MAX_SHARD_ROWS}.")
self.max_shards = int(self.max_shards)
if self.max_shards < 1 or self.max_shards > MAX_SHARDS:
raise ValueError(f"max_shards must be between 1 and {MAX_SHARDS}.")
self.output_dir = Path(self.output_dir).expanduser()
self.hf_repo_id = self.hf_repo_id.strip()
self.hf_token = self.hf_token.strip()
self.hf_path_prefix = self.hf_path_prefix.strip() or "crawl_shards"
if self.enable_hf_upload:
if not self.hf_repo_id:
raise ValueError("hf_repo_id is required when enable_hf_upload=True.")
if not self.hf_token:
raise ValueError("hf_token is required when enable_hf_upload=True.")