VeuReu commited on
Commit
18e066a
·
verified ·
1 Parent(s): 5782742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -46
app.py CHANGED
@@ -425,88 +425,269 @@ def extract_audio_ffmpeg(video_file, sr: int = 16000, mono: bool = True):
425
  return convertir_a_temporal(audio_out+".mp3")
426
 
427
 
428
- # =================
429
- # UI de demostración
430
- # =================
431
- with gr.Blocks(title="Aina faster-whisper (Català) · ZeroGPU") as demo:
432
- gr.Markdown("## Aina faster-whisper (Català) · ZeroGPU\nReconocimiento de voz en catalán finetune projecte-aina.")
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  with gr.Row():
435
  with gr.Column():
436
- inp = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio (WAV/MP3/MP4, etc.)")
437
- lang = gr.Textbox(label="language", value="ca")
438
- ts = gr.Checkbox(label="timestamps", value=True)
439
- vad = gr.Checkbox(label="VAD filter", value=True)
440
- btn = gr.Button("Transcribir (ENGINE /predict)", variant="primary")
441
  with gr.Column():
442
- out = gr.JSON(label="Salida /predict")
443
 
 
444
  btn.click(predict_for_engine, [inp, lang, ts, vad], out, api_name="predict", concurrency_limit=1)
445
 
446
- # Sección avanzada
447
- gr.Markdown("---\n### Avanzado (/transcribe)")
448
  with gr.Row():
449
  with gr.Column():
450
- inp2 = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")
451
- lang2 = gr.Textbox(label="language", value="ca")
452
- task2 = gr.Dropdown(["transcribe", "translate"], value="transcribe", label="task")
453
- vad2 = gr.Checkbox(label="VAD filter", value=True)
454
- beam2 = gr.Slider(1, 10, value=5, step=1, label="beam_size")
455
- temp2 = gr.Slider(0.0, 1.5, value=0.0, step=0.1, label="temperature")
456
- wts2 = gr.Checkbox(label="word_timestamps", value=False)
457
- btn2 = gr.Button("Transcribir (avanzado)")
458
  with gr.Column():
459
- out2 = gr.JSON(label="Salida /transcribe")
460
-
461
- btn2.click(transcribe_advanced, [inp2, lang2, task2, vad2, beam2, temp2, wts2], out2, api_name="transcribe", concurrency_limit=1)
462
-
463
- # Diarización
 
 
 
 
 
464
 
465
- gr.Markdown('<h2 style="text-align:center">Diarització del vídeo</h2>')
 
466
  with gr.Row():
467
  audio_input = gr.Audio(label="Àudio per diaritzar", type="filepath")
468
  process_btn = gr.Button("Diaritzar àudio", variant="primary")
469
- clips_output = gr.File(label="Clips d'àudio generats", file_types=[".wav"], file_count="multiple")
470
  diarization_output = gr.JSON(label="Resultat de la diarització")
471
- process_btn.click(diarize_audio, inputs=[audio_input], outputs=[clips_output,diarization_output], api_name="diaritzar_audio", concurrency_limit=1)
472
 
473
- # Embeddings de veu
 
 
 
 
 
 
474
 
 
475
  gr.Markdown('<h2 style="text-align:center">Obtenir l\'embedding d\'un àudio</h2>')
476
  with gr.Row():
477
- audio_input = gr.Audio(label="Àudio per obtenir l'embedding", type="filepath")
478
  process_btn = gr.Button("Obtenir embedding", variant="primary")
479
  clip_out = gr.JSON(label="Embedding de veu (vector)")
480
- process_btn.click(voice_embedder, [audio_input], clip_out, api_name="voice_embedding", concurrency_limit=1)
481
- gr.Markdown("---")
482
 
483
- # Identificació de parlants
 
 
 
 
 
 
 
 
484
 
 
485
  gr.Markdown('<h2 style="text-align:center">Identificació de parlants</h2>')
486
  with gr.Row():
487
- audio_input = gr.Audio(label="Àudio per obtenir l'parlant", type="filepath")
488
  voice_col_input = gr.Textbox(
489
- label="Llistat de diccionaris voice_col (format JSON)",
490
- placeholder='[{"nombre": "Anna", "embedding": [0.12, 0.88, ...]}, ...]',
491
  lines=5
492
  )
