File size: 25,186 Bytes
4c12dfc
8c0fdd4
 
 
 
 
 
29005ec
8c0fdd4
c74cbe8
 
 
8c0fdd4
4c12dfc
8c0fdd4
4c12dfc
 
8c0fdd4
 
4c12dfc
8c0fdd4
 
4c12dfc
8c0fdd4
 
 
 
c74cbe8
8c0fdd4
 
 
 
 
4c12dfc
8c0fdd4
 
 
85f3aa8
 
 
 
 
 
 
 
c74cbe8
85f3aa8
 
4c12dfc
 
941e7a4
85f3aa8
4c12dfc
85f3aa8
 
 
fe0a171
85f3aa8
 
f195d1e
85f3aa8
 
 
f410f79
85f3aa8
8c0fdd4
4c12dfc
8c0fdd4
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
8c0fdd4
4c12dfc
8c0fdd4
 
 
 
 
 
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
cf4b658
8c0fdd4
 
 
4c12dfc
 
 
 
8c0fdd4
4c12dfc
 
8c0fdd4
 
4c12dfc
8c0fdd4
 
b42569c
8c0fdd4
4c12dfc
8c0fdd4
 
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
4c12dfc
 
8c0fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
 
 
 
4c12dfc
 
 
8c0fdd4
 
4c12dfc
8c0fdd4
 
4c12dfc
 
8c0fdd4
 
 
4c12dfc
8c0fdd4
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
 
 
 
 
 
4c12dfc
8c0fdd4
 
 
4c12dfc
8c0fdd4
4c12dfc
 
8c0fdd4
 
 
baa4d85
8c0fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
 
8c0fdd4
 
 
 
 
4c12dfc
8c0fdd4
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
8c0fdd4
4c12dfc
 
8c0fdd4
 
4c12dfc
8c0fdd4
4c12dfc
8c0fdd4
4c12dfc
8c0fdd4
4c12dfc
8c0fdd4
 
4c12dfc
8c0fdd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
 
 
 
8c0fdd4
4c12dfc
 
8c0fdd4
4c12dfc
 
 
 
 
 
8c0fdd4
4c12dfc
 
8c0fdd4
18e066a
 
 
 
 
8c0fdd4
18e066a
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
 
85f3aa8
18e066a
4c12dfc
 
18e066a
 
 
 
4c12dfc
18e066a
 
4c12dfc
 
18e066a
 
 
 
4c12dfc
18e066a
 
 
4c12dfc
18e066a
 
4c12dfc
18e066a
 
 
 
 
 
 
 
 
 
 
 
 
4c12dfc
 
1a6850f
18e066a
4c12dfc
 
18e066a
 
 
4c12dfc
18e066a
 
 
 
4c12dfc
18e066a
 
 
 
 
 
 
 
4c12dfc
18e066a
4c12dfc
18e066a
 
 
 
 
 
4c12dfc
18e066a
 
 
 
 
 
 
4c12dfc
18e066a
 
 
4c12dfc
18e066a
 
4c12dfc
18e066a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c0fdd4
4c12dfc
 
8c0fdd4
4c12dfc
851ca09
4c12dfc
 
 
18e066a
4c12dfc
 
 
 
18e066a
8c0fdd4
18e066a
 
8c0fdd4
 
 
18e066a
8c0fdd4
 
18e066a
 
 
 
 
 
 
8c0fdd4
18e066a
8c0fdd4
 
18e066a
851ca09
8c0fdd4
851ca09
8c0fdd4
 
18e066a
 
 
 
 
 
 
 
 
8c0fdd4
18e066a
8c0fdd4
 
18e066a
851ca09
8c0fdd4
18e066a
 
8c0fdd4
 
851ca09
8c0fdd4
851ca09
8c0fdd4
 
18e066a
 
 
 
 
 
 
8c0fdd4
18e066a
851ca09
18e066a
 
851ca09
18e066a
851ca09
 
18e066a
 
 
 
 
 
 
 
851ca09
18e066a
 
851ca09
 
 
18e066a
 
