Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from .config import CrawlerConfig | |
| from .models import CrawlStats | |
| from .tokenizer import LiveShardTokenizer | |
| from .upload import HfShardUploader | |
| class ShardLimitReached(RuntimeError): | |
| pass | |
| PARQUET_SCHEMA = pa.schema( | |
| [ | |
| ("text", pa.string()), | |
| ("url", pa.string()), | |
| ("domain", pa.string()), | |
| ("timestamp", pa.string()), | |
| ] | |
| ) | |
| class ParquetShardWriter: | |
| def __init__(self, config: CrawlerConfig, stats: CrawlStats): | |
| self.config = config | |
| self.stats = stats | |
| self.buffer: list[dict[str, Any]] = [] | |
| self.shard_index = 0 | |
| self.uploader: HfShardUploader | None = None | |
| self.live_tokenizer = LiveShardTokenizer() | |
| async def initialize(self) -> None: | |
| self.config.output_dir.mkdir(parents=True, exist_ok=True) | |
| if not self.config.enable_hf_upload: | |
| return | |
| self.uploader = HfShardUploader( | |
| repo_id=self.config.hf_repo_id, | |
| token=self.config.hf_token, | |
| repo_type=self.config.hf_repo_type, | |
| private_repo=self.config.hf_private_repo, | |
| path_prefix=self.config.hf_path_prefix, | |
| ) | |
| await self.uploader.initialize() | |
| async def consume(self, record_queue: asyncio.Queue[dict[str, Any] | None]) -> None: | |
| while True: | |
| item = await record_queue.get() | |
| if item is None: | |
| record_queue.task_done() | |
| break | |
| try: | |
| self.buffer.append(item) | |
| if len(self.buffer) >= self.config.shard_size_rows: | |
| await self.flush() | |
| finally: | |
| record_queue.task_done() | |
| if self.buffer: | |
| should_flush_incomplete = ( | |
| (not self.config.enable_hf_upload) | |
| or self.config.upload_incomplete_shards | |
| ) | |
| if should_flush_incomplete: | |
| await self.flush() | |
| else: | |
| self.buffer = [] | |
| async def flush(self) -> None: | |
| if not self.buffer: | |
| return | |
| if self.shard_index >= self.config.max_shards: | |
| raise ShardLimitReached(f"Reached shard cap of {self.config.max_shards}.") | |
| rows = self.buffer | |
| self.buffer = [] | |
| normalized_rows = [ | |
| { | |
| "text": str(row.get("text", "")), | |
| "url": str(row.get("url", "")), | |
| "domain": str(row.get("domain", "")), | |
| "timestamp": str(row.get("timestamp", "")), | |
| } | |
| for row in rows | |
| if row.get("text") | |
| ] | |
| if not normalized_rows: | |
| return | |
| timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") | |
| shard_name = f"shard-{timestamp}-{self.shard_index:04d}.parquet" | |
| shard_path = self.config.output_dir / shard_name | |
| table = pa.Table.from_pylist(normalized_rows, schema=PARQUET_SCHEMA) | |
| await asyncio.to_thread( | |
| pq.write_table, | |
| table, | |
| shard_path, | |
| compression=self.config.parquet_compression, | |
| compression_level=self.config.parquet_compression_level, | |
| use_dictionary=True, | |
| ) | |
| self.shard_index += 1 | |
| self.stats.written_shards = self.shard_index | |
| self.stats.stored_rows += len(normalized_rows) | |
| token_rows, token_count = await asyncio.to_thread( | |
| self.live_tokenizer.tokenize_shard_text, shard_path | |
| ) | |
| self.stats.tokenized_shards += 1 | |
| self.stats.tokenized_rows += token_rows | |
| self.stats.tokenized_tokens += token_count | |
| if self.config.enable_hf_upload: | |
| ok = await self._upload_and_delete(shard_path, rows=len(normalized_rows)) | |
| if ok: | |
| self.stats.uploaded_shards += 1 | |
| if self.shard_index >= self.config.max_shards: | |
| raise ShardLimitReached(f"Reached shard cap of {self.config.max_shards}.") | |
| async def _upload_and_delete(self, shard_path: Path, rows: int) -> bool: | |
| if self.uploader is None: | |
| raise RuntimeError("Uploader not initialized.") | |
| return await self.uploader.upload_and_delete(shard_path, rows) | |