from __future__ import annotations import asyncio from datetime import datetime, timezone from pathlib import Path from typing import Any import pyarrow as pa import pyarrow.parquet as pq from .config import CrawlerConfig from .models import CrawlStats from .tokenizer import LiveShardTokenizer from .upload import HfShardUploader class ShardLimitReached(RuntimeError): pass PARQUET_SCHEMA = pa.schema( [ ("text", pa.string()), ("url", pa.string()), ("domain", pa.string()), ("timestamp", pa.string()), ] ) class ParquetShardWriter: def __init__(self, config: CrawlerConfig, stats: CrawlStats): self.config = config self.stats = stats self.buffer: list[dict[str, Any]] = [] self.shard_index = 0 self.uploader: HfShardUploader | None = None self.live_tokenizer = LiveShardTokenizer() async def initialize(self) -> None: self.config.output_dir.mkdir(parents=True, exist_ok=True) if not self.config.enable_hf_upload: return self.uploader = HfShardUploader( repo_id=self.config.hf_repo_id, token=self.config.hf_token, repo_type=self.config.hf_repo_type, private_repo=self.config.hf_private_repo, path_prefix=self.config.hf_path_prefix, ) await self.uploader.initialize() async def consume(self, record_queue: asyncio.Queue[dict[str, Any] | None]) -> None: while True: item = await record_queue.get() if item is None: record_queue.task_done() break try: self.buffer.append(item) if len(self.buffer) >= self.config.shard_size_rows: await self.flush() finally: record_queue.task_done() if self.buffer: should_flush_incomplete = ( (not self.config.enable_hf_upload) or self.config.upload_incomplete_shards ) if should_flush_incomplete: await self.flush() else: self.buffer = [] async def flush(self) -> None: if not self.buffer: return if self.shard_index >= self.config.max_shards: raise ShardLimitReached(f"Reached shard cap of {self.config.max_shards}.") rows = self.buffer self.buffer = [] normalized_rows = [ { "text": str(row.get("text", "")), "url": str(row.get("url", "")), "domain": str(row.get("domain", "")), "timestamp": str(row.get("timestamp", "")), } for row in rows if row.get("text") ] if not normalized_rows: return timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") shard_name = f"shard-{timestamp}-{self.shard_index:04d}.parquet" shard_path = self.config.output_dir / shard_name table = pa.Table.from_pylist(normalized_rows, schema=PARQUET_SCHEMA) await asyncio.to_thread( pq.write_table, table, shard_path, compression=self.config.parquet_compression, compression_level=self.config.parquet_compression_level, use_dictionary=True, ) self.shard_index += 1 self.stats.written_shards = self.shard_index self.stats.stored_rows += len(normalized_rows) token_rows, token_count = await asyncio.to_thread( self.live_tokenizer.tokenize_shard_text, shard_path ) self.stats.tokenized_shards += 1 self.stats.tokenized_rows += token_rows self.stats.tokenized_tokens += token_count if self.config.enable_hf_upload: ok = await self._upload_and_delete(shard_path, rows=len(normalized_rows)) if ok: self.stats.uploaded_shards += 1 if self.shard_index >= self.config.max_shards: raise ShardLimitReached(f"Reached shard cap of {self.config.max_shards}.") async def _upload_and_delete(self, shard_path: Path, rows: int) -> bool: if self.uploader is None: raise RuntimeError("Uploader not initialized.") return await self.uploader.upload_and_delete(shard_path, rows)