851ca09
18e066a
 
 
 
4c12dfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18e066a
28f23cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
# app.py — veureu/asr (Aina faster-whisper Catalan · ZeroGPU) — compatible with ENGINE
from __future__ import annotations
import os, json, tempfile
from typing import Dict, Any, List, Tuple, Optional

import gradio as gr
import spaces
import torch

# faster-whisper (CTranslate2)
from faster_whisper import WhisperModel

# =========================
# Config and lazy loading
# =========================
# By default we use the Catalan finetune from projecte-aina on HF.
# Change MODEL_ID to the exact repo you are using (e.g.: "projecte-aina/faster-whisper-large-v3-ca-3catparla")
MODEL_ID = os.environ.get("MODEL_ID", "projecte-aina/faster-whisper-large-v3-ca-3catparla")

# Detect if there is a GPU (ZeroGPU) -> fp16, otherwise INT8
HAS_CUDA = os.environ.get("CUDA_VISIBLE_DEVICES") not in (None, "", "-1")
DEVICE = "cuda" if HAS_CUDA else "cpu"
COMPUTE_TYPE = "float16" if HAS_CUDA else "int8"  # "int8_float16" also works on low-end GPUs

_model: Optional[WhisperModel] = None

def _lazy_model() -> WhisperModel:
    global _model
    if _model is None:
        _model = WhisperModel(
            MODEL_ID,
            device=DEVICE,
            compute_type=COMPUTE_TYPE,
            download_root=os.environ.get("HF_HOME") or None,  # optional
        )
    return _model

_model_whis = None
_processor_whis = None

def _lazy_load_whisper():
    """
    Lazy load para Whisper en HuggingFace Spaces (Stateless GPU compatible).
    Evita inicializar CUDA en el proceso principal.
    """
    global _model_whis, _processor_whis
    if _model_whis is None or _processor_whis is None:
        model_name = "projecte-aina/whisper-large-v3-ca-3catparla"
        
        # processor
        _processor_whis = WhisperProcessor.from_pretrained(model_name)

        # model
        m = WhisperForConditionalGeneration.from_pretrained(
            model_name,
            low_cpu_mem_usage=True,
            use_safetensors=True,
        )

        m = m.to(DEVICE)

        _model_whis = m

    return _processor_whis, _model_whis

# ==================================
# Transcription core (Catalan)
# ==================================
@spaces.GPU
def _transcribe_core(
    audio_path: str,
    language: str = "ca",
    task: str = "transcribe",
    vad_filter: bool = True,
    beam_size: int = 5,
    temperature: float = 0.0,
    word_timestamps: bool = False,
) -> Dict[str, Any]:
    """
    Returns:
      {
        "text": "transcription…",
        "segments": [
            {"start": 0.10, "end": 1.92, "text": "…"},
            ...
        ],
        "language": "ca",
        "info": { "duration": ..., "device": "cuda/cpu", "compute_type": "float16/int8" }
      }
    """
    model = _lazy_model()

    # faster-whisper produces a generator of segments + info
    segments, info = model.transcribe(
        audio_path,
        language=language or "ca",
        task=task,
        vad_filter=vad_filter,
        beam_size=int(beam_size),
        temperature=float(temperature),
        word_timestamps=bool(word_timestamps),
    )

    segs: List[Dict[str, Any]] = []
    full_text_parts: List[str] = []
    for seg in segments:
        text = (seg.text or "").strip()
        full_text_parts.append(text)
        segs.append({
            "start": round(float(seg.start), 3) if seg.start is not None else None,
            "end": round(float(seg.end), 3) if seg.end is not None else None,
            "text": text,
        })

    out = {
        "text": " ".join([t for t in full_text_parts if t]),
        "segments": segs,
        "language": language or "ca",
        "info": {
            "duration": getattr(info, "duration", None),
            "device": DEVICE,
            "compute_type": COMPUTE_TYPE,
        },
    }
    return out

# ==========================
# Endpoints Gradio (API/UI)
# ==========================

