Spaces:
Running
Running
Joseph Pollack
commited on
adds additional datasets for phrases
Browse files- interface.py +186 -62
interface.py
CHANGED
|
@@ -251,8 +251,13 @@ def start_voxtral_training(
|
|
| 251 |
yield line
|
| 252 |
|
| 253 |
|
| 254 |
-
def
|
| 255 |
-
"""Load phrases from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
Args:
|
| 258 |
language: Language code (e.g., 'en', 'de', 'fr', etc.)
|
|
@@ -262,43 +267,121 @@ def load_voxpopuli_phrases(language="en", max_phrases=None, split="train"):
|
|
| 262 |
Returns:
|
| 263 |
List of normalized text phrases
|
| 264 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
try:
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
# Load the specified language dataset
|
| 270 |
-
ds = load_dataset("facebook/voxpopuli", language, split=split)
|
| 271 |
|
| 272 |
-
# Extract normalized text phrases
|
| 273 |
phrases = []
|
|
|
|
| 274 |
for example in ds:
|
| 275 |
-
|
|
|
|
|
|
|
| 276 |
if text and len(text) > 10: # Filter out very short phrases
|
| 277 |
phrases.append(text)
|
|
|
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
phrases = random.sample(phrases, min(max_phrases, len(phrases)))
|
| 282 |
-
else:
|
| 283 |
-
# If no limit, shuffle the entire list
|
| 284 |
random.shuffle(phrases)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
except Exception as e:
|
| 289 |
-
print(f"
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
# Initialize phrases dynamically
|
| 300 |
-
|
| 301 |
-
ALL_PHRASES =
|
| 302 |
|
| 303 |
with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
| 304 |
has_gpu, gpu_msg = detect_nvidia_driver()
|
|
@@ -337,12 +420,15 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 337 |
|
| 338 |
jsonl_out = gr.Textbox(label="Dataset JSONL path", interactive=False, visible=True)
|
| 339 |
|
| 340 |
-
# Language selection for
|
| 341 |
-
|
| 342 |
-
choices=[
|
|
|
|
|
|
|
|
|
|
| 343 |
value="en",
|
| 344 |
-
label="
|
| 345 |
-
info="Select language for phrases from
|
| 346 |
)
|
| 347 |
|
| 348 |
# Recording grid with dynamic text readouts
|
|
@@ -383,8 +469,8 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 383 |
return [new_visible] + visibility_updates
|
| 384 |
|
| 385 |
def change_language(language):
|
| 386 |
-
"""Change the language and reload phrases from
|
| 387 |
-
new_phrases =
|
| 388 |
# Reset visible rows to 10
|
| 389 |
visible_count = min(10, len(new_phrases))
|
| 390 |
|
|
@@ -407,9 +493,9 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 407 |
return [new_phrases, visible_count] + combined_updates
|
| 408 |
|
| 409 |
# Connect language change to phrase reloading
|
| 410 |
-
|
| 411 |
change_language,
|
| 412 |
-
inputs=[
|
| 413 |
outputs=[phrase_texts_state, visible_rows_state] + phrase_markdowns + rec_components
|
| 414 |
)
|
| 415 |
|
|
@@ -482,40 +568,78 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 482 |
|
| 483 |
save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_out])
|
| 484 |
|
| 485 |
-
# Quick sample from
|
| 486 |
with gr.Row():
|
| 487 |
-
vp_lang = gr.Dropdown(choices=["en", "de", "fr", "es", "it", "pl", "
|
| 488 |
vp_samples = gr.Number(value=20, precision=0, label="Num samples")
|
| 489 |
vp_split = gr.Dropdown(choices=["train", "validation", "test"], value="train", label="Split")
|
| 490 |
-
vp_btn = gr.Button("Use
|
| 491 |
-
|
| 492 |
-
def
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
if "__main__" not in sys.modules:
|
| 496 |
-
sys.modules["__main__"] = sys.modules[__name__]
|
| 497 |
-
from datasets import load_dataset, Audio # type: ignore
|
| 498 |
import random
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
| 508 |
-
rows
|
| 509 |
-
texts: list[str] = []
|
| 510 |
-
for ex in ds_sel:
|
| 511 |
-
audio = ex.get("audio") or {}
|
| 512 |
-
path = audio.get("path")
|
| 513 |
-
text = ex.get("normalized_text") or ex.get("raw_text") or ""
|
| 514 |
-
if path and text is not None:
|
| 515 |
-
rows.append({"audio_path": path, "text": text})
|
| 516 |
-
texts.append(str(text))
|
| 517 |
jsonl_path = dataset_dir / "data.jsonl"
|
| 518 |
_write_jsonl(rows, jsonl_path)
|
|
|
|
| 519 |
# Build markdown content updates for on-screen prompts
|
| 520 |
combined_updates = []
|
| 521 |
for i in range(len(phrase_markdowns)):
|
|
@@ -528,7 +652,7 @@ with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
|
| 528 |
return (str(jsonl_path), texts, *combined_updates)
|
| 529 |
|
| 530 |
vp_btn.click(
|
| 531 |
-
|
| 532 |
[vp_lang, vp_samples, vp_split],
|
| 533 |
[jsonl_out, phrase_texts_state] + phrase_markdowns,
|
| 534 |
)
|
|
|
|
| 251 |
yield line
|
| 252 |
|
| 253 |
|
| 254 |
+
def load_multilingual_phrases(language="en", max_phrases=None, split="train"):
|
| 255 |
+
"""Load phrases from various multilingual speech datasets.
|
| 256 |
+
|
| 257 |
+
Tries multiple datasets in order of preference:
|
| 258 |
+
1. Common Voice (most reliable and up-to-date)
|
| 259 |
+
2. FLEURS (Google's multilingual dataset)
|
| 260 |
+
3. Fallback to basic phrases
|
| 261 |
|
| 262 |
Args:
|
| 263 |
language: Language code (e.g., 'en', 'de', 'fr', etc.)
|
|
|
|
| 267 |
Returns:
|
| 268 |
List of normalized text phrases
|
| 269 |
"""
|
| 270 |
+
from datasets import load_dataset
|
| 271 |
+
import random
|
| 272 |
+
|
| 273 |
+
# Language code mapping for different datasets
|
| 274 |
+
lang_mappings = {
|
| 275 |
+
"en": {"common_voice": "en", "fleurs": "en_us"},
|
| 276 |
+
"de": {"common_voice": "de", "fleurs": "de_de"},
|
| 277 |
+
"fr": {"common_voice": "fr", "fleurs": "fr_fr"},
|
| 278 |
+
"es": {"common_voice": "es", "fleurs": "es_419"},
|
| 279 |
+
"it": {"common_voice": "it", "fleurs": "it_it"},
|
| 280 |
+
"pt": {"common_voice": "pt", "fleurs": "pt_br"},
|
| 281 |
+
"pl": {"common_voice": "pl", "fleurs": "pl_pl"},
|
| 282 |
+
"nl": {"common_voice": "nl", "fleurs": "nl_nl"},
|
| 283 |
+
"ru": {"common_voice": "ru", "fleurs": "ru_ru"},
|
| 284 |
+
"ar": {"common_voice": "ar", "fleurs": "ar_eg"},
|
| 285 |
+
"zh": {"common_voice": "zh-CN", "fleurs": "zh_cn"},
|
| 286 |
+
"ja": {"common_voice": "ja", "fleurs": "ja_jp"},
|
| 287 |
+
"ko": {"common_voice": "ko", "fleurs": "ko_kr"},
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
lang_config = lang_mappings.get(language, {"common_voice": language, "fleurs": f"{language}_{language}"})
|
| 291 |
+
|
| 292 |
+
# Try Common Voice first (most reliable)
|
| 293 |
try:
|
| 294 |
+
print(f"Trying Common Voice dataset for language: {language}")
|
| 295 |
+
cv_lang = lang_config["common_voice"]
|
| 296 |
+
ds = load_dataset("mozilla-foundation/common_voice_11_0", cv_lang, split=split, streaming=True)
|
|
|
|
|
|
|
| 297 |
|
|
|
|
| 298 |
phrases = []
|
| 299 |
+
count = 0
|
| 300 |
for example in ds:
|
| 301 |
+
if max_phrases and count >= max_phrases:
|
| 302 |
+
break
|
| 303 |
+
text = example.get("sentence", "").strip()
|
| 304 |
if text and len(text) > 10: # Filter out very short phrases
|
| 305 |
phrases.append(text)
|
| 306 |
+
count += 1
|
| 307 |
|
| 308 |
+
if phrases:
|
| 309 |
+
print(f"Successfully loaded {len(phrases)} phrases from Common Voice")
|
|
|
|
|
|
|
|
|
|
| 310 |
random.shuffle(phrases)
|
| 311 |
+
return phrases
|
| 312 |
+
|
| 313 |
+
except Exception as e:
|
| 314 |
+
print(f"Common Voice failed: {e}")
|
| 315 |
+
|
| 316 |
+
# Try FLEURS as backup
|
| 317 |
+
try:
|
| 318 |
+
print(f"Trying FLEURS dataset for language: {language}")
|
| 319 |
+
fleurs_lang = lang_config["fleurs"]
|
| 320 |
+
ds = load_dataset("google/fleurs", fleurs_lang, split=split, streaming=True)
|
| 321 |
|
| 322 |
+
phrases = []
|
| 323 |
+
count = 0
|
| 324 |
+
for example in ds:
|
| 325 |
+
if max_phrases and count >= max_phrases:
|
| 326 |
+
break
|
| 327 |
+
text = example.get("transcription", "").strip()
|
| 328 |
+
if text and len(text) > 10: # Filter out very short phrases
|
| 329 |
+
phrases.append(text)
|
| 330 |
+
count += 1
|
| 331 |
+
|
| 332 |
+
if phrases:
|
| 333 |
+
print(f"Successfully loaded {len(phrases)} phrases from FLEURS")
|
| 334 |
+
random.shuffle(phrases)
|
| 335 |
+
return phrases
|
| 336 |
|
| 337 |
except Exception as e:
|
| 338 |
+
print(f"FLEURS failed: {e}")
|
| 339 |
+
|
| 340 |
+
# Final fallback to basic phrases
|
| 341 |
+
print("All dataset loading attempts failed, using fallback phrases")
|
| 342 |
+
fallback_phrases = [
|
| 343 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 344 |
+
"Please say your full name.",
|
| 345 |
+
"Today is a good day to learn something new.",
|
| 346 |
+
"Artificial intelligence helps with many tasks.",
|
| 347 |
+
"I enjoy reading books and listening to music.",
|
| 348 |
+
"This is a sample sentence for testing speech.",
|
| 349 |
+
"Speak clearly and at a normal pace.",
|
| 350 |
+
"Numbers like one, two, three are easy to say.",
|
| 351 |
+
"The weather is sunny with a chance of rain.",
|
| 352 |
+
"Thank you for taking the time to help.",
|
| 353 |
+
"Hello, how are you today?",
|
| 354 |
+
"I would like to order a pizza.",
|
| 355 |
+
"The meeting is scheduled for tomorrow.",
|
| 356 |
+
"Please call me back as soon as possible.",
|
| 357 |
+
"Thank you for your assistance.",
|
| 358 |
+
"Can you help me with this problem?",
|
| 359 |
+
"I need to make a reservation.",
|
| 360 |
+
"The weather looks beautiful outside.",
|
| 361 |
+
"Let's go for a walk in the park.",
|
| 362 |
+
"I enjoy listening to classical music.",
|
| 363 |
+
"What time does the store open?",
|
| 364 |
+
"I forgot my password again.",
|
| 365 |
+
"Please send me the invoice.",
|
| 366 |
+
"The project is almost complete.",
|
| 367 |
+
"I appreciate your hard work.",
|
| 368 |
+
"Let's schedule a meeting next week.",
|
| 369 |
+
"The food tastes delicious.",
|
| 370 |
+
"I need to buy some groceries.",
|
| 371 |
+
"Please turn off the lights.",
|
| 372 |
+
"The presentation went very well.",
|
| 373 |
+
]
|
| 374 |
+
|
| 375 |
+
if max_phrases:
|
| 376 |
+
fallback_phrases = random.sample(fallback_phrases, min(max_phrases, len(fallback_phrases)))
|
| 377 |
+
else:
|
| 378 |
+
random.shuffle(fallback_phrases)
|
| 379 |
+
|
| 380 |
+
return fallback_phrases
|
| 381 |
|
| 382 |
# Initialize phrases dynamically
|
| 383 |
+
DEFAULT_LANGUAGE = "en" # Default to English
|
| 384 |
+
ALL_PHRASES = load_multilingual_phrases(DEFAULT_LANGUAGE, max_phrases=None)
|
| 385 |
|
| 386 |
with gr.Blocks(title="Voxtral ASR Fine-tuning") as demo:
|
| 387 |
has_gpu, gpu_msg = detect_nvidia_driver()
|
|
|
|
| 420 |
|
| 421 |
jsonl_out = gr.Textbox(label="Dataset JSONL path", interactive=False, visible=True)
|
| 422 |
|
| 423 |
+
# Language selection for multilingual phrases
|
| 424 |
+
language_selector = gr.Dropdown(
|
| 425 |
+
choices=[
|
| 426 |
+
"en", "de", "fr", "es", "it", "pt", "pl", "nl", "ru",
|
| 427 |
+
"ar", "zh", "ja", "ko", "tr", "ca", "sv", "fi", "da"
|
| 428 |
+
],
|
| 429 |
value="en",
|
| 430 |
+
label="Language for Speech Phrases",
|
| 431 |
+
info="Select language for phrases from Common Voice, FLEURS, or fallback datasets"
|
| 432 |
)
|
| 433 |
|
| 434 |
# Recording grid with dynamic text readouts
|
|
|
|
| 469 |
return [new_visible] + visibility_updates
|
| 470 |
|
| 471 |
def change_language(language):
|
| 472 |
+
"""Change the language and reload phrases from multilingual datasets"""
|
| 473 |
+
new_phrases = load_multilingual_phrases(language, max_phrases=None)
|
| 474 |
# Reset visible rows to 10
|
| 475 |
visible_count = min(10, len(new_phrases))
|
| 476 |
|
|
|
|
| 493 |
return [new_phrases, visible_count] + combined_updates
|
| 494 |
|
| 495 |
# Connect language change to phrase reloading
|
| 496 |
+
language_selector.change(
|
| 497 |
change_language,
|
| 498 |
+
inputs=[language_selector],
|
| 499 |
outputs=[phrase_texts_state, visible_rows_state] + phrase_markdowns + rec_components
|
| 500 |
)
|
| 501 |
|
|
|
|
| 568 |
|
| 569 |
save_rec_btn.click(_collect_preloaded_recs, rec_components + [phrase_texts_state], [jsonl_out])
|
| 570 |
|
| 571 |
+
# Quick sample from multilingual datasets (Common Voice, etc.)
|
| 572 |
with gr.Row():
|
| 573 |
+
vp_lang = gr.Dropdown(choices=["en", "de", "fr", "es", "it", "pl", "pt", "nl", "ru", "ar", "zh", "ja", "ko"], value="en", label="Sample Language")
|
| 574 |
vp_samples = gr.Number(value=20, precision=0, label="Num samples")
|
| 575 |
vp_split = gr.Dropdown(choices=["train", "validation", "test"], value="train", label="Split")
|
| 576 |
+
vp_btn = gr.Button("Use Multilingual Dataset Sample")
|
| 577 |
+
|
| 578 |
+
def _collect_multilingual_sample(lang_code: str, num_samples: int, split: str):
|
| 579 |
+
"""Load sample from multilingual datasets (Common Voice preferred)"""
|
| 580 |
+
from datasets import load_dataset, Audio
|
|
|
|
|
|
|
|
|
|
| 581 |
import random
|
| 582 |
+
|
| 583 |
+
# Language code mapping for Common Voice
|
| 584 |
+
cv_lang_map = {
|
| 585 |
+
"en": "en", "de": "de", "fr": "fr", "es": "es", "it": "it",
|
| 586 |
+
"pl": "pl", "pt": "pt", "nl": "nl", "ru": "ru", "ar": "ar",
|
| 587 |
+
"zh": "zh-CN", "ja": "ja", "ko": "ko"
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
cv_lang = cv_lang_map.get(lang_code, lang_code)
|
| 591 |
+
|
| 592 |
+
try:
|
| 593 |
+
# Try Common Voice first
|
| 594 |
+
ds = load_dataset("mozilla-foundation/common_voice_11_0", cv_lang, split=split, streaming=True)
|
| 595 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
| 596 |
+
|
| 597 |
+
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
| 598 |
+
rows: list[dict] = []
|
| 599 |
+
texts: list[str] = []
|
| 600 |
+
|
| 601 |
+
count = 0
|
| 602 |
+
for ex in ds:
|
| 603 |
+
if count >= num_samples:
|
| 604 |
+
break
|
| 605 |
+
|
| 606 |
+
audio = ex.get("audio") or {}
|
| 607 |
+
path = audio.get("path")
|
| 608 |
+
text = ex.get("sentence", "").strip()
|
| 609 |
+
|
| 610 |
+
if path and text and len(text) > 10:
|
| 611 |
+
rows.append({"audio_path": path, "text": text})
|
| 612 |
+
texts.append(str(text))
|
| 613 |
+
count += 1
|
| 614 |
+
|
| 615 |
+
if rows:
|
| 616 |
+
jsonl_path = dataset_dir / "data.jsonl"
|
| 617 |
+
_write_jsonl(rows, jsonl_path)
|
| 618 |
+
|
| 619 |
+
# Build markdown content updates for on-screen prompts
|
| 620 |
+
combined_updates = []
|
| 621 |
+
for i in range(len(phrase_markdowns)):
|
| 622 |
+
t = texts[i] if i < len(texts) else ""
|
| 623 |
+
if i < len(texts):
|
| 624 |
+
combined_updates.append(gr.update(value=f"**{i+1}. {t}**", visible=True))
|
| 625 |
+
else:
|
| 626 |
+
combined_updates.append(gr.update(visible=False))
|
| 627 |
+
|
| 628 |
+
return (str(jsonl_path), texts, *combined_updates)
|
| 629 |
+
|
| 630 |
+
except Exception as e:
|
| 631 |
+
print(f"Common Voice sample loading failed: {e}")
|
| 632 |
+
|
| 633 |
+
# Fallback: generate synthetic samples with text only
|
| 634 |
+
print("Using fallback: generating text-only samples")
|
| 635 |
+
phrases = load_multilingual_phrases(lang_code, max_phrases=num_samples)
|
| 636 |
+
texts = phrases[:num_samples]
|
| 637 |
|
| 638 |
dataset_dir = PROJECT_ROOT / "datasets" / "voxtral_user"
|
| 639 |
+
rows = [{"audio_path": "", "text": text} for text in texts]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
jsonl_path = dataset_dir / "data.jsonl"
|
| 641 |
_write_jsonl(rows, jsonl_path)
|
| 642 |
+
|
| 643 |
# Build markdown content updates for on-screen prompts
|
| 644 |
combined_updates = []
|
| 645 |
for i in range(len(phrase_markdowns)):
|
|
|
|
| 652 |
return (str(jsonl_path), texts, *combined_updates)
|
| 653 |
|
| 654 |
vp_btn.click(
|
| 655 |
+
_collect_multilingual_sample,
|
| 656 |
[vp_lang, vp_samples, vp_split],
|
| 657 |
[jsonl_out, phrase_texts_state] + phrase_markdowns,
|
| 658 |
)
|