493
  process_btn = gr.Button("Processar àudio (Persones)", variant="primary")
494
  output_json = gr.JSON(label="Resultat complet")
495
 
496
- process_btn.click(identify_speaker, inputs=[audio_input, voice_col_input], outputs=output_json, api_name="identificar_veu", concurrency_limit=1)
 
 
 
 
 
 
497
 
 
498
  with gr.Row():
499
- gr.Markdown("## Extract Audio from Video using FFmpeg")
500
- # Input component: user uploads a video file
501
- video_input = gr.Video(label="Upload a video")
502
- # Output component: returns a WAV file path
503
- audio_output = gr.Audio(label="Extracted audio (WAV)", type="filepath")
504
- # Button to trigger extraction
505
- extract_btn = gr.Button("Extract audio")
506
- # Link button click to processing function
507
  extract_btn.click(
508
  fn=extract_audio_ffmpeg,
509
  inputs=video_input,
510
  outputs=audio_output
511
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  demo.queue(max_size=8).launch()
 
425
  return convertir_a_temporal(audio_out+".mp3")
426
 
427
 
428
+ import torch
429
+ import torchaudio
430
+ from dataclasses import dataclass
431
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
432
+ import logging
433
 
434
+ def load_audio(path, target_sr=16000):
435
+ waveform, sr = torchaudio.load(path)
436
+ if sr != target_sr:
437
+ waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
438
+ return waveform.squeeze().numpy()
439
+
440
+ def transcribe_wav(wav_path: str) -> str:
441
+ model_name = "projecte-aina/whisper-large-v3-ca-3catparla"
442
+ device = "cuda"
443
+ dev = device
444
+ if dev == "cuda" and not torch.cuda.is_available():
445
+ dev = "cpu"
446
+ processor = WhisperProcessor.from_pretrained(model_name)
447
+ model = WhisperForConditionalGeneration.from_pretrained(model_name).to(dev)
448
+ device = dev
449
+ # Carga el archivo WAV
450
+ waveform, sr = torchaudio.load(wav_path)
451
+
452
+ target_sr = 16000
453
+ if sr != target_sr:
454
+ waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
455
+ sr = target_sr
456
+ # Preprocesamos el audio
457
+ inputs = processor(
458
+ waveform.numpy(), sampling_rate=sr, return_tensors="pt"
459
+ ).input_features.to(model.device)
460
+
461
+ # Genera la transcripción con el modelo
462
+ with torch.no_grad():
463
+ ids = model.generate(inputs, max_new_tokens=440)[0]
464
+
465
+ # Decodifica la transcripción
466
+ txt = processor.decode(ids)
467
+
468
+ # Normaliza el texto si es necesario
469
+ norm = getattr(processor.tokenizer, "_normalize", None)
470
+ return norm(txt) if callable(norm) else txt
471
+
472
+ def transcribe_long_audio(
473
+ wav_path: str,
474
+ chunk_length_s: int = 20,
475
+ overlap_s: int = 2,
476
+ ) -> str:
477
+ model_name = "projecte-aina/whisper-large-v3-ca-3catparla"
478
+ device = "cuda"
479
+ dev = device
480
+ if dev == "cuda" and not torch.cuda.is_available():
481
+ dev = "cpu"
482
+ processor = WhisperProcessor.from_pretrained(model_name)
483
+ model = WhisperForConditionalGeneration.from_pretrained(model_name).to(dev)
484
+ device = dev
485
+ # Carga el archivo WAV completo
486
+ waveform, sr = torchaudio.load(wav_path)
487
+ target_sr = 16000
488
+ if sr != target_sr:
489
+ waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
490
+ sr = target_sr
491
+ total_samples = waveform.shape[1]
492
+
493
+ # Calculamos el tamaño de los chunks y el solapamiento en muestras
494
+ chunk_size = chunk_length_s * sr
495
+ overlap_size = overlap_s * sr
496
+
497
+ transcriptions = []
498
+ start = 0
499
+
500
+ while start < total_samples:
501
+ end = min(start + chunk_size, total_samples)
502
+ chunk = waveform[:, start:end] # Se transcribe como en fragmentos pequeños
503
+
504
+ input_features = processor(
505
+ chunk.numpy(),
506
+ sampling_rate=sr,
507
+ return_tensors="pt"
508
+ ).input_features.to(model.device)
509
+
510
+ with torch.no_grad():
511
+ predicted_ids = model.generate(
512
+ input_features,
513
+ max_new_tokens=440,
514
+ num_beams=1,
515
+ )[0]
516
+
517
+ text = processor.decode(predicted_ids, skip_special_tokens=True)
518
+ transcriptions.append(text.strip())
519
+
520
+ # avanzar con solapamiento
521
+ start += chunk_size - overlap_size
522
+
523
+ return " ".join(transcriptions).strip()
524
+
525
+
526
+ """
527
+ # ==============================================================================
528
+ # UI & Endpoints
529
+ # ==============================================================================
530
+ Collection of Gradio interface elements and API endpoints used by the application.
531
+
532
+ This section defines the user-facing interface for Salamandra Vision 7B,
533
+ allowing users to interact with the model through images, text prompts,
534
+ video uploads, and batch operations.
535
+
536
+ The components and endpoints in this module typically:
537
+ - Accept images, text, or video files from the user
538
+ - Apply optional parameters such as temperature, token limits, or crop ratios
539
+ - Preprocess inputs and invoke internal inference or utility functions
540
+ - Return structured outputs, including text descriptions, JSON metadata,
541
+ or image galleries
542
+
543
+ All endpoints are designed to be stateless, safe for concurrent calls,
544
+ and compatible with both interactive UI usage and programmatic API access.
545
+ # ==============================================================================
546
+ """
547
+ custom_css = """
548
+ h2 {
549
+ background: #e3e4e6 !important;
550
+ padding: 14px 22px !important;
551
+ border-radius: 14px !important;
552
+ box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important;
553
+ display: block !important; /* ocupa tot l'ample */
554
+ width: 100% !important; /* assegura 100% */
555
+ margin: 20px auto !important;
556
+ text-align:center;
557
+ }
558
+ """
559
+ with gr.Blocks(title="Aina faster-whisper (Català) · ZeroGPU", css=custom_css,theme=gr.themes.Soft()) as demo:
560
+ # Header
561
+ gr.Markdown("## Aina faster-whisper (Català) · ZeroGPU\nReconeixement de veu en català finetune projecte-aina.")
562
+
563
+ # Main transcription section
564
  with gr.Row():
565
  with gr.Column():
566
+ inp = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Àudio (WAV/MP3/MP4, etc.)")
567
+ lang = gr.Textbox(label="Idioma", value="ca")
568
+ ts = gr.Checkbox(label="Marques de temps", value=True)
569
+ vad = gr.Checkbox(label="Filtre VAD", value=True)
570
+ btn = gr.Button("Transcriure (ENGINE /predict)", variant="primary")
571
  with gr.Column():
572
+ out = gr.JSON(label="Sortida /predict")
573
 
574
+ # Button callback
575
  btn.click(predict_for_engine, [inp, lang, ts, vad], out, api_name="predict", concurrency_limit=1)
576
 
577
+ # Advanced transcription section
578
+ gr.Markdown("---\n### Avançat (/transcribe)")
579
  with gr.Row():
580
  with gr.Column():
581
+ inp2 = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Àudio")
582
+ lang2 = gr.Textbox(label="Idioma", value="ca")
583
+ task2 = gr.Dropdown(["transcribe", "translate"], value="transcribe", label="Tasques")
584
+ vad2 = gr.Checkbox(label="Filtre VAD", value=True)
585
+ beam2 = gr.Slider(1, 10, value=5, step=1, label="Mida del feix")
586
+ temp2 = gr.Slider(0.0, 1.5, value=0.0, step=0.1, label="Temperatura")
587
+ wts2 = gr.Checkbox(label="Marques de temps per paraula", value=False)
588
+ btn2 = gr.Button("Transcriure (avançat)")
589
  with gr.Column():
590
+ out2 = gr.JSON(label="Sortida /transcribe")
591
+
592
+ # Button callback advanced
593
+ btn2.click(
594
+ transcribe_advanced,
595
+ [inp2, lang2, task2, vad2, beam2, temp2, wts2],
596
+ out2,
597
+ api_name="transcribe",
598
+ concurrency_limit=1
599
+ )
600
 
601
+ # Diarization section
602
+ gr.Markdown('<h2 style="text-align:center">Diarització de l\'àudio</h2>')
603
  with gr.Row():
604
  audio_input = gr.Audio(label="Àudio per diaritzar", type="filepath")
605
  process_btn = gr.Button("Diaritzar àudio", variant="primary")
606
+ clips_output = gr.File(label="Clips d\'àudio generats", file_types=[".wav"], file_count="multiple")
607
  diarization_output = gr.JSON(label="Resultat de la diarització")
 
608
 
609
+ process_btn.click(
610
+ diarize_audio,
611
+ inputs=[audio_input],
612
+ outputs=[clips_output, diarization_output],
613
+ api_name="diaritzar_audio",
614
+ concurrency_limit=1
615
+ )
616
 
617
+ # Voice embeddings section
618
  gr.Markdown('<h2 style="text-align:center">Obtenir l\'embedding d\'un àudio</h2>')
619
  with gr.Row():
620
+ audio_input = gr.Audio(label="Àudio per obtenir l\'embedding", type="filepath")
621
  process_btn = gr.Button("Obtenir embedding", variant="primary")
622
  clip_out = gr.JSON(label="Embedding de veu (vector)")
 
 
623
 
624
+ process_btn.click(
625
+ voice_embedder,
626
+ [audio_input],
627
+ clip_out,
628
+ api_name="voice_embedding",
629
+ concurrency_limit=1
630
+ )
631
+
632
+ gr.Markdown("---")
633
 
634
+ # Speaker identification
635
  gr.Markdown('<h2 style="text-align:center">Identificació de parlants</h2>')
636
  with gr.Row():
637
+ audio_input = gr.Audio(label="Àudio per identificar el parlant", type="filepath")
638
  voice_col_input = gr.Textbox(
639
+ label="Llista de diccionaris voice_col (format JSON)",
640
+ placeholder='[{"nom": "Anna", "embedding": [0.12, 0.88, ...]}, ...]',
641
  lines=5
642
  )
643
  process_btn = gr.Button("Processar àudio (Persones)", variant="primary")
644
  output_json = gr.JSON(label="Resultat complet")
645
 
646
+ process_btn.click(
647
+ identify_speaker,
648
+ inputs=[audio_input, voice_col_input],
649
+ outputs=output_json,
650
+ api_name="identificar_veu",
651
+ concurrency_limit=1
652
+ )
653
 
654
+ # Extract audio from video
655
  with gr.Row():
656
+ gr.Markdown('<h2 style="text-align:center">Extreure àudio d\'un vídeo (FFmpeg)</h2>')
657
+ video_input = gr.Video(label="Puja un vídeo")
658
+ audio_output = gr.Audio(label="Àudio extret (WAV)", type="filepath")
659
+ extract_btn = gr.Button("Extreure àudio", variant="primary")
660
+
 
 
 
661
  extract_btn.click(
662
  fn=extract_audio_ffmpeg,
663
  inputs=video_input,
664
  outputs=audio_output
665
  )
666
+
667
+ # Short audio transcription
668
+ gr.Markdown('<h2 style="text-align:center">Àudio curt → text</h2>')
669
+ with gr.Row():
670
+ audio_input = gr.Audio(type="filepath", label="Puja el teu àudio")
671
+ output_text = gr.Textbox(label="Text transcrit")
672
+ boton = gr.Button("Transcriure", variant="primary")
673
+
674
+ boton.click(
675
+ fn=transcribe_wav,
676
+ inputs=audio_input,
677
+ outputs=output_text
678
+ )
679
+
680
+ # Long audio transcription
681
+ gr.Markdown('<h2 style="text-align:center">Àudio llarg → text</h2>')
682
+ with gr.Row():
683
+ audio_input = gr.Audio(type="filepath", label="Puja el teu àudio")
684
+ output_text = gr.Textbox(label="Text transcrit")
685
+ boton = gr.Button("Transcriure", variant="primary")
686
+
687
+ boton.click(
688
+ fn=transcribe_long_audio,
689
+ inputs=audio_input,
690
+ outputs=output_text
691
+ )
692
+
693
  demo.queue(max_size=8).launch()