# 1) /predict — el que usa el ENGINE vía gradio_client
#    Firma minimalista: solo el audio; el resto con defaults.
def predict_for_engine(
    audio_file,              # gr.Audio o gr.File
    language: str = "ca",
    timestamps: bool = True,
    vad_filter: bool = True,
) -> Dict[str, Any]:
    """
    ENGINE llama normalmente con: client.predict(<audio_path>, api_name="/predict")
    Devolvemos dict con 'text' y 'segments'.
    """
    # Gradio puede darte un dict {'name', 'data'} o una ruta directamente
    path = None
    if isinstance(audio_file, dict) and audio_file.get("name"):
        path = audio_file["name"]
    elif isinstance(audio_file, str):
        path = audio_file
    elif hasattr(audio_file, "name"):
        path = audio_file.name

    if not path:
        return {"text": "", "segments": [], "language": language, "info": {"error": "no_audio"}}

    return _transcribe_core(
        path,
        language=language or "ca",
        task="transcribe",
        vad_filter=bool(vad_filter),
        beam_size=5,
        temperature=0.0,
        word_timestamps=bool(timestamps),
    )

# 2) /transcribe — endpoint alternativo con más controles (útil para pruebas manuales/HTTP)
def transcribe_advanced(
    audio_file,
    language: str = "ca",
    task: str = "transcribe",         # "transcribe" | "translate"
    vad_filter: bool = True,
    beam_size: int = 5,
    temperature: float = 0.0,
    word_timestamps: bool = False,
) -> Dict[str, Any]:
    path = None
    if isinstance(audio_file, dict) and audio_file.get("name"):
        path = audio_file["name"]
    elif isinstance(audio_file, str):
        path = audio_file
    elif hasattr(audio_file, "name"):
        path = audio_file.name
    if not path:
        return {"text": "", "segments": [], "language": language, "info": {"error": "no_audio"}}

    return _transcribe_core(
        path,
        language=language or "ca",
        task=task or "transcribe",
        vad_filter=bool(vad_filter),
        beam_size=int(beam_size),
        temperature=float(temperature),
        word_timestamps=bool(word_timestamps),
    )

import math
from typing import Any, Dict, List, Tuple
from pydub import AudioSegment
from pyannote.audio import Pipeline
from io import BytesIO
import base64
import soundfile as sf

def diarize_audio(
    wav_file: str,
    min_segment_duration: float = 0.5,
    max_segment_duration: float = 50.0,
) -> Tuple[List[str], List[Dict[str, Any]]]:
    """
    Audio diarization that:
    - Reads a WAV file
    - Returns clips in memory as dicts for Gradio (without saving files)
    - Returns the list of segments [{'start','end','speaker'}]
    """
    # Load audio and calculate duration
    audio = AudioSegment.from_wav(wav_file)
    duration = len(audio) / 1000.0

    # Diarization pipeline
    pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization-3.1",
        use_auth_token=os.getenv('HF_TOKEN')
    )
    diarization = pipeline(wav_file)

    clip_buffers: List[Tuple[str, BytesIO]] = []
    segments: List[Dict[str, Any]] = []
    spk_map: Dict[str, int] = {}
    prev_end = 0.0

    # Process each segment
    for i, (turn, _, speaker) in enumerate(diarization.itertracks(yield_label=True)):
        start, end = max(0.0, float(turn.start)), min(duration, float(turn.end))

        if start < prev_end: 
            start = prev_end

        if end <= start: 
            continue

        seg_dur = end - start

        if seg_dur < min_segment_duration: 
            continue

        # Split very long segments
        if seg_dur > max_segment_duration:
            n = int(math.ceil(seg_dur / max_segment_duration))
            sub_d = seg_dur / n
            for j in range(n):
                s = start + j * sub_d
                e = min(end, start + (j + 1) * sub_d)
                clip = audio[int(s*1000):int(e*1000)]
                print(f"Creating clip from {s} to {e} seconds")
                buf = BytesIO()
                clip.export(buf, format="wav")
                buf.seek(0)
                clip_buffers.append((f"segment_{i:03d}_{j:02d}.wav", buf))

                if speaker not in spk_map:
                    spk_map[speaker] = len(spk_map)
                segments.append({"start": s, "end": e, "speaker": f"SPEAKER_{spk_map[speaker]:02d}"})
                prev_end = e

        else:
            clip = audio[int(start*1000):int(end*1000)]
            buf = BytesIO()
            clip.export(buf, format="wav")
            buf.seek(0)
            clip_buffers.append((f"segment_{i:03d}.wav", buf))

            if speaker not in spk_map:
                spk_map[speaker] = len(spk_map)
            segments.append({"start": start, "end": end, "speaker": f"SPEAKER_{spk_map[speaker]:02d}"})
            prev_end = end

    # If no segments, use the entire audio
    if not segments:
        buf = BytesIO()
        audio.export(buf, format="wav")
        buf.seek(0)
        return [{"name": "segment_000.wav", "data": base64.b64encode(buf.read()).decode("utf-8")}], [{"start": 0.0, "end": duration, "speaker": "SPEAKER_00"}]

    # Convert all clips to dicts for Gradio
    print("Clip buffers:")
    print(clip_buffers)

    gr_clips = []
    for i, (name, buf) in enumerate(clip_buffers, start=1):
        buf.seek(0)
        # Create temporary file but with friendly name
        tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
        tmp_file.write(buf.read())
        tmp_file.close()
        
        # Rename to something like "clip1.wav", "clip2.wav", ...
        new_name = f"clip{i}.wav"
        new_path = os.path.join(tempfile.gettempdir(), new_name)
        os.rename(tmp_file.name, new_path)
        
        gr_clips.append(new_path)

    print("Gradio clips prepared.")
    print(gr_clips)
    return gr_clips, segments

