pedroapfilho commited on
Commit
eb72117
·
unverified ·
1 Parent(s): ad1969b

Add max files limit to HF dataset download

Browse files

List repo files first via HfApi, take first N audio files, download
individually with hf_hub_download, and symlink into working dir.
Avoids downloading all 18k files when only 50 are needed for training.

Files changed (2) hide show
  1. app.py +9 -3
  2. src/lora_trainer.py +49 -20
app.py CHANGED
@@ -323,7 +323,7 @@ def lora_upload_and_scan(files, training_state):
323
  return f"Error: {e}", training_state or {}
324
 
325
 
326
- def lora_download_hf(dataset_id, hf_token, training_state):
327
  """Download HuggingFace dataset and scan for audio files."""
328
  try:
329
  if not dataset_id or not dataset_id.strip():
@@ -332,7 +332,7 @@ def lora_download_hf(dataset_id, hf_token, training_state):
332
  token = hf_token.strip() if hf_token else None
333
 
334
  local_dir, dl_status = download_hf_dataset(
335
- dataset_id.strip(), hf_token=token
336
  )
337
 
338
  if not local_dir:
@@ -768,6 +768,12 @@ def create_ui():
768
  label="HF Token (optional, for private repos)",
769
  type="password",
770
  )
 
 
 
 
 
 
771
  lora_hf_btn = gr.Button(
772
  "Download & Scan", variant="primary"
773
  )
@@ -899,7 +905,7 @@ def create_ui():
899
 
900
  lora_hf_btn.click(
901
  fn=lora_download_hf,
902
- inputs=[lora_hf_id, lora_hf_token, training_state],
903
  outputs=[lora_source_status, training_state],
904
  )
905
 
 
323
  return f"Error: {e}", training_state or {}
324
 
325
 
326
+ def lora_download_hf(dataset_id, hf_token, max_files, training_state):
327
  """Download HuggingFace dataset and scan for audio files."""
328
  try:
329
  if not dataset_id or not dataset_id.strip():
 
332
  token = hf_token.strip() if hf_token else None
333
 
334
  local_dir, dl_status = download_hf_dataset(
335
+ dataset_id.strip(), max_files=int(max_files), hf_token=token
336
  )
337
 
338
  if not local_dir:
 
768
  label="HF Token (optional, for private repos)",
769
  type="password",
770
  )
771
+ lora_hf_max = gr.Number(
772
+ label="Max files",
773
+ value=50,
774
+ minimum=1,
775
+ precision=0,
776
+ )
777
  lora_hf_btn = gr.Button(
778
  "Download & Scan", variant="primary"
779
  )
 
905
 
906
  lora_hf_btn.click(
907
  fn=lora_download_hf,
908
+ inputs=[lora_hf_id, lora_hf_token, lora_hf_max, training_state],
909
  outputs=[lora_source_status, training_state],
910
  )
911
 
src/lora_trainer.py CHANGED
@@ -11,48 +11,77 @@ from typing import Optional, Tuple
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
- AUDIO_EXTENSIONS = ["*.wav", "*.mp3", "*.flac", "*.ogg", "*.opus"]
15
 
16
 
17
  def download_hf_dataset(
18
  dataset_id: str,
 
19
  hf_token: Optional[str] = None,
20
  ) -> Tuple[str, str]:
21
  """
22
- Download an audio dataset from HuggingFace Hub.
23
 
24
- Uses snapshot_download without local_dir so HF's built-in cache
25
- handles storage. On HF Spaces this is co-located with HF storage,
26
- avoiding unnecessary file duplication.
27
 
28
  Args:
29
  dataset_id: HuggingFace dataset repo ID (e.g. "pedroapfilho/lofi-tracks")
 
30
  hf_token: Optional HuggingFace token for private repos
31
 
32
  Returns:
33
- Tuple of (cached_dir, status_message)
34
  """
35
  try:
36
- from huggingface_hub import snapshot_download
37
 
38
- logger.info(f"Fetching dataset '{dataset_id}' via HF cache...")
 
39
 
40
- cached_dir = snapshot_download(
41
- repo_id=dataset_id,
42
- repo_type="dataset",
43
- token=hf_token or None,
44
- allow_patterns=AUDIO_EXTENSIONS,
45
- )
 
 
 
 
 
 
 
 
 
 
46
 
47
- audio_count = sum(
48
- 1
49
- for ext in AUDIO_EXTENSIONS
50
- for _ in Path(cached_dir).rglob(ext)
51
  )
52
 
53
- status = f"Loaded {audio_count} audio files from {dataset_id}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  logger.info(status)
55
- return cached_dir, status
56
 
57
  except ImportError:
58
  msg = "huggingface_hub is not installed. Run: pip install huggingface_hub"
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
+ AUDIO_SUFFIXES = {".wav", ".mp3", ".flac", ".ogg", ".opus"}
15
 
16
 
17
  def download_hf_dataset(
18
  dataset_id: str,
19
+ max_files: int = 50,
20
  hf_token: Optional[str] = None,
21
  ) -> Tuple[str, str]:
22
  """
23
+ Download a subset of audio files from a HuggingFace dataset repo.
24
 
25
+ Lists repo contents first, picks the first N audio files,
26
+ then downloads them individually to the HF cache.
 
27
 
28
  Args:
29
  dataset_id: HuggingFace dataset repo ID (e.g. "pedroapfilho/lofi-tracks")
30
+ max_files: Maximum number of audio files to download
31
  hf_token: Optional HuggingFace token for private repos
32
 
33
  Returns:
34
+ Tuple of (output_dir, status_message)
35
  """
36
  try:
37
+ from huggingface_hub import HfApi, hf_hub_download
38
 
39
+ api = HfApi()
40
+ token = hf_token or None
41
 
42
+ logger.info(f"Listing files in '{dataset_id}'...")
43
+
44
+ all_files = [
45
+ f.rfilename
46
+ for f in api.list_repo_tree(
47
+ dataset_id, repo_type="dataset", token=token, recursive=True
48
+ )
49
+ if hasattr(f, "rfilename")
50
+ and Path(f.rfilename).suffix.lower() in AUDIO_SUFFIXES
51
+ ]
52
+
53
+ total_available = len(all_files)
54
+ selected = all_files[:max_files]
55
+
56
+ if not selected:
57
+ return "", f"No audio files found in {dataset_id}"
58
 
59
+ logger.info(
60
+ f"Downloading {len(selected)}/{total_available} audio files..."
 
 
61
  )
62
 
63
+ output_dir = Path("lora_training") / "hf" / dataset_id.replace("/", "_")
64
+ output_dir.mkdir(parents=True, exist_ok=True)
65
+
66
+ for i, filename in enumerate(selected):
67
+ logger.info(f" [{i + 1}/{len(selected)}] {filename}")
68
+ cached_path = hf_hub_download(
69
+ repo_id=dataset_id,
70
+ filename=filename,
71
+ repo_type="dataset",
72
+ token=token,
73
+ )
74
+ # Symlink from cache into our working dir so scan_directory finds them
75
+ dest = output_dir / Path(filename).name
76
+ if not dest.exists():
77
+ dest.symlink_to(cached_path)
78
+
79
+ status = (
80
+ f"Downloaded {len(selected)} of {total_available} "
81
+ f"audio files from {dataset_id}"
82
+ )
83
  logger.info(status)
84
+ return str(output_dir), status
85
 
86
  except ImportError:
87
  msg = "huggingface_hub is not installed. Run: pip install huggingface_hub"