AutoWS / crawler /engine.py
Roman190928's picture
Update crawler: incomplete shard upload + UI options
378a0c0 verified
from __future__ import annotations
import asyncio
import contextlib
from collections import deque
from typing import Any
from urllib.parse import urlsplit
import aiohttp
from .config import CrawlerConfig
from .fetch import fetch_url
from .models import CrawlStats, FetchResult
from .parse import parse_page
from .rate_limit import RequestRateLimiter
from .robots import RobotsPolicy
from .shards import ParquetShardWriter, ShardLimitReached
from .utils import has_binary_extension, normalize_url
class AsyncCrawler:
def __init__(self, config: CrawlerConfig):
self.config = config
self.stats = CrawlStats()
self.stop_event = asyncio.Event()
self.stop_reason = ""
self.fetch_queue: asyncio.Queue[str | None] = asyncio.Queue(
maxsize=config.fetch_queue_size
)
self.parse_queue: asyncio.Queue[FetchResult | None] = asyncio.Queue(
maxsize=config.parse_queue_size
)
self.record_queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue(
maxsize=config.record_queue_size
)
self.seen_urls: set[str] = set()
self.seen_order: deque[str] = deque()
self.seen_lock = asyncio.Lock()
self.counter_lock = asyncio.Lock()
self.active_fetchers = 0
self.active_parsers = 0
self.writer = ParquetShardWriter(config=config, stats=self.stats)
self.rate_limiter: RequestRateLimiter | None = None
self.robots_policy: RobotsPolicy | None = None
async def run(self) -> None:
await self.writer.initialize()
for seed in self.config.seed_urls:
await self.try_enqueue(seed)
connector = aiohttp.TCPConnector(
limit=max(200, self.config.fetch_workers * 4),
ttl_dns_cache=300,
)
timeout = aiohttp.ClientTimeout(total=self.config.request_timeout_seconds)
async with aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={"User-Agent": self.config.user_agent},
) as session:
self.rate_limiter = RequestRateLimiter(
global_interval_seconds=self.config.request_delay_global_seconds,
per_domain_interval_seconds=self.config.request_delay_per_domain_seconds,
)
self.robots_policy = RobotsPolicy(
session=session,
user_agent=self.config.user_agent,
cache_ttl_seconds=self.config.robots_cache_ttl_seconds,
fail_closed=self.config.robots_fail_closed,
max_bytes=self.config.robots_max_bytes,
)
fetchers = [
asyncio.create_task(self.fetcher_worker(worker_id=i, session=session))
for i in range(self.config.fetch_workers)
]
parsers = [
asyncio.create_task(self.parser_worker(worker_id=i))
for i in range(self.config.parser_workers)
]
writer_task = asyncio.create_task(self.writer.consume(self.record_queue))
reporter_task = asyncio.create_task(self.progress_reporter())
try:
await self.wait_until_complete(writer_task)
await self._graceful_shutdown(fetchers, parsers, writer_task)
except ShardLimitReached:
self.stop_reason = "shard_cap_reached"
self.stop_event.set()
await self._hard_shutdown(fetchers, parsers, writer_task)
finally:
reporter_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await reporter_task
def request_stop(self, reason: str = "user_requested_stop") -> None:
if not self.stop_reason:
self.stop_reason = reason
self.stop_event.set()
async def wait_until_complete(self, writer_task: asyncio.Task[None]) -> None:
while True:
if writer_task.done():
exc = writer_task.exception()
if exc is not None:
raise exc
return
if self.stop_event.is_set():
if not self.stop_reason:
self.stop_reason = "stop_event_set"
if self._is_pipeline_idle():
return
await asyncio.sleep(0.2)
continue
if self._is_pipeline_idle():
self.stop_reason = "frontier_exhausted"
return
await asyncio.sleep(0.5)
async def _graceful_shutdown(
self,
fetchers: list[asyncio.Task[None]],
parsers: list[asyncio.Task[None]],
writer_task: asyncio.Task[None],
) -> None:
for _ in fetchers:
await self.fetch_queue.put(None)
await asyncio.gather(*fetchers, return_exceptions=True)
for _ in parsers:
await self.parse_queue.put(None)
await asyncio.gather(*parsers, return_exceptions=True)
await self.record_queue.put(None)
await writer_task
async def _hard_shutdown(
self,
fetchers: list[asyncio.Task[None]],
parsers: list[asyncio.Task[None]],
writer_task: asyncio.Task[None],
) -> None:
for task in fetchers + parsers:
task.cancel()
await asyncio.gather(*fetchers, *parsers, return_exceptions=True)
if not writer_task.done():
writer_task.cancel()
await asyncio.gather(writer_task, return_exceptions=True)
async def progress_reporter(self) -> None:
while True:
await asyncio.sleep(self.config.report_every_seconds)
print(
"[stats]"
f" workers={self.config.total_workers}"
f" split={self.config.fetch_workers}/{self.config.parser_workers}"
f" queued={self.stats.queued_urls}"
f" fetched={self.stats.fetch_reserved}"
f" fetch_ok={self.stats.fetch_succeeded}"
f" fetch_fail={self.stats.fetch_failed}"
f" parsed={self.stats.parsed_pages}"
f" parse_fail={self.stats.parse_failed}"
f" robots_blocked={self.stats.robots_blocked}"
f" rows={self.stats.stored_rows}"
f" shards={self.stats.written_shards}/{self.config.max_shards}"
f" tok_shards={self.stats.tokenized_shards}"
f" tok_rows={self.stats.tokenized_rows}"
f" tok_total={self.stats.tokenized_tokens}"
f" uploaded={self.stats.uploaded_shards}"
f" fetch_q={self.fetch_queue.qsize()}"
f" parse_q={self.parse_queue.qsize()}"
f" record_q={self.record_queue.qsize()}"
)
async def fetcher_worker(self, worker_id: int, session: aiohttp.ClientSession) -> None:
del worker_id
assert self.rate_limiter is not None
assert self.robots_policy is not None
loop = asyncio.get_running_loop()
last_domain = ""
last_request_started = 0.0
while True:
url = await self.fetch_queue.get()
if url is None:
self.fetch_queue.task_done()
return
slot_reserved = await self.reserve_fetch_slot()
if not slot_reserved:
self.fetch_queue.task_done()
continue
self.active_fetchers += 1
try:
requested_domain = (urlsplit(url).hostname or "").lower().strip(".")
if requested_domain and requested_domain == last_domain:
elapsed = loop.time() - last_request_started
wait = self.config.same_site_delay_per_worker_seconds - elapsed
if wait > 0:
await asyncio.sleep(wait)
if requested_domain:
last_domain = requested_domain
last_request_started = loop.time()
outcome = await fetch_url(
session,
url,
config=self.config,
mark_seen=self._mark_seen,
rate_limiter=self.rate_limiter,
robots_policy=self.robots_policy,
)
if outcome.robots_blocked:
self.stats.robots_blocked += 1
if outcome.result is not None:
self.stats.fetch_succeeded += 1
if outcome.result.html:
await self.parse_queue.put(outcome.result)
else:
self.stats.fetch_failed += 1
finally:
self.active_fetchers -= 1
self.fetch_queue.task_done()
async def parser_worker(self, worker_id: int) -> None:
del worker_id
while True:
item = await self.parse_queue.get()
if item is None:
self.parse_queue.task_done()
return
self.active_parsers += 1
try:
record, links = parse_page(item)
if record is not None:
await self.record_queue.put(record)
self.stats.parsed_pages += 1
extracted = 0
for link in links:
if extracted >= self.config.max_links_per_page:
break
if await self.try_enqueue(link):
extracted += 1
self.stats.extracted_links += extracted
except Exception:
self.stats.parse_failed += 1
finally:
self.active_parsers -= 1
self.parse_queue.task_done()
async def reserve_fetch_slot(self) -> bool:
async with self.counter_lock:
if self.stop_event.is_set():
return False
self.stats.fetch_reserved += 1
return True
async def try_enqueue(self, raw_url: str) -> bool:
if self.stop_event.is_set():
return False
normalized = normalize_url(raw_url)
if not normalized:
self.stats.dropped_urls += 1
return False
if has_binary_extension(normalized):
self.stats.dropped_urls += 1
return False
async with self.seen_lock:
if self.config.seen_url_cache_size > 0 and normalized in self.seen_urls:
return False
self._remember_seen_locked(normalized)
self.stats.queued_urls += 1
await self.fetch_queue.put(normalized)
return True
async def _mark_seen(self, url: str) -> None:
async with self.seen_lock:
self._remember_seen_locked(url)
def _remember_seen_locked(self, url: str) -> None:
if self.config.seen_url_cache_size <= 0:
return
if url in self.seen_urls:
return
self.seen_urls.add(url)
self.seen_order.append(url)
while len(self.seen_order) > self.config.seen_url_cache_size:
expired = self.seen_order.popleft()
self.seen_urls.discard(expired)
def _is_pipeline_idle(self) -> bool:
return (
self.fetch_queue.empty()
and self.parse_queue.empty()
and self.active_fetchers == 0
and self.active_parsers == 0
)