File size: 4,393 Bytes
f55f92e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f99323c
378a0c0
f55f92e
 
 
 
 
 
 
 
 
 
 
 
f99323c
 
 
 
 
 
 
 
f55f92e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import annotations

import asyncio
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import pyarrow as pa
import pyarrow.parquet as pq

from .config import CrawlerConfig
from .models import CrawlStats
from .tokenizer import LiveShardTokenizer
from .upload import HfShardUploader


class ShardLimitReached(RuntimeError):
    pass


PARQUET_SCHEMA = pa.schema(
    [
        ("text", pa.string()),
        ("url", pa.string()),
        ("domain", pa.string()),
        ("timestamp", pa.string()),
    ]
)


class ParquetShardWriter:
    def __init__(self, config: CrawlerConfig, stats: CrawlStats):
        self.config = config
        self.stats = stats
        self.buffer: list[dict[str, Any]] = []
        self.shard_index = 0
        self.uploader: HfShardUploader | None = None
        self.live_tokenizer = LiveShardTokenizer()

    async def initialize(self) -> None:
        self.config.output_dir.mkdir(parents=True, exist_ok=True)
        if not self.config.enable_hf_upload:
            return

        self.uploader = HfShardUploader(
            repo_id=self.config.hf_repo_id,
            token=self.config.hf_token,
            repo_type=self.config.hf_repo_type,
            private_repo=self.config.hf_private_repo,
            path_prefix=self.config.hf_path_prefix,
        )
        await self.uploader.initialize()

    async def consume(self, record_queue: asyncio.Queue[dict[str, Any] | None]) -> None:
        while True:
            item = await record_queue.get()

            if item is None:
                record_queue.task_done()
                break

            try:
                self.buffer.append(item)
                if len(self.buffer) >= self.config.shard_size_rows:
                    await self.flush()
            finally:
                record_queue.task_done()

        if self.buffer:
            should_flush_incomplete = (
                (not self.config.enable_hf_upload)
                or self.config.upload_incomplete_shards
            )
            if should_flush_incomplete:
                await self.flush()
            else:
                self.buffer = []

    async def flush(self) -> None:
        if not self.buffer:
            return

        if self.shard_index >= self.config.max_shards:
            raise ShardLimitReached(f"Reached shard cap of {self.config.max_shards}.")

        rows = self.buffer
        self.buffer = []
        normalized_rows = [
            {
                "text": str(row.get("text", "")),
                "url": str(row.get("url", "")),
                "domain": str(row.get("domain", "")),
                "timestamp": str(row.get("timestamp", "")),
            }
            for row in rows
            if row.get("text")
        ]
        if not normalized_rows:
            return

        timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
        shard_name = f"shard-{timestamp}-{self.shard_index:04d}.parquet"
        shard_path = self.config.output_dir / shard_name

        table = pa.Table.from_pylist(normalized_rows, schema=PARQUET_SCHEMA)
        await asyncio.to_thread(
            pq.write_table,
            table,
            shard_path,
            compression=self.config.parquet_compression,
            compression_level=self.config.parquet_compression_level,
            use_dictionary=True,
        )

        self.shard_index += 1
        self.stats.written_shards = self.shard_index
        self.stats.stored_rows += len(normalized_rows)
        token_rows, token_count = await asyncio.to_thread(
            self.live_tokenizer.tokenize_shard_text, shard_path
        )
        self.stats.tokenized_shards += 1
        self.stats.tokenized_rows += token_rows
        self.stats.tokenized_tokens += token_count

        if self.config.enable_hf_upload:
            ok = await self._upload_and_delete(shard_path, rows=len(normalized_rows))
            if ok:
                self.stats.uploaded_shards += 1

        if self.shard_index >= self.config.max_shards:
            raise ShardLimitReached(f"Reached shard cap of {self.config.max_shards}.")

    async def _upload_and_delete(self, shard_path: Path, rows: int) -> bool:
        if self.uploader is None:
            raise RuntimeError("Uploader not initialized.")
        return await self.uploader.upload_and_delete(shard_path, rows)