palli23 commited on
Commit
1155b96
·
verified ·
1 Parent(s): 9614b16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -36
app.py CHANGED
@@ -1,73 +1,82 @@
1
  import os
2
  import gradio as gr
3
- import torch
4
  import whisperx
5
 
6
- HF_TOKEN = os.getenv("HF_TOKEN")
7
 
8
- CT2_MODEL = "palli23/whisper-small-sam_spjall-ct2"
9
- DIAR_MODEL = "pyannote/speaker-diarization-3.1"
10
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
-
13
- def load_all_models():
14
- print("Loading ASR model...")
15
- asr_model = whisperx.load_model(
16
- CT2_MODEL,
17
- device=DEVICE,
18
- compute_type="float16" if DEVICE == "cuda" else "int8",
19
- token=HF_TOKEN,
20
  )
21
 
22
  print("Loading alignment model...")
23
  align_model, metadata = whisperx.load_align_model(
24
- language_code="is",
25
- device=DEVICE,
26
- token=HF_TOKEN,
27
  )
28
 
29
  print("Loading diarization model...")
30
- diar_model = whisperx.DiarizationPipeline(
31
- DIAR_MODEL,
32
- device=DEVICE,
33
- token=HF_TOKEN,
34
  )
35
 
36
- return asr_model, align_model, metadata, diar_model
 
37
 
 
38
 
39
- print("Initializing...")
40
- asr_model, align_model, align_metadata, diar_model = load_all_models()
41
 
 
 
 
42
 
43
- def transcribe(audio_file):
44
- audio = whisperx.load_audio(audio_file)
45
- result = asr_model.transcribe(audio, batch_size=16)
46
 
 
47
  aligned = whisperx.align(
48
  result["segments"],
49
  align_model,
50
  align_metadata,
51
  audio,
52
- DEVICE,
53
  )
54
 
55
- diarization = diar_model(audio)
56
- final_segments = whisperx.assign_speakers(aligned["segments"], diarization)
 
 
 
 
 
 
57
 
58
- output_text = ""
59
- for seg in final_segments:
60
  speaker = seg.get("speaker", "Unknown")
61
- output_text += f"[{speaker}] {seg['text']}\n"
62
 
63
- return output_text
64
 
65
 
66
  ui = gr.Interface(
67
  fn=transcribe,
68
  inputs=gr.Audio(type="filepath"),
69
- outputs=gr.Textarea(),
70
- title="WhisperX Icelandic + Diarization",
 
71
  )
72
 
73
- ui.launch()
 
 
1
  import os
2
  import gradio as gr
 
3
  import whisperx
4
 
5
+ HF_TOKEN = os.getenv("HF_TOKEN") # MUST be set in HF Spaces secrets
6
 
7
+ ASR_MODEL = "palli23/whisper-small-sam_spjall-ct2"
8
+ DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
9
+ ALIGN_MODEL = "WAV2VEC2_ASR_LARGE_LV60K_960H"
10
 
11
+ def load_models():
12
+ print("Loading WhisperX ASR...")
13
+ asr = whisperx.load_model(
14
+ model_name=ASR_MODEL,
15
+ device="cuda" if whisperx.is_cuda_available() else "cpu",
16
+ compute_type="int8", # Safe for Spaces
17
+ hf_token=HF_TOKEN
 
18
  )
19
 
20
  print("Loading alignment model...")
21
  align_model, metadata = whisperx.load_align_model(
22
+ language_code="is",
23
+ model_name=ALIGN_MODEL,
24
+ hf_token=HF_TOKEN
25
  )
26
 
27
  print("Loading diarization model...")
28
+ diar = whisperx.DiarizationPipeline(
29
+ DIARIZATION_MODEL,
30
+ hf_token=HF_TOKEN,
31
+ use_auth_token=True
32
  )
33
 
34
+ return asr, align_model, metadata, diar
35
+
36
 
37
+ asr_model, align_model, align_metadata, diar_pipeline = load_models()
38
 
 
 
39
 
40
+ def transcribe(audio):
41
+ if audio is None:
42
+ return "No audio provided."
43
 
44
+ print("Running ASR...")
45
+ result = asr_model.transcribe(audio)
 
46
 
47
+ print("Running alignment...")
48
  aligned = whisperx.align(
49
  result["segments"],
50
  align_model,
51
  align_metadata,
52
  audio,
53
+ "is"
54
  )
55
 
56
+ print("Running diarization...")
57
+ diarization = diar_pipeline(audio)
58
+
59
+ print("Assigning speaker labels...")
60
+ final_result = whisperx.assign_word_speakers(
61
+ diarization,
62
+ aligned
63
+ )
64
 
65
+ text_out = ""
66
+ for seg in final_result["segments"]:
67
  speaker = seg.get("speaker", "Unknown")
68
+ text_out += f"[{speaker}] {seg['text']}\n"
69
 
70
+ return text_out
71
 
72
 
73
  ui = gr.Interface(
74
  fn=transcribe,
75
  inputs=gr.Audio(type="filepath"),
76
+ outputs=gr.Textbox(label="Transcription + Speakers", lines=20),
77
+ title="WhisperX Icelandic CT2 + Diarization",
78
+ description="Uses your private CT2 Whisper Small model + alignment + pyannote diarization."
79
  )
80
 
81
+ if __name__ == "__main__":
82
+ ui.launch()