czyoung commited on
Commit
e895a63
·
verified ·
1 Parent(s): 689597f

Create ParquetScheduler.py

Browse files
Files changed (1) hide show
  1. ParquetScheduler.py +174 -0
ParquetScheduler.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://huggingface.co/spaces/hysts-samples/save-user-preferences
2
+ # Credits to @@hysts
3
+ import datetime
4
+ import json
5
+ import shutil
6
+ import tempfile
7
+ import uuid
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Union
10
+
11
+ import pyarrow as pa
12
+ import pyarrow.parquet as pq
13
+ from huggingface_hub import CommitScheduler
14
+ from huggingface_hub.hf_api import HfApi
15
+
16
+ #######################
17
+ # Parquet scheduler #
18
+ # Run in scheduler.py #
19
+ #######################
20
+
21
+
22
+ class ParquetScheduler(CommitScheduler):
23
+ """
24
+ Usage: configure the scheduler with a repo id. Once started, you can add data to be uploaded to the Hub. 1 `.append`
25
+ call will result in 1 row in your final dataset.
26
+
27
+ ```py
28
+ # Start scheduler
29
+ >>> scheduler = ParquetScheduler(repo_id="my-parquet-dataset")
30
+
31
+ # Append some data to be uploaded
32
+ >>> scheduler.append({...})
33
+ >>> scheduler.append({...})
34
+ >>> scheduler.append({...})
35
+ ```
36
+
37
+ The scheduler will automatically infer the schema from the data it pushes.
38
+ Optionally, you can manually set the schema yourself:
39
+
40
+ ```py
41
+ >>> scheduler = ParquetScheduler(
42
+ ... repo_id="my-parquet-dataset",
43
+ ... schema={
44
+ ... "prompt": {"_type": "Value", "dtype": "string"},
45
+ ... "negative_prompt": {"_type": "Value", "dtype": "string"},
46
+ ... "guidance_scale": {"_type": "Value", "dtype": "int64"},
47
+ ... "image": {"_type": "Image"},
48
+ ... },
49
+ ... )
50
+
51
+ See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of
52
+ possible values.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ *,
58
+ repo_id: str,
59
+ schema: Optional[Dict[str, Dict[str, str]]] = None,
60
+ every: Union[int, float] = 5,
61
+ path_in_repo: Optional[str] = "data",
62
+ repo_type: Optional[str] = "dataset",
63
+ revision: Optional[str] = None,
64
+ private: bool = False,
65
+ token: Optional[str] = None,
66
+ allow_patterns: Union[List[str], str, None] = None,
67
+ ignore_patterns: Union[List[str], str, None] = None,
68
+ hf_api: Optional[HfApi] = None,
69
+ ) -> None:
70
+ super().__init__(
71
+ repo_id=repo_id,
72
+ folder_path="dummy", # not used by the scheduler
73
+ every=every,
74
+ path_in_repo=path_in_repo,
75
+ repo_type=repo_type,
76
+ revision=revision,
77
+ private=private,
78
+ token=token,
79
+ allow_patterns=allow_patterns,
80
+ ignore_patterns=ignore_patterns,
81
+ hf_api=hf_api,
82
+ )
83
+
84
+ self._rows: List[Dict[str, Any]] = []
85
+ self._schema = schema
86
+
87
+ def append(self, row: Dict[str, Any]) -> None:
88
+ """Add a new item to be uploaded."""
89
+ with self.lock:
90
+ self._rows.append(row)
91
+
92
+ def push_to_hub(self):
93
+ # Check for new rows to push
94
+ with self.lock:
95
+ rows = self._rows
96
+ self._rows = []
97
+ if not rows:
98
+ return
99
+ print(f"Got {len(rows)} item(s) to commit.")
100
+
101
+ # Load images + create 'features' config for datasets library
102
+ schema: Dict[str, Dict] = self._schema or {}
103
+ path_to_cleanup: List[Path] = []
104
+ for row in rows:
105
+ for key, value in row.items():
106
+ # Infer schema (for `datasets` library)
107
+ if key not in schema:
108
+ schema[key] = _infer_schema(key, value)
109
+
110
+ # Load binary files if necessary
111
+ if schema[key]["_type"] in ("Image", "Audio"):
112
+ # It's an image or audio: we load the bytes and remember to cleanup the file
113
+ file_path = Path(value)
114
+ if file_path.is_file():
115
+ row[key] = {
116
+ "path": file_path.name,
117
+ "bytes": file_path.read_bytes(),
118
+ }
119
+ path_to_cleanup.append(file_path)
120
+
121
+ # Complete rows if needed
122
+ for row in rows:
123
+ for feature in schema:
124
+ if feature not in row:
125
+ row[feature] = None
126
+
127
+ # Export items to Arrow format
128
+ table = pa.Table.from_pylist(rows)
129
+
130
+ # Add metadata (used by datasets library)
131
+ table = table.replace_schema_metadata(
132
+ {"huggingface": json.dumps({"info": {"features": schema}})}
133
+ )
134
+
135
+ # Write to parquet file
136
+ archive_file = tempfile.NamedTemporaryFile()
137
+ pq.write_table(table, archive_file.name)
138
+
139
+ # Upload
140
+ self.api.upload_file(
141
+ repo_id=self.repo_id,
142
+ repo_type=self.repo_type,
143
+ revision=self.revision,
144
+ path_in_repo=f"{uuid.uuid4()}.parquet",
145
+ path_or_fileobj=archive_file.name,
146
+ )
147
+ print(f"Commit completed.")
148
+
149
+ # Cleanup
150
+ archive_file.close()
151
+ for path in path_to_cleanup:
152
+ path.unlink(missing_ok=True)
153
+
154
+
155
+ def _infer_schema(key: str, value: Any) -> Dict[str, str]:
156
+ """
157
+ Infer schema for the `datasets` library.
158
+
159
+ See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value.
160
+ """
161
+ if "image" in key:
162
+ return {"_type": "Image"}
163
+ if "audio" in key:
164
+ return {"_type": "Audio"}
165
+ if isinstance(value, int):
166
+ return {"_type": "Value", "dtype": "int64"}
167
+ if isinstance(value, float):
168
+ return {"_type": "Value", "dtype": "float64"}
169
+ if isinstance(value, bool):
170
+ return {"_type": "Value", "dtype": "bool"}
171
+ if isinstance(value, bytes):
172
+ return {"_type": "Value", "dtype": "binary"}
173
+ # Otherwise in last resort => convert it to a string
174
+ return {"_type": "Value", "dtype": "string"}