import numpy as np
import torchaudio.transforms as T
from speechbrain.inference import SpeakerRecognition
from typing import List
import torchaudio
import torch

def voice_embedder(wav_file: str) -> List[float]:
    print("======================================================")
    model = SpeakerRecognition.from_hparams(
        source="pretrained_models/spkrec-ecapa-voxceleb",
        savedir="pretrained_models/spkrec-ecapa-voxceleb"
    )
    model.eval()
    print("======================================================")
    
    # Audio preprocessing
    waveform, sr = torchaudio.load(wav_file)
    target_sr = 16000

    # Resample if needed
    if sr != target_sr:
        waveform = T.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Minimum duration of 0.2 seconds
    min_samples = int(0.2 * target_sr)
    if waveform.shape[1] < min_samples:
        pad = min_samples - waveform.shape[1]
        waveform = torch.cat([waveform, torch.zeros((1, pad))], dim=1)

    # Compute speaker embedding
    with torch.no_grad():
        emb = (
            model.encode_batch(waveform)
            .squeeze()
            .cpu()
            .numpy()
            .astype(float)
        )

    # Normalize embedding
    emb = emb / np.linalg.norm(emb)
    print(len(emb))
    print(emb.tolist())
    return emb.tolist()

def identify_speaker(wav_file: str, voice_col: List[Dict[str, Any]]) -> Dict[str, Any]:
    voice_embedding = voice_embedder(wav_file)    
    voice_col = json.loads(voice_col)

    identity = "Desconegut"
    knn = []

    if voice_col and voice_embedding is not None:
        try:
            num_embeddings = len(voice_col)

            if num_embeddings < 1:
                knn = []
                identity = "Desconegut"

            else:
                n_results = min(3, num_embeddings)

                voice_embedding = np.array(voice_embedding)

                distances_embedding = []

                # Compute Euclidean distance between the detected voice and each stored embedding
                for voice_base_datos in voice_col:
                    voice_base_datos_embedding = np.array(voice_base_datos["embedding"])
                    distance = np.linalg.norm(voice_embedding - voice_base_datos_embedding)
                    distances_embedding.append({
                        "identity": voice_base_datos["nombre"],
                        "distance": float(distance)
                    })

                # Sort by distance and keep the top N matches
                distances_embedding = sorted(distances_embedding, key=lambda x: x["distance"])
                knn = distances_embedding[:n_results]

                # Assign identity if closest match exists
                if knn:
                    identity = knn[0]["identity"]
                else:
                    identity = "Desconegut"

        except Exception as e:
            print(f"Voice KNN failed: {e}")
            knn = []
            identity = "Desconegut"
    
    return {"knn": knn, "identity": identity}

