Reza2kn commited on
Commit
67ba0d5
·
verified ·
1 Parent(s): 2a60ca8

Update app for GPU-aware model loading and dataset fixes

Browse files
.gitignore CHANGED
@@ -30,6 +30,11 @@ build/
30
  chizzler_cache/
31
  CommonVoice24-FA/
32
  .commonvoice_upload_checkpoint.json
 
 
 
 
 
33
  *.ogg
34
 
35
  # macOS
 
30
  chizzler_cache/
31
  CommonVoice24-FA/
32
  .commonvoice_upload_checkpoint.json
33
+ commonvoice_upload.pid
34
+ commonvoice_upload.log
35
+ commonvoice_progress.log
36
+ commonvoice_progress.pid
37
+ .commonvoice_progress_state.json
38
  *.ogg
39
 
40
  # macOS
app.py CHANGED
@@ -180,10 +180,10 @@ def select_device() -> torch.device:
180
  return torch.device("cpu")
181
 
182
 
183
- def initialize_models():
184
  log_progress("Initializing models...")
185
 
186
- device = select_device()
187
  log_progress(f"Using {device.type.upper()} for all operations", 2)
188
 
189
  log_progress("Loading Silero VAD model...", 2)
@@ -214,7 +214,29 @@ def initialize_models():
214
  return vad_model, utils, mpnet_model, config, device
215
 
216
 
217
- vad_model, vad_utils, mpnet_model, mpnet_config, device = initialize_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
 
220
  def ensure_mono(waveform: torch.Tensor) -> torch.Tensor:
@@ -283,6 +305,7 @@ def get_speech_timestamps(
283
  ) -> List[dict]:
284
  log_progress("Detecting speech segments...", enabled=log)
285
 
 
286
  (get_speech_timestamps_fn, _, _, _, _) = vad_utils
287
 
