Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| import contextlib | |
| from collections import deque | |
| from typing import Any | |
| from urllib.parse import urlsplit | |
| import aiohttp | |
| from .config import CrawlerConfig | |
| from .fetch import fetch_url | |
| from .models import CrawlStats, FetchResult | |
| from .parse import parse_page | |
| from .rate_limit import RequestRateLimiter | |
| from .robots import RobotsPolicy | |
| from .shards import ParquetShardWriter, ShardLimitReached | |
| from .utils import has_binary_extension, normalize_url | |
| class AsyncCrawler: | |
| def __init__(self, config: CrawlerConfig): | |
| self.config = config | |
| self.stats = CrawlStats() | |
| self.stop_event = asyncio.Event() | |
| self.stop_reason = "" | |
| self.fetch_queue: asyncio.Queue[str | None] = asyncio.Queue( | |
| maxsize=config.fetch_queue_size | |
| ) | |
| self.parse_queue: asyncio.Queue[FetchResult | None] = asyncio.Queue( | |
| maxsize=config.parse_queue_size | |
| ) | |
| self.record_queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue( | |
| maxsize=config.record_queue_size | |
| ) | |
| self.seen_urls: set[str] = set() | |
| self.seen_order: deque[str] = deque() | |
| self.seen_lock = asyncio.Lock() | |
| self.counter_lock = asyncio.Lock() | |
| self.active_fetchers = 0 | |
| self.active_parsers = 0 | |
| self.writer = ParquetShardWriter(config=config, stats=self.stats) | |
| self.rate_limiter: RequestRateLimiter | None = None | |
| self.robots_policy: RobotsPolicy | None = None | |
| async def run(self) -> None: | |
| await self.writer.initialize() | |
| for seed in self.config.seed_urls: | |
| await self.try_enqueue(seed) | |
| connector = aiohttp.TCPConnector( | |
| limit=max(200, self.config.fetch_workers * 4), | |
| ttl_dns_cache=300, | |
| ) | |
| timeout = aiohttp.ClientTimeout(total=self.config.request_timeout_seconds) | |
| async with aiohttp.ClientSession( | |
| connector=connector, | |
| timeout=timeout, | |
| headers={"User-Agent": self.config.user_agent}, | |
| ) as session: | |
| self.rate_limiter = RequestRateLimiter( | |
| global_interval_seconds=self.config.request_delay_global_seconds, | |
| per_domain_interval_seconds=self.config.request_delay_per_domain_seconds, | |
| ) | |
| self.robots_policy = RobotsPolicy( | |
| session=session, | |
| user_agent=self.config.user_agent, | |
| cache_ttl_seconds=self.config.robots_cache_ttl_seconds, | |
| fail_closed=self.config.robots_fail_closed, | |
| max_bytes=self.config.robots_max_bytes, | |
| ) | |
| fetchers = [ | |
| asyncio.create_task(self.fetcher_worker(worker_id=i, session=session)) | |
| for i in range(self.config.fetch_workers) | |
| ] | |
| parsers = [ | |
| asyncio.create_task(self.parser_worker(worker_id=i)) | |
| for i in range(self.config.parser_workers) | |
| ] | |
| writer_task = asyncio.create_task(self.writer.consume(self.record_queue)) | |
| reporter_task = asyncio.create_task(self.progress_reporter()) | |
| try: | |
| await self.wait_until_complete(writer_task) | |
| await self._graceful_shutdown(fetchers, parsers, writer_task) | |
| except ShardLimitReached: | |
| self.stop_reason = "shard_cap_reached" | |
| self.stop_event.set() | |
| await self._hard_shutdown(fetchers, parsers, writer_task) | |
| finally: | |
| reporter_task.cancel() | |
| with contextlib.suppress(asyncio.CancelledError): | |
| await reporter_task | |
| def request_stop(self, reason: str = "user_requested_stop") -> None: | |
| if not self.stop_reason: | |
| self.stop_reason = reason | |
| self.stop_event.set() | |
| async def wait_until_complete(self, writer_task: asyncio.Task[None]) -> None: | |
| while True: | |
| if writer_task.done(): | |
| exc = writer_task.exception() | |
| if exc is not None: | |
| raise exc | |
| return | |
| if self.stop_event.is_set(): | |
| if not self.stop_reason: | |
| self.stop_reason = "stop_event_set" | |
| if self._is_pipeline_idle(): | |
| return | |
| await asyncio.sleep(0.2) | |
| continue | |
| if self._is_pipeline_idle(): | |
| self.stop_reason = "frontier_exhausted" | |
| return | |
| await asyncio.sleep(0.5) | |
| async def _graceful_shutdown( | |
| self, | |
| fetchers: list[asyncio.Task[None]], | |
| parsers: list[asyncio.Task[None]], | |
| writer_task: asyncio.Task[None], | |
| ) -> None: | |
| for _ in fetchers: | |
| await self.fetch_queue.put(None) | |
| await asyncio.gather(*fetchers, return_exceptions=True) | |
| for _ in parsers: | |
| await self.parse_queue.put(None) | |
| await asyncio.gather(*parsers, return_exceptions=True) | |
| await self.record_queue.put(None) | |
| await writer_task | |
| async def _hard_shutdown( | |
| self, | |
| fetchers: list[asyncio.Task[None]], | |
| parsers: list[asyncio.Task[None]], | |
| writer_task: asyncio.Task[None], | |
| ) -> None: | |
| for task in fetchers + parsers: | |
| task.cancel() | |
| await asyncio.gather(*fetchers, *parsers, return_exceptions=True) | |
| if not writer_task.done(): | |
| writer_task.cancel() | |
| await asyncio.gather(writer_task, return_exceptions=True) | |
| async def progress_reporter(self) -> None: | |
| while True: | |
| await asyncio.sleep(self.config.report_every_seconds) | |
| print( | |
| "[stats]" | |
| f" workers={self.config.total_workers}" | |
| f" split={self.config.fetch_workers}/{self.config.parser_workers}" | |
| f" queued={self.stats.queued_urls}" | |
| f" fetched={self.stats.fetch_reserved}" | |
| f" fetch_ok={self.stats.fetch_succeeded}" | |
| f" fetch_fail={self.stats.fetch_failed}" | |
| f" parsed={self.stats.parsed_pages}" | |
| f" parse_fail={self.stats.parse_failed}" | |
| f" robots_blocked={self.stats.robots_blocked}" | |
| f" rows={self.stats.stored_rows}" | |
| f" shards={self.stats.written_shards}/{self.config.max_shards}" | |
| f" tok_shards={self.stats.tokenized_shards}" | |
| f" tok_rows={self.stats.tokenized_rows}" | |
| f" tok_total={self.stats.tokenized_tokens}" | |
| f" uploaded={self.stats.uploaded_shards}" | |
| f" fetch_q={self.fetch_queue.qsize()}" | |
| f" parse_q={self.parse_queue.qsize()}" | |
| f" record_q={self.record_queue.qsize()}" | |
| ) | |
| async def fetcher_worker(self, worker_id: int, session: aiohttp.ClientSession) -> None: | |
| del worker_id | |
| assert self.rate_limiter is not None | |
| assert self.robots_policy is not None | |
| loop = asyncio.get_running_loop() | |
| last_domain = "" | |
| last_request_started = 0.0 | |
| while True: | |
| url = await self.fetch_queue.get() | |
| if url is None: | |
| self.fetch_queue.task_done() | |
| return | |
| slot_reserved = await self.reserve_fetch_slot() | |
| if not slot_reserved: | |
| self.fetch_queue.task_done() | |
| continue | |
| self.active_fetchers += 1 | |
| try: | |
| requested_domain = (urlsplit(url).hostname or "").lower().strip(".") | |
| if requested_domain and requested_domain == last_domain: | |
| elapsed = loop.time() - last_request_started | |
| wait = self.config.same_site_delay_per_worker_seconds - elapsed | |
| if wait > 0: | |
| await asyncio.sleep(wait) | |
| if requested_domain: | |
| last_domain = requested_domain | |
| last_request_started = loop.time() | |
| outcome = await fetch_url( | |
| session, | |
| url, | |
| config=self.config, | |
| mark_seen=self._mark_seen, | |
| rate_limiter=self.rate_limiter, | |
| robots_policy=self.robots_policy, | |
| ) | |
| if outcome.robots_blocked: | |
| self.stats.robots_blocked += 1 | |
| if outcome.result is not None: | |
| self.stats.fetch_succeeded += 1 | |
| if outcome.result.html: | |
| await self.parse_queue.put(outcome.result) | |
| else: | |
| self.stats.fetch_failed += 1 | |
| finally: | |
| self.active_fetchers -= 1 | |
| self.fetch_queue.task_done() | |
| async def parser_worker(self, worker_id: int) -> None: | |
| del worker_id | |
| while True: | |
| item = await self.parse_queue.get() | |
| if item is None: | |
| self.parse_queue.task_done() | |
| return | |
| self.active_parsers += 1 | |
| try: | |
| record, links = parse_page(item) | |
| if record is not None: | |
| await self.record_queue.put(record) | |
| self.stats.parsed_pages += 1 | |
| extracted = 0 | |
| for link in links: | |
| if extracted >= self.config.max_links_per_page: | |
| break | |
| if await self.try_enqueue(link): | |
| extracted += 1 | |
| self.stats.extracted_links += extracted | |
| except Exception: | |
| self.stats.parse_failed += 1 | |
| finally: | |
| self.active_parsers -= 1 | |
| self.parse_queue.task_done() | |
| async def reserve_fetch_slot(self) -> bool: | |
| async with self.counter_lock: | |
| if self.stop_event.is_set(): | |
| return False | |
| self.stats.fetch_reserved += 1 | |
| return True | |
| async def try_enqueue(self, raw_url: str) -> bool: | |
| if self.stop_event.is_set(): | |
| return False | |
| normalized = normalize_url(raw_url) | |
| if not normalized: | |
| self.stats.dropped_urls += 1 | |
| return False | |
| if has_binary_extension(normalized): | |
| self.stats.dropped_urls += 1 | |
| return False | |
| async with self.seen_lock: | |
| if self.config.seen_url_cache_size > 0 and normalized in self.seen_urls: | |
| return False | |
| self._remember_seen_locked(normalized) | |
| self.stats.queued_urls += 1 | |
| await self.fetch_queue.put(normalized) | |
| return True | |
| async def _mark_seen(self, url: str) -> None: | |
| async with self.seen_lock: | |
| self._remember_seen_locked(url) | |
| def _remember_seen_locked(self, url: str) -> None: | |
| if self.config.seen_url_cache_size <= 0: | |
| return | |
| if url in self.seen_urls: | |
| return | |
| self.seen_urls.add(url) | |
| self.seen_order.append(url) | |
| while len(self.seen_order) > self.config.seen_url_cache_size: | |
| expired = self.seen_order.popleft() | |
| self.seen_urls.discard(expired) | |
| def _is_pipeline_idle(self) -> bool: | |
| return ( | |
| self.fetch_queue.empty() | |
| and self.parse_queue.empty() | |
| and self.active_fetchers == 0 | |
| and self.active_parsers == 0 | |
| ) | |