pedroapfilho's picture
Use HF dataset repo as source of truth for dataset.json
6c32e21 unverified
from typing import List, Tuple
from .models import AudioSample
class LabelAllMixin:
"""Label all samples in the dataset."""
def label_all_samples(
self,
dit_handler,
llm_handler,
format_lyrics: bool = False,
transcribe_lyrics: bool = False,
skip_metas: bool = False,
only_unlabeled: bool = False,
max_count: int = 0,
progress_callback=None,
) -> Tuple[List[AudioSample], str]:
"""Label all samples in the dataset.
Args:
max_count: When > 0, stop after labeling this many samples (batch mode).
"""
if not self.samples:
return [], "❌ No samples to label. Please scan a directory first."
if only_unlabeled:
samples_to_label = [
(i, s) for i, s in enumerate(self.samples) if not s.labeled or not s.caption
]
else:
samples_to_label = [(i, s) for i, s in enumerate(self.samples)]
if not samples_to_label:
return self.samples, "✅ All samples already labeled"
batch_limit = max_count if max_count > 0 else len(samples_to_label)
samples_to_label = samples_to_label[:batch_limit]
success_count = 0
fail_count = 0
total = len(samples_to_label)
for idx, (i, sample) in enumerate(samples_to_label):
if progress_callback:
progress_callback(f"Labeling {idx+1}/{total}: {sample.filename}")
_, status = self.label_sample(
i,
dit_handler,
llm_handler,
format_lyrics,
transcribe_lyrics,
skip_metas,
progress_callback,
)
if "✅" in status:
success_count += 1
else:
fail_count += 1
total_labeled = sum(1 for s in self.samples if s.labeled)
total_samples = len(self.samples)
remaining = total_samples - total_labeled
status_msg = f"✅ Labeled {success_count} this batch"
if fail_count > 0:
status_msg += f" ({fail_count} failed)"
status_msg += f" | {total_labeled}/{total_samples} labeled total, {remaining} remaining"
return self.samples, status_msg