288
  speech_timestamps = get_speech_timestamps_fn(
@@ -332,7 +355,10 @@ def extract_speech_waveform(
332
 
333
 
334
  def denoise_audio_chunk(
335
- audio_tensor: torch.Tensor, chunk_size: int = 5 * DEFAULT_SAMPLE_RATE
 
 
 
336
  ) -> torch.Tensor:
337
  chunks = []
338
  for i in range(0, audio_tensor.size(1), chunk_size):
@@ -375,6 +401,7 @@ def process_waveform(
375
  max_gap: float = 4.0,
376
  log: bool = True,
377
  ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], str, bool]:
 
378
  if waveform.device != device:
379
  waveform = waveform.to(device)
380
  log_progress("Stage 1: Voice Activity Detection", 2, enabled=log)
@@ -414,7 +441,9 @@ def process_waveform(
414
 
415
  log_progress("Stage 2: MP-SENet denoising", 2, enabled=log)
416
  with torch.no_grad():
417
- denoised_waveform = denoise_audio_chunk(vad_waveform)
 
 
418
 
419
  return vad_waveform, denoised_waveform, "\n".join(details), True
420
 
@@ -430,6 +459,7 @@ def process_audio_file(
430
  ) -> Tuple[str, str, str, str]:
431
  log_progress(f"Processing: {Path(audio_path).name}")
432
  waveform, sample_rate = load_audio_file(audio_path)
 
433
 
434
  vad_waveform, denoised_waveform, details, has_speech = process_waveform(
435
  waveform, sample_rate, threshold=threshold, max_gap=max_gap, log=True
@@ -721,6 +751,9 @@ def process_dataset_and_push(
721
  if not dataset_id:
722
  return "Provide a dataset ID or URL."
723
 
 
 
 
724
  config = config.strip() or None
725
  split = split.strip()
726
  audio_column = audio_column.strip()
 
180
  return torch.device("cpu")
181
 
182
 
183
+ def initialize_models(device_override: Optional[torch.device] = None):
184
  log_progress("Initializing models...")
185
 
186
+ device = device_override or select_device()
187
  log_progress(f"Using {device.type.upper()} for all operations", 2)
188
 
189
  log_progress("Loading Silero VAD model...", 2)
 
214
  return vad_model, utils, mpnet_model, config, device
215
 
216
 
217
+ vad_model = None
218
+ vad_utils = None
219
+ mpnet_model = None
220
+ mpnet_config = None
221
+ device = None
222
+
223
+
224
+ def get_models():
225
+ global vad_model, vad_utils, mpnet_model, mpnet_config, device
226
+ desired_device = select_device()
227
+ if vad_model is None or mpnet_model is None or mpnet_config is None:
228
+ vad_model, vad_utils, mpnet_model, mpnet_config, device = (
229
+ initialize_models(desired_device)
230
+ )
231
+ return vad_model, vad_utils, mpnet_model, mpnet_config, device
232
+
233
+ if device is None or str(device) != str(desired_device):
234
+ log_progress(f"Moving models to {desired_device}...", 2)
235
+ vad_model = vad_model.to(desired_device)
236
+ mpnet_model = mpnet_model.to(desired_device)
237
+ device = desired_device
238
+
239
+ return vad_model, vad_utils, mpnet_model, mpnet_config, device
240
 
241
 
242
  def ensure_mono(waveform: torch.Tensor) -> torch.Tensor:
 
305
  ) -> List[dict]:
306
  log_progress("Detecting speech segments...", enabled=log)
307
 
308
+ vad_model, vad_utils, _, _, _ = get_models()
309
  (get_speech_timestamps_fn, _, _, _, _) = vad_utils
310
 
311
  speech_timestamps = get_speech_timestamps_fn(
 
355
 
356
 
357
  def denoise_audio_chunk(
358
+ audio_tensor: torch.Tensor,
359
+ mpnet_model: torch.nn.Module,
360
+ mpnet_config: AttrDict,
361
+ chunk_size: int = 5 * DEFAULT_SAMPLE_RATE,
362
  ) -> torch.Tensor:
363
  chunks = []
364
  for i in range(0, audio_tensor.size(1), chunk_size):
 
401
  max_gap: float = 4.0,
402
  log: bool = True,
403
  ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], str, bool]:
404
+ vad_model, vad_utils, mpnet_model, mpnet_config, device = get_models()
405
  if waveform.device != device:
406
  waveform = waveform.to(device)
407
  log_progress("Stage 1: Voice Activity Detection", 2, enabled=log)
 
441
 
442
  log_progress("Stage 2: MP-SENet denoising", 2, enabled=log)
443
  with torch.no_grad():
444
+ denoised_waveform = denoise_audio_chunk(
445
+ vad_waveform, mpnet_model, mpnet_config
446
+ )
447
 
448
  return vad_waveform, denoised_waveform, "\n".join(details), True
449
 
 
459
  ) -> Tuple[str, str, str, str]:
460
  log_progress(f"Processing: {Path(audio_path).name}")
461
  waveform, sample_rate = load_audio_file(audio_path)
462
+ _, _, _, mpnet_config, _ = get_models()
463
 
464
  vad_waveform, denoised_waveform, details, has_speech = process_waveform(
465
  waveform, sample_rate, threshold=threshold, max_gap=max_gap, log=True
 
751
  if not dataset_id:
752
  return "Provide a dataset ID or URL."
753
 
754
+ # Ensure models are loaded on the correct device before heavy processing.
755
+ get_models()
756
+
757
  config = config.strip() or None
758
  split = split.strip()
759
  audio_column = audio_column.strip()
scripts/publish_commonvoice_dataset.py CHANGED
@@ -2,6 +2,8 @@ import os
2
  from pathlib import Path
3
 
4
  import csv
 
 
5
 
6
  from datasets import Audio, Dataset, DatasetDict
7
  from huggingface_hub import HfApi
@@ -10,21 +12,17 @@ from huggingface_hub import HfApi
10
  DATASET_DIR = Path(os.getenv("COMMONVOICE_DIR", "CommonVoice24-FA")).resolve()
11
  SPLITS = [
12
  split.strip()
13
- for split in os.getenv("COMMONVOICE_SPLITS", "train,dev,test").split(",")
14
  if split.strip()
15
  ]
16
  REPO_OVERRIDE = os.getenv("COMMONVOICE_REPO")
17
  PRIVATE_REPO = os.getenv("COMMONVOICE_PRIVATE", "0") == "1"
18
 
19
- DROP_COLUMNS = {
20
- "client_id",
21
- "sentence_id",
22
- "sentence_domain",
23
- "accents",
24
- "variant",
25
- "segment",
26
- "path",
27
- }
28
 
29
 
30
  def load_env(path: Path) -> dict:
@@ -40,7 +38,8 @@ def load_env(path: Path) -> dict:
40
  return data
41
 
42
 
43
- def dataset_card(repo_id: str) -> str:
 
44
  return f"""---
45
  language:
46
  - fa
@@ -53,11 +52,16 @@ pretty_name: Common Voice 24 (FA) - Audio Column
53
  This dataset is a repackaging of the Persian subset of Mozilla Common Voice 24.0.
54
 
55
  ## What changed
56
- - Added an `audio` column pointing to `clips/*.mp3` for easy playback in the Hub UI.
57
- - Removed columns: `client_id`, `sentence_id`, `sentence_domain`, `accents`,
58
- `variant`, `segment`, and `path`.
59
- - Kept columns like `sentence`, `up_votes`, `down_votes`, `age`, `gender`, and
60
- `locale`.
 
 
 
 
 
61
 
62
  ## Source
63
  Original data: https://huggingface.co/datasets/mozilla-foundation/common_voice_24_0
@@ -85,11 +89,23 @@ def main() -> None:
85
  if not DATASET_DIR.exists():
86
  raise SystemExit(f"Dataset dir not found: {DATASET_DIR}")
87
 
 
 
 
 
 
 
 
 
88
  data_files = {}
89
- for split in SPLITS:
90
- tsv_path = DATASET_DIR / f"{split}.tsv"
91
- if tsv_path.exists():
92
- data_files[split] = str(tsv_path)
 
 
 
 
93
 
94
  if not data_files:
95
  raise SystemExit(
@@ -104,14 +120,33 @@ def main() -> None:
104
  repo_id, repo_type="dataset", private=PRIVATE_REPO, exist_ok=True
105
  )
106
 
 
 
 
 
 
 
 
107
  def tsv_generator(path: str):
108
  with open(path, "r", encoding="utf-8", errors="replace") as handle:
109
  reader = csv.reader(handle, delimiter="\t")
110
- header = next(reader)
 
 
 
 
111
  for row in reader:
112
  if len(row) != len(header):
113
  continue
114
- yield dict(zip(header, row))
 
 
 
 
 
 
 
 
115
 
116
  dataset_splits = {}
117
  for split, path in data_files.items():
@@ -121,20 +156,9 @@ def main() -> None:
121
 
122
  dataset = DatasetDict(dataset_splits)
123
 
124
- def add_audio(batch):
125
- return {
126
- "audio": [f"clips/{path}" for path in batch["path"]]
127
- }
128
-
129
- dataset = dataset.map(add_audio, batched=True)
130
  dataset = dataset.cast_column("audio", Audio())
131
-
132
  for split, split_ds in dataset.items():
133
- columns_to_drop = [
134
- col for col in split_ds.column_names if col in DROP_COLUMNS
135
- ]
136
- if columns_to_drop:
137
- dataset[split] = split_ds.remove_columns(columns_to_drop)
138
 
139
  current_dir = os.getcwd()
140
  os.chdir(str(DATASET_DIR))
@@ -144,7 +168,7 @@ def main() -> None:
144
  os.chdir(current_dir)
145
 
146
  api.upload_file(
147
- path_or_fileobj=dataset_card(repo_id).encode("utf-8"),
148
  path_in_repo="README.md",
149
  repo_id=repo_id,
150
  repo_type="dataset",
 
2
  from pathlib import Path
3
 
4
  import csv
5
+ import re
6
+ import sys
7
 
8
  from datasets import Audio, Dataset, DatasetDict
9
  from huggingface_hub import HfApi
 
12
  DATASET_DIR = Path(os.getenv("COMMONVOICE_DIR", "CommonVoice24-FA")).resolve()
13
  SPLITS = [
14
  split.strip()
15
+ for split in os.getenv("COMMONVOICE_SPLITS", "").split(",")
16
  if split.strip()
17
  ]
18
  REPO_OVERRIDE = os.getenv("COMMONVOICE_REPO")
19
  PRIVATE_REPO = os.getenv("COMMONVOICE_PRIVATE", "0") == "1"
20
 
21
+ REQUIRED_COLUMNS = {"path", "sentence"}
22
+ csv.field_size_limit(min(sys.maxsize, 10**7))
23
+ PREFIX_RE = re.compile(r"^common_voice_fa_(\d+)\.mp3$")
24
+ BUCKET_COUNT = int(os.getenv("COMMONVOICE_BUCKETS", "100"))
25
+ BUCKET_WIDTH = max(2, len(str(max(BUCKET_COUNT - 1, 0))))
 
 
 
 
26
 
27
 
28
  def load_env(path: Path) -> dict:
 
38
  return data
39
 
40
 
41
+ def dataset_card(repo_id: str, split_names: list[str]) -> str:
42
+ splits = ", ".join(split_names)
43
  return f"""---
44
  language:
45
  - fa
 
52
  This dataset is a repackaging of the Persian subset of Mozilla Common Voice 24.0.
53
 
54
  ## What changed
55
+ - Added an `audio` column pointing to `clips/<bucket>/*.mp3` for easy playback in the Hub UI.
56
+ - Only kept `audio` and `sentence` columns (in that order).
57
+
58
+ ## Splits
59
+ {splits}
60
+
61
+ ## Notes
62
+ Additional TSV files that do not include audio paths (e.g. reports or sentence
63
+ metadata) are kept as raw files in the repo but are not exposed as dataset
64
+ splits.
65
 
66
  ## Source
67
  Original data: https://huggingface.co/datasets/mozilla-foundation/common_voice_24_0
 
89
  if not DATASET_DIR.exists():
90
  raise SystemExit(f"Dataset dir not found: {DATASET_DIR}")
91
 
92
+ tsv_files = sorted(DATASET_DIR.glob("*.tsv"))
93
+ if SPLITS:
94
+ tsv_files = [
95
+ DATASET_DIR / f"{name}.tsv"
96
+ for name in SPLITS
97
+ if (DATASET_DIR / f"{name}.tsv").exists()
98
+ ]
99
+
100
  data_files = {}
101
+ for path in tsv_files:
102
+ with path.open("r", encoding="utf-8", errors="replace") as handle:
103
+ reader = csv.reader(handle, delimiter="\t")
104
+ header = next(reader, [])
105
+ if not REQUIRED_COLUMNS.issubset(header):
106
+ continue
107
+ split_name = path.stem
108
+ data_files[split_name] = str(path)
109
 
110
  if not data_files:
111
  raise SystemExit(
 
120
  repo_id, repo_type="dataset", private=PRIVATE_REPO, exist_ok=True
121
  )
122
 
123
+ def bucket_for_clip(clip_path: str) -> str:
124
+ match = PREFIX_RE.match(clip_path)
125
+ if not match:
126
+ return "misc"
127
+ clip_id = int(match.group(1))
128
+ return f"{clip_id % BUCKET_COUNT:0{BUCKET_WIDTH}d}"
129
+
130
  def tsv_generator(path: str):
131
  with open(path, "r", encoding="utf-8", errors="replace") as handle:
132
  reader = csv.reader(handle, delimiter="\t")
133
+ header = next(reader, [])
134
+ if not REQUIRED_COLUMNS.issubset(header):
135
+ return
136
+ path_idx = header.index("path")
137
+ sentence_idx = header.index("sentence")
138
  for row in reader:
139
  if len(row) != len(header):
140
  continue
141
+ clip_path = row[path_idx].strip()
142
+ sentence = row[sentence_idx].strip()
143
+ if not clip_path:
144
+ continue
145
+ bucket = bucket_for_clip(clip_path)
146
+ yield {
147
+ "audio": f"clips/{bucket}/{clip_path}",
148
+ "sentence": sentence,
149
+ }
150
 
151
  dataset_splits = {}
152
  for split, path in data_files.items():
 
156
 
157
  dataset = DatasetDict(dataset_splits)
158
 
 
 
 
 
 
 
159
  dataset = dataset.cast_column("audio", Audio())
 
160
  for split, split_ds in dataset.items():
161
+ dataset[split] = split_ds.select_columns(["audio", "sentence"])
 
 
 
 
162
 
163
  current_dir = os.getcwd()
164
  os.chdir(str(DATASET_DIR))
 
168
  os.chdir(current_dir)
169
 
170
  api.upload_file(
171
+ path_or_fileobj=dataset_card(repo_id, sorted(data_files)).encode("utf-8"),
172
  path_in_repo="README.md",
173
  repo_id=repo_id,
174
  repo_type="dataset",
scripts/upload_commonvoice_chunks.py CHANGED
@@ -1,9 +1,15 @@
1
  import json
2
  import os
3
  import re
 
4
  from pathlib import Path
5
 
6
- from huggingface_hub import CommitOperationAdd, HfApi
 
 
 
 
 
7
 
8
 
9
  DATASET_DIR = Path(os.getenv("COMMONVOICE_DIR", "CommonVoice24-FA"))
@@ -14,6 +20,12 @@ REPO_OVERRIDE = os.getenv("COMMONVOICE_REPO")
14
  PREFIX_RE = re.compile(r"^common_voice_fa_(\d+)\.mp3$")
15
  CHUNK_SIZE = int(os.getenv("COMMONVOICE_CHUNK_SIZE", "2000"))
16
  MAX_CHUNKS = int(os.getenv("COMMONVOICE_MAX_CHUNKS", "0"))
 
 
 
 
 
 
17
 
18
 
19
  def load_env(path: Path) -> dict:
@@ -33,8 +45,20 @@ def load_env(path: Path) -> dict:
33
 
34
  def load_checkpoint(path: Path) -> dict:
35
  if not path.exists():
36
- return {"metadata_uploaded": False, "prefixes": []}
37
- return json.loads(path.read_text())
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def save_checkpoint(path: Path, data: dict) -> None:
@@ -52,6 +76,82 @@ def get_clip_files(clip_dir: Path) -> list[Path]:
52
  return sorted(files)
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def main() -> None:
56
  env = load_env(Path(".env"))
57
  token = (
@@ -73,6 +173,13 @@ def main() -> None:
73
  api.create_repo(repo_id, repo_type="dataset", exist_ok=True)
74
 
75
  checkpoint = load_checkpoint(CHECKPOINT_FILE)
 
 
 
 
 
 
 
76
 
77
  if not checkpoint.get("metadata_uploaded"):
78
  api.upload_folder(
@@ -88,6 +195,8 @@ def main() -> None:
88
  checkpoint["metadata_uploaded"] = True
89
  save_checkpoint(CHECKPOINT_FILE, checkpoint)
90
 
 
 
91
  clip_dir = DATASET_DIR / "clips"
92
  clip_files = get_clip_files(clip_dir)
93
  total = len(clip_files)
@@ -101,12 +210,13 @@ def main() -> None:
101
  batch = clip_files[start:end]
102
  operations = [
103
  CommitOperationAdd(
104
- path_in_repo=f"clips/{path.name}",
105
  path_or_fileobj=str(path),
106
  )
107
  for path in batch
108
  ]
109
- api.create_commit(
 
110
  repo_id=repo_id,
111
  repo_type="dataset",
112
  operations=operations,
 
1
  import json
2
  import os
3
  import re
4
+ import time
5
  from pathlib import Path
6
 
7
+ from huggingface_hub import (
8
+ CommitOperationAdd,
9
+ CommitOperationCopy,
10
+ CommitOperationDelete,
11
+ HfApi,
12
+ )
13
 
14
 
15
  DATASET_DIR = Path(os.getenv("COMMONVOICE_DIR", "CommonVoice24-FA"))
 
20
  PREFIX_RE = re.compile(r"^common_voice_fa_(\d+)\.mp3$")
21
  CHUNK_SIZE = int(os.getenv("COMMONVOICE_CHUNK_SIZE", "2000"))
22
  MAX_CHUNKS = int(os.getenv("COMMONVOICE_MAX_CHUNKS", "0"))
23
+ BUCKET_COUNT = int(os.getenv("COMMONVOICE_BUCKETS", "100"))
24
+ BUCKET_WIDTH = max(2, len(str(max(BUCKET_COUNT - 1, 0))))
25
+ MOVE_BATCH_SIZE = int(os.getenv("COMMONVOICE_MOVE_BATCH", "100"))
26
+ MIGRATE_EXISTING = os.getenv("COMMONVOICE_MIGRATE", "1") == "1"
27
+ COMMIT_RETRIES = int(os.getenv("COMMONVOICE_COMMIT_RETRIES", "3"))
28
+ COMMIT_SLEEP = float(os.getenv("COMMONVOICE_COMMIT_SLEEP", "5"))
29
 
30
 
31
  def load_env(path: Path) -> dict:
 
45
 
46
  def load_checkpoint(path: Path) -> dict:
47
  if not path.exists():
48
+ return {
49
+ "metadata_uploaded": False,
50
+ "prefixes": [],
51
+ "clip_index": 0,
52
+ "bucketed": False,
53
+ "bucket_count": BUCKET_COUNT,
54
+ }
55
+ data = json.loads(path.read_text())
56
+ data.setdefault("metadata_uploaded", False)
57
+ data.setdefault("prefixes", [])
58
+ data.setdefault("clip_index", 0)
59
+ data.setdefault("bucketed", False)
60
+ data.setdefault("bucket_count", BUCKET_COUNT)
61
+ return data
62
 
63
 
64
  def save_checkpoint(path: Path, data: dict) -> None:
 
76
  return sorted(files)
77
 
78
 
79
+ def bucket_for_filename(filename: str) -> str:
80
+ match = PREFIX_RE.match(filename)
81
+ if not match:
82
+ return "misc"
83
+ clip_id = int(match.group(1))
84
+ return f"{clip_id % BUCKET_COUNT:0{BUCKET_WIDTH}d}"
85
+
86
+
87
+ def bucketed_repo_path(filename: str) -> str:
88
+ bucket = bucket_for_filename(filename)
89
+ return f"clips/{bucket}/{filename}"
90
+
91
+
92
+ def create_commit_with_retry(api: HfApi, **kwargs) -> None:
93
+ for attempt in range(1, COMMIT_RETRIES + 1):
94
+ try:
95
+ api.create_commit(**kwargs)
96
+ return
97
+ except Exception as exc:
98
+ if attempt >= COMMIT_RETRIES:
99
+ raise
100
+ print(
101
+ "Commit failed, retrying "
102
+ f"({attempt}/{COMMIT_RETRIES}): {exc}"
103
+ )
104
+ time.sleep(COMMIT_SLEEP)
105
+
106
+
107
+ def migrate_root_clips(
108
+ api: HfApi, repo_id: str, checkpoint: dict
109
+ ) -> None:
110
+ if checkpoint.get("bucketed"):
111
+ return
112
+ if not MIGRATE_EXISTING:
113
+ return
114
+
115
+ repo_files = api.list_repo_files(repo_id, repo_type="dataset")
116
+ root_clips = [
117
+ path
118
+ for path in repo_files
119
+ if path.startswith("clips/")
120
+ and path.count("/") == 1
121
+ and PREFIX_RE.match(Path(path).name)
122
+ ]
123
+ if not root_clips:
124
+ checkpoint["bucketed"] = True
125
+ save_checkpoint(CHECKPOINT_FILE, checkpoint)
126
+ return
127
+
128
+ for start in range(0, len(root_clips), MOVE_BATCH_SIZE):
129
+ batch = root_clips[start:start + MOVE_BATCH_SIZE]
130
+ operations = []
131
+ for path in batch:
132
+ new_path = bucketed_repo_path(Path(path).name)
133
+ operations.append(
134
+ CommitOperationCopy(
135
+ src_path_in_repo=path,
136
+ path_in_repo=new_path,
137
+ )
138
+ )
139
+ operations.append(CommitOperationDelete(path_in_repo=path))
140
+ create_commit_with_retry(
141
+ api,
142
+ repo_id=repo_id,
143
+ repo_type="dataset",
144
+ operations=operations,
145
+ commit_message=(
146
+ "Move Common Voice clips into bucketed subfolders"
147
+ ),
148
+ )
149
+
150
+ checkpoint["bucketed"] = True
151
+ checkpoint["bucket_count"] = BUCKET_COUNT
152
+ save_checkpoint(CHECKPOINT_FILE, checkpoint)
153
+
154
+
155
  def main() -> None:
156
  env = load_env(Path(".env"))
157
  token = (
 
173
  api.create_repo(repo_id, repo_type="dataset", exist_ok=True)
174
 
175
  checkpoint = load_checkpoint(CHECKPOINT_FILE)
176
+ if int(checkpoint.get("bucket_count", BUCKET_COUNT)) != BUCKET_COUNT:
177
+ raise SystemExit(
178
+ "Bucket count mismatch. "
179
+ f"Checkpoint has {checkpoint.get('bucket_count')}, "
180
+ f"env has {BUCKET_COUNT}. "
181
+ "Set COMMONVOICE_BUCKETS to match the existing upload."
182
+ )
183
 
184
  if not checkpoint.get("metadata_uploaded"):
185
  api.upload_folder(
 
195
  checkpoint["metadata_uploaded"] = True
196
  save_checkpoint(CHECKPOINT_FILE, checkpoint)
197
 
198
+ migrate_root_clips(api, repo_id, checkpoint)
199
+
200
  clip_dir = DATASET_DIR / "clips"
201
  clip_files = get_clip_files(clip_dir)
202
  total = len(clip_files)
 
210
  batch = clip_files[start:end]
211
  operations = [
212
  CommitOperationAdd(
213
+ path_in_repo=bucketed_repo_path(path.name),
214
  path_or_fileobj=str(path),
215
  )
216
  for path in batch
217
  ]
218
+ create_commit_with_retry(
219
+ api,
220
  repo_id=repo_id,
221
  repo_type="dataset",
222
  operations=operations,