import subprocess
from pathlib import Path
from audio_extract import extract_audio
import os
import shutil
import tempfile

def convert_to_temporary(original_file):
    """
    Converts a file to a temporary file, deletes the original, and returns
    the path of the temporary file.
    """
    if not os.path.exists(original_file):
        raise FileNotFoundError(f"{original_file} does not exist")

    # Create a temporary file in persistent mode
    temp_fd, temp_path = tempfile.mkstemp(suffix=os.path.splitext(original_file)[1])
    os.close(temp_fd)  # Close the file descriptor; we'll use it as a normal file

    # Copy the content to the temporary file
    shutil.copy2(original_file, temp_path)

    # Delete the original file
    os.remove(original_file)

    return temp_path

def extract_audio_ffmpeg(video_file, sr: int = 16000, mono: bool = True):
    """
    Extracts audio from a video file using FFmpeg and returns the path
    to the generated WAV audio file.

    Parameters
    ----------
    video_file : str
        The temporary file path provided by Gradio for the uploaded video.
    sr : int
        Target audio sample rate.
    mono : bool
        Whether to convert audio to mono channel.

    Returns
    -------
    str
        Filepath to the extracted WAV audio file.
    """
    if video_file is None: 
        return None 
    
    # Extract the file name without extension
    base_name = os.path.splitext(os.path.basename(video_file))[0] 
    
    # Build the output path with .wav extension
    audio_out = f"./{base_name}.wav" 
    
    # If the file already exists, return it directly
    if os.path.exists(audio_out+".mp3"): 
        return audio_out 
    
    # Call the function that performs the extraction
    extract_audio(input_path=video_file, output_path=audio_out) 
    
    return convert_to_temporary(audio_out+".mp3")

import torch
import torchaudio
from dataclasses import dataclass
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import logging

def load_audio(path, target_sr=16000):
    waveform, sr = torchaudio.load(path)
    if sr != target_sr:
        waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
    return waveform.squeeze().numpy()

def transcribe_wav(wav_path: str) -> str:
    model_name = "projecte-aina/whisper-large-v3-ca-3catparla"
    device = "cuda"
    dev = device
    if dev == "cuda" and not torch.cuda.is_available():
        dev = "cpu"
    
    # Lazy-load the Whisper processor and model
    processor, model = _lazy_load_whisper()
    device = dev

    # Load the WAV file
    waveform, sr = torchaudio.load(wav_path)
    
    target_sr = 16000 
    if sr != target_sr: 
        # Resample audio if sample rate differs
        waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform) 
        sr = target_sr

    # Preprocess the audio
    inputs = processor(
        waveform.numpy(), sampling_rate=sr, return_tensors="pt"
    ).input_features.to(model.device)
    
    # Generate transcription with the model
    with torch.no_grad():
        ids = model.generate(inputs, max_new_tokens=440)[0]
    
    # Decode the transcription
    txt = processor.decode(ids)

    # Normalize text if necessary
    norm = getattr(processor.tokenizer, "_normalize", None)
    return norm(txt) if callable(norm) else txt

