File size: 2,065 Bytes
3830be6
 
 
 
1155b96
3830be6
1155b96
 
 
3830be6
1155b96
 
 
 
 
 
 
3830be6
 
 
 
1155b96
 
 
3830be6
 
049d47a
1155b96
 
 
 
3830be6
 
1155b96
 
3830be6
1155b96
3830be6
 
1155b96
 
 
3830be6
1155b96
 
3830be6
1155b96
3830be6
 
 
 
049d47a
1155b96
3830be6
 
1155b96
 
 
 
 
 
 
 
3830be6
1155b96
 
049d47a
1155b96
3830be6
1155b96
3830be6
 
049d47a
 
 
1155b96
 
 
049d47a
3830be6
1155b96
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import gradio as gr
import whisperx

HF_TOKEN = os.getenv("HF_TOKEN")  # MUST be set in HF Spaces secrets

ASR_MODEL = "palli23/whisper-small-sam_spjall-ct2"
DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
ALIGN_MODEL = "WAV2VEC2_ASR_LARGE_LV60K_960H"

def load_models():
    print("Loading WhisperX ASR...")
    asr = whisperx.load_model(
        model_name=ASR_MODEL,
        device="cuda" if whisperx.is_cuda_available() else "cpu",
        compute_type="int8",  # Safe for Spaces
        hf_token=HF_TOKEN
    )

    print("Loading alignment model...")
    align_model, metadata = whisperx.load_align_model(
        language_code="is", 
        model_name=ALIGN_MODEL,
        hf_token=HF_TOKEN
    )

    print("Loading diarization model...")
    diar = whisperx.DiarizationPipeline(
        DIARIZATION_MODEL,
        hf_token=HF_TOKEN,
        use_auth_token=True
    )

    return asr, align_model, metadata, diar


asr_model, align_model, align_metadata, diar_pipeline = load_models()


def transcribe(audio):
    if audio is None:
        return "No audio provided."

    print("Running ASR...")
    result = asr_model.transcribe(audio)

    print("Running alignment...")
    aligned = whisperx.align(
        result["segments"],
        align_model,
        align_metadata,
        audio,
        "is"
    )

    print("Running diarization...")
    diarization = diar_pipeline(audio)

    print("Assigning speaker labels...")
    final_result = whisperx.assign_word_speakers(
        diarization,
        aligned
    )

    text_out = ""
    for seg in final_result["segments"]:
        speaker = seg.get("speaker", "Unknown")
        text_out += f"[{speaker}] {seg['text']}\n"

    return text_out


ui = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(type="filepath"),
    outputs=gr.Textbox(label="Transcription + Speakers", lines=20),
    title="WhisperX Icelandic CT2 + Diarization",
    description="Uses your private CT2 Whisper Small model + alignment + pyannote diarization."
)

if __name__ == "__main__":
    ui.launch()