from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path NORMAL_TOTAL_WORKERS = 12 SUPER_TOTAL_WORKERS = 24 MAX_SHARD_ROWS = 15_000 MAX_SHARDS = 15 def validate_total_workers(total_workers: int) -> int: value = int(total_workers) if value not in {NORMAL_TOTAL_WORKERS, SUPER_TOTAL_WORKERS}: raise ValueError( f"total_workers must be {NORMAL_TOTAL_WORKERS} or {SUPER_TOTAL_WORKERS}, got {value}." ) return value def compute_worker_split(total_workers: int) -> tuple[int, int]: total = validate_total_workers(total_workers) fetch_workers = (total * 5) // 6 parser_workers = total - fetch_workers if fetch_workers < 1 or parser_workers < 1: raise ValueError(f"Invalid worker split for total_workers={total}.") return fetch_workers, parser_workers @dataclass class CrawlerConfig: seed_urls: list[str] max_links_per_page: int = 250 request_timeout_seconds: float = 18.0 max_response_bytes: int = 3_000_000 user_agent: str = "HFDBContCrawler/1.0 (+https://huggingface.co/datasets)" seen_url_cache_size: int = 2_000_000 fetch_queue_size: int = 100_000 parse_queue_size: int = 25_000 record_queue_size: int = 50_000 report_every_seconds: float = 5.0 output_dir: Path = field( default_factory=lambda: Path(__file__).resolve().parents[1] / "shards" ) shard_size_rows: int = 10_000 max_shards: int = MAX_SHARDS parquet_compression: str = "zstd" parquet_compression_level: int = 9 enable_hf_upload: bool = False upload_incomplete_shards: bool = False hf_repo_id: str = "" hf_token: str = "" hf_repo_type: str = "dataset" hf_private_repo: bool = False hf_path_prefix: str = "crawl_shards" total_workers: int = NORMAL_TOTAL_WORKERS request_delay_global_seconds: float = 0.02 request_delay_per_domain_seconds: float = 0.0 same_site_delay_per_worker_seconds: float = 0.5 robots_cache_ttl_seconds: float = 3600.0 robots_fail_closed: bool = True robots_max_bytes: int = 300_000 fetch_workers: int = field(init=False) parser_workers: int = field(init=False) def __post_init__(self) -> None: self.seed_urls = [u.strip() for u in self.seed_urls if u and u.strip()] if not self.seed_urls: raise ValueError("At least one seed URL is required.") self.total_workers = validate_total_workers(self.total_workers) self.fetch_workers, self.parser_workers = compute_worker_split(self.total_workers) self.shard_size_rows = int(self.shard_size_rows) if self.shard_size_rows < 1 or self.shard_size_rows > MAX_SHARD_ROWS: raise ValueError(f"shard_size_rows must be between 1 and {MAX_SHARD_ROWS}.") self.max_shards = int(self.max_shards) if self.max_shards < 1 or self.max_shards > MAX_SHARDS: raise ValueError(f"max_shards must be between 1 and {MAX_SHARDS}.") self.output_dir = Path(self.output_dir).expanduser() self.hf_repo_id = self.hf_repo_id.strip() self.hf_token = self.hf_token.strip() self.hf_path_prefix = self.hf_path_prefix.strip() or "crawl_shards" if self.enable_hf_upload: if not self.hf_repo_id: raise ValueError("hf_repo_id is required when enable_hf_upload=True.") if not self.hf_token: raise ValueError("hf_token is required when enable_hf_upload=True.")