def transcribe_long_audio(
        wav_path: str,
        chunk_length_s: int = 20,
        overlap_s: int = 2,
) -> str:
    model_name = "projecte-aina/whisper-large-v3-ca-3catparla"
    device = "cuda"
    dev = device
    if dev == "cuda" and not torch.cuda.is_available():
        dev = "cpu"
    
    # Lazy-load the Whisper processor and model
    processor, model = _lazy_load_whisper()
    device = dev

    # Load the full WAV file
    waveform, sr = torchaudio.load(wav_path)
    target_sr = 16000 
    if sr != target_sr: 
        # Resample if sample rate differs
        waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform) 
        sr = target_sr
    total_samples = waveform.shape[1]
    
    # Calculate chunk size and overlap in samples
    chunk_size = chunk_length_s * sr
    overlap_size = overlap_s * sr

    transcriptions = []
    start = 0

    while start < total_samples:
        end = min(start + chunk_size, total_samples)
        chunk = waveform[:, start:end]  # Transcribe in small fragments

        # Preprocess the chunk
        input_features = processor(
            chunk.numpy(),
            sampling_rate=sr,
            return_tensors="pt"
        ).input_features.to(model.device)

        # Generate transcription for the chunk
        with torch.no_grad():
            predicted_ids = model.generate(
                input_features,
                max_new_tokens=440,
                num_beams=1,   
            )[0]

        # Decode and store the chunk transcription
        text = processor.decode(predicted_ids, skip_special_tokens=True)
        transcriptions.append(text.strip())

        # Move to the next chunk with overlap
        start += chunk_size - overlap_size

    # Join all chunks into a single string
    return " ".join(transcriptions).strip()

