Spaces:
Running
Running
Add max files limit to HF dataset download
Browse filesList repo files first via HfApi, take first N audio files, download
individually with hf_hub_download, and symlink into working dir.
Avoids downloading all 18k files when only 50 are needed for training.
- app.py +9 -3
- src/lora_trainer.py +49 -20
app.py
CHANGED
|
@@ -323,7 +323,7 @@ def lora_upload_and_scan(files, training_state):
|
|
| 323 |
return f"Error: {e}", training_state or {}
|
| 324 |
|
| 325 |
|
| 326 |
-
def lora_download_hf(dataset_id, hf_token, training_state):
|
| 327 |
"""Download HuggingFace dataset and scan for audio files."""
|
| 328 |
try:
|
| 329 |
if not dataset_id or not dataset_id.strip():
|
|
@@ -332,7 +332,7 @@ def lora_download_hf(dataset_id, hf_token, training_state):
|
|
| 332 |
token = hf_token.strip() if hf_token else None
|
| 333 |
|
| 334 |
local_dir, dl_status = download_hf_dataset(
|
| 335 |
-
dataset_id.strip(), hf_token=token
|
| 336 |
)
|
| 337 |
|
| 338 |
if not local_dir:
|
|
@@ -768,6 +768,12 @@ def create_ui():
|
|
| 768 |
label="HF Token (optional, for private repos)",
|
| 769 |
type="password",
|
| 770 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
lora_hf_btn = gr.Button(
|
| 772 |
"Download & Scan", variant="primary"
|
| 773 |
)
|
|
@@ -899,7 +905,7 @@ def create_ui():
|
|
| 899 |
|
| 900 |
lora_hf_btn.click(
|
| 901 |
fn=lora_download_hf,
|
| 902 |
-
inputs=[lora_hf_id, lora_hf_token, training_state],
|
| 903 |
outputs=[lora_source_status, training_state],
|
| 904 |
)
|
| 905 |
|
|
|
|
| 323 |
return f"Error: {e}", training_state or {}
|
| 324 |
|
| 325 |
|
| 326 |
+
def lora_download_hf(dataset_id, hf_token, max_files, training_state):
|
| 327 |
"""Download HuggingFace dataset and scan for audio files."""
|
| 328 |
try:
|
| 329 |
if not dataset_id or not dataset_id.strip():
|
|
|
|
| 332 |
token = hf_token.strip() if hf_token else None
|
| 333 |
|
| 334 |
local_dir, dl_status = download_hf_dataset(
|
| 335 |
+
dataset_id.strip(), max_files=int(max_files), hf_token=token
|
| 336 |
)
|
| 337 |
|
| 338 |
if not local_dir:
|
|
|
|
| 768 |
label="HF Token (optional, for private repos)",
|
| 769 |
type="password",
|
| 770 |
)
|
| 771 |
+
lora_hf_max = gr.Number(
|
| 772 |
+
label="Max files",
|
| 773 |
+
value=50,
|
| 774 |
+
minimum=1,
|
| 775 |
+
precision=0,
|
| 776 |
+
)
|
| 777 |
lora_hf_btn = gr.Button(
|
| 778 |
"Download & Scan", variant="primary"
|
| 779 |
)
|
|
|
|
| 905 |
|
| 906 |
lora_hf_btn.click(
|
| 907 |
fn=lora_download_hf,
|
| 908 |
+
inputs=[lora_hf_id, lora_hf_token, lora_hf_max, training_state],
|
| 909 |
outputs=[lora_source_status, training_state],
|
| 910 |
)
|
| 911 |
|
src/lora_trainer.py
CHANGED
|
@@ -11,48 +11,77 @@ from typing import Optional, Tuple
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
|
| 17 |
def download_hf_dataset(
|
| 18 |
dataset_id: str,
|
|
|
|
| 19 |
hf_token: Optional[str] = None,
|
| 20 |
) -> Tuple[str, str]:
|
| 21 |
"""
|
| 22 |
-
Download
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
avoiding unnecessary file duplication.
|
| 27 |
|
| 28 |
Args:
|
| 29 |
dataset_id: HuggingFace dataset repo ID (e.g. "pedroapfilho/lofi-tracks")
|
|
|
|
| 30 |
hf_token: Optional HuggingFace token for private repos
|
| 31 |
|
| 32 |
Returns:
|
| 33 |
-
Tuple of (
|
| 34 |
"""
|
| 35 |
try:
|
| 36 |
-
from huggingface_hub import
|
| 37 |
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
for ext in AUDIO_EXTENSIONS
|
| 50 |
-
for _ in Path(cached_dir).rglob(ext)
|
| 51 |
)
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
logger.info(status)
|
| 55 |
-
return
|
| 56 |
|
| 57 |
except ImportError:
|
| 58 |
msg = "huggingface_hub is not installed. Run: pip install huggingface_hub"
|
|
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
+
AUDIO_SUFFIXES = {".wav", ".mp3", ".flac", ".ogg", ".opus"}
|
| 15 |
|
| 16 |
|
| 17 |
def download_hf_dataset(
|
| 18 |
dataset_id: str,
|
| 19 |
+
max_files: int = 50,
|
| 20 |
hf_token: Optional[str] = None,
|
| 21 |
) -> Tuple[str, str]:
|
| 22 |
"""
|
| 23 |
+
Download a subset of audio files from a HuggingFace dataset repo.
|
| 24 |
|
| 25 |
+
Lists repo contents first, picks the first N audio files,
|
| 26 |
+
then downloads them individually to the HF cache.
|
|
|
|
| 27 |
|
| 28 |
Args:
|
| 29 |
dataset_id: HuggingFace dataset repo ID (e.g. "pedroapfilho/lofi-tracks")
|
| 30 |
+
max_files: Maximum number of audio files to download
|
| 31 |
hf_token: Optional HuggingFace token for private repos
|
| 32 |
|
| 33 |
Returns:
|
| 34 |
+
Tuple of (output_dir, status_message)
|
| 35 |
"""
|
| 36 |
try:
|
| 37 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 38 |
|
| 39 |
+
api = HfApi()
|
| 40 |
+
token = hf_token or None
|
| 41 |
|
| 42 |
+
logger.info(f"Listing files in '{dataset_id}'...")
|
| 43 |
+
|
| 44 |
+
all_files = [
|
| 45 |
+
f.rfilename
|
| 46 |
+
for f in api.list_repo_tree(
|
| 47 |
+
dataset_id, repo_type="dataset", token=token, recursive=True
|
| 48 |
+
)
|
| 49 |
+
if hasattr(f, "rfilename")
|
| 50 |
+
and Path(f.rfilename).suffix.lower() in AUDIO_SUFFIXES
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
total_available = len(all_files)
|
| 54 |
+
selected = all_files[:max_files]
|
| 55 |
+
|
| 56 |
+
if not selected:
|
| 57 |
+
return "", f"No audio files found in {dataset_id}"
|
| 58 |
|
| 59 |
+
logger.info(
|
| 60 |
+
f"Downloading {len(selected)}/{total_available} audio files..."
|
|
|
|
|
|
|
| 61 |
)
|
| 62 |
|
| 63 |
+
output_dir = Path("lora_training") / "hf" / dataset_id.replace("/", "_")
|
| 64 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
|
| 66 |
+
for i, filename in enumerate(selected):
|
| 67 |
+
logger.info(f" [{i + 1}/{len(selected)}] {filename}")
|
| 68 |
+
cached_path = hf_hub_download(
|
| 69 |
+
repo_id=dataset_id,
|
| 70 |
+
filename=filename,
|
| 71 |
+
repo_type="dataset",
|
| 72 |
+
token=token,
|
| 73 |
+
)
|
| 74 |
+
# Symlink from cache into our working dir so scan_directory finds them
|
| 75 |
+
dest = output_dir / Path(filename).name
|
| 76 |
+
if not dest.exists():
|
| 77 |
+
dest.symlink_to(cached_path)
|
| 78 |
+
|
| 79 |
+
status = (
|
| 80 |
+
f"Downloaded {len(selected)} of {total_available} "
|
| 81 |
+
f"audio files from {dataset_id}"
|
| 82 |
+
)
|
| 83 |
logger.info(status)
|
| 84 |
+
return str(output_dir), status
|
| 85 |
|
| 86 |
except ImportError:
|
| 87 |
msg = "huggingface_hub is not installed. Run: pip install huggingface_hub"
|