File size: 1,759 Bytes
f55f92e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __future__ import annotations

import asyncio
import contextlib
from pathlib import Path

from huggingface_hub import HfApi


class HfShardUploader:
    def __init__(
        self,
        *,
        repo_id: str,
        token: str,
        repo_type: str = "dataset",
        private_repo: bool = False,
        path_prefix: str = "crawl_shards",
    ) -> None:
        self.repo_id = repo_id.strip()
        self.token = token.strip()
        self.repo_type = repo_type
        self.private_repo = bool(private_repo)
        self.path_prefix = path_prefix.strip("/")
        self.api: HfApi | None = None

    async def initialize(self) -> None:
        self.api = HfApi(token=self.token or None)
        await asyncio.to_thread(
            self.api.create_repo,
            repo_id=self.repo_id,
            repo_type=self.repo_type,
            private=self.private_repo,
            exist_ok=True,
        )

    async def upload_and_delete(self, shard_path: Path, rows: int) -> bool:
        if self.api is None:
            raise RuntimeError("Uploader was not initialized.")

        if self.path_prefix:
            path_in_repo = f"{self.path_prefix}/{shard_path.name}"
        else:
            path_in_repo = shard_path.name

        try:
            await asyncio.to_thread(
                self.api.upload_file,
                path_or_fileobj=str(shard_path),
                path_in_repo=path_in_repo,
                repo_id=self.repo_id,
                repo_type=self.repo_type,
                commit_message=f"Add crawl shard {shard_path.name} ({rows} rows)",
            )
        except Exception:
            return False

        with contextlib.suppress(FileNotFoundError):
            shard_path.unlink()
        return True