"""
# ==============================================================================
# UI & Endpoints
# ==============================================================================
Collection of Gradio interface elements and API endpoints used by the application.

This section defines the user-facing interface for Salamandra Vision 7B,
allowing users to interact with the model through images, text prompts, 
video uploads, and batch operations.

The components and endpoints in this module typically:
- Accept images, text, or video files from the user
- Apply optional parameters such as temperature, token limits, or crop ratios
- Preprocess inputs and invoke internal inference or utility functions
- Return structured outputs, including text descriptions, JSON metadata, 
  or image galleries

All endpoints are designed to be stateless, safe for concurrent calls, 
and compatible with both interactive UI usage and programmatic API access.
# ==============================================================================
"""
custom_css = """
h2 {
    background: #e3e4e6 !important;
    padding: 14px 22px !important;
    border-radius: 14px !important;
    box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important;
    display: block !important;       /* ocupa tot l'ample */
    width: 100% !important;          /* assegura 100% */
    margin: 20px auto !important;
    text-align:center;
}
"""
with gr.Blocks(title="Aina faster-whisper (Català) · ZeroGPU", css=custom_css,theme=gr.themes.Soft()) as demo:

    # Extract audio from video
    gr.Markdown('<h2 style="text-align:center">Extreure àudio d\'un vídeo</h2>')
    with gr.Row():
        video_input = gr.Video(label="Puja un vídeo")
    with gr.Row():
        extract_btn = gr.Button("Extreure àudio", variant="primary")
    with gr.Row():
        audio_output = gr.Audio(label="Àudio extret (WAV)", type="filepath")

    extract_btn.click(
        fn=extract_audio_ffmpeg,
        inputs=video_input,
        outputs=audio_output
    )

    # Diarization section
    gr.Markdown('<h2 style="text-align:center">Diarització de l\'àudio</h2>')
    with gr.Row():
        audio_input = gr.Audio(label="Àudio per diaritzar", type="filepath")
        process_btn = gr.Button("Diaritzar àudio", variant="primary")
        clips_output = gr.File(label="Clips d\'àudio generats", file_types=[".wav"], file_count="multiple")
        diarization_output = gr.JSON(label="Resultat de la diarització")

    process_btn.click(
        diarize_audio,
        inputs=[audio_input],
        outputs=[clips_output, diarization_output],
        api_name="diaritzar_audio",
        concurrency_limit=1
    )

    # Voice embeddings section
    gr.Markdown('<h2 style="text-align:center">Obtenir l\'embedding d\'un àudio</h2>')
    with gr.Row():
        audio_input = gr.Audio(label="Àudio per obtenir l\'embedding", type="filepath")
    with gr.Row():
        process_btn = gr.Button("Obtenir embedding", variant="primary")
    with gr.Row():
        clip_out = gr.JSON(label="Embedding de veu (vector)")

    process_btn.click(
        voice_embedder,
        [audio_input],
        clip_out,
        api_name="voice_embedding",
        concurrency_limit=1
    )

    gr.Markdown("---")

    # Speaker identification
    gr.Markdown('<h2 style="text-align:center">Identificació de parlants</h2>')
    with gr.Row():
        audio_input = gr.Audio(label="Àudio per identificar el parlant", type="filepath")
    with gr.Row():
        voice_col_input = gr.Textbox(
            label="Llista de diccionaris voice_col (format JSON)",
            placeholder='[{"nom": "Anna", "embedding": [0.12, 0.88, ...]}, ...]',
            lines=5
        )
    with gr.Row():
        process_btn = gr.Button("Processar àudio (Persones)", variant="primary")
    with gr.Row():
        output_json = gr.JSON(label="Resultat complet")
    
    process_btn.click(
        identify_speaker,
        inputs=[audio_input, voice_col_input],
        outputs=output_json,
        api_name="identificar_veu",
        concurrency_limit=1
    )

    # Short audio transcription
    gr.Markdown('<h2 style="text-align:center">Aina faster-whisper (Català) Àudio curt → text</h2>')
    with gr.Row():
        audio_input = gr.Audio(type="filepath", label="Puja el teu àudio")
    with gr.Row():
        boton = gr.Button("Transcriure", variant="primary")
    with gr.Row():
        output_text = gr.Textbox(label="Text transcrit")

    boton.click(
        fn=transcribe_wav,
        inputs=audio_input,
        outputs=output_text
    )
    
    # Long audio transcription
    gr.Markdown('<h2 style="text-align:center">Aina faster-whisper (Català) Àudio llarg → text</h2>')
    with gr.Row():
        audio_input = gr.Audio(type="filepath", label="Puja el teu àudio")
    with gr.Row():
        boton2 = gr.Button("Transcriure", variant="primary")
    with gr.Row():
        output_text = gr.Textbox(label="Text transcrit")

    boton2.click(
        fn=transcribe_long_audio,
        inputs=audio_input,
        outputs=output_text
    )
    
    # Main transcription section
    gr.Markdown('<h2 style="text-align:center">Aina faster-whisper (Català) · ZeroGPU - Reconeixement de veu en català finetune projecte-aina</h2>')
    with gr.Row():
        with gr.Column():
            inp = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Àudio (WAV/MP3/MP4, etc.)")
            lang = gr.Textbox(label="Idioma", value="ca")
            ts = gr.Checkbox(label="Marques de temps", value=True)
            vad = gr.Checkbox(label="Filtre VAD", value=True)
        with gr.Column():
            out = gr.JSON(label="Sortida /predict")
    with gr.Row():
        btn = gr.Button("Transcriure (ENGINE /predict)", variant="primary")

    # Button callback
    btn.click(predict_for_engine, [inp, lang, ts, vad], out, api_name="predict", concurrency_limit=1)

    # Advanced transcription section
    gr.Markdown('<h2 style="text-align:center">Avançat (/transcribe)</h2>')
    with gr.Row():
        with gr.Column():
            inp2 = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Àudio")
            lang2 = gr.Textbox(label="Idioma", value="ca")
            task2 = gr.Dropdown(["transcribe", "translate"], value="transcribe", label="Tasques")
            vad2 = gr.Checkbox(label="Filtre VAD", value=True)
            beam2 = gr.Slider(1, 10, value=5, step=1, label="Mida del feix")
            temp2 = gr.Slider(0.0, 1.5, value=0.0, step=0.1, label="Temperatura")
            wts2 = gr.Checkbox(label="Marques de temps per paraula", value=False)
        with gr.Column():
            out2 = gr.JSON(label="Sortida /transcribe")
    with gr.Row():
        btn2 = gr.Button("Transcriure (avançat)", variant="primary")

    # Button callback advanced
    btn2.click(
        transcribe_advanced,
        [inp2, lang2, task2, vad2, beam2, temp2, wts2],
        out2,
        api_name="transcribe",
        concurrency_limit=1
    )

demo.queue(max_size=8).launch(share=True,show_error=True)