ahmad walidurosyad Claude commited on
Commit
b02904a
·
1 Parent(s): 5ea2528

Add user-selectable model UI with DiariZen support

Browse files

- Add dropdown UI for model selection (4 models available)
- Support DiariZen WavLM Large/Base/MLC models (no token required)
- Support Pyannote 3.1 model (requires HF_TOKEN)
- Implement model caching for performance
- Add status messages and model info display
- Improve error handling and user feedback
- Fix: DiariZen models now use correct API (DiariZenPipeline)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +203 -66
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,86 +1,223 @@
1
- import spaces
2
  import gradio as gr
 
 
 
3
  from gryannote_audio import AudioLabeling
4
  from gryannote_rttm import RTTM
5
- from pyannote.audio import Pipeline
6
- import os
7
- import torch
8
-
9
- @spaces.GPU(duration=120)
10
- def apply_pipeline(audio):
11
- """Apply specified pipeline on the indicated audio file"""
12
- pipeline = Pipeline.from_pretrained("BUT-FIT/diarizen-wavlm-large-s80-md", use_auth_token=os.environ["HF_TOKEN"])
13
- pipeline.to(torch.device("cuda"))
14
- annotations = pipeline(audio)
15
-
16
- return ((audio, annotations), annotations)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def update_annotations(data):
20
- return rttm.on_edit(data)
21
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- with gr.Blocks() as demo:
24
  with gr.Row():
25
- with gr.Column():
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  with gr.Row():
28
- with gr.Column(scale=1):
29
- gr.Markdown(
30
- '<a href="https://github.com/clement-pages/gryannote"><img src="https://github.com/clement-pages/gryannote/blob/main/docs/assets/logo-gryannote.png?raw=true" alt="gryannote logo" width="140"/></a>',
31
- )
32
- with gr.Column(scale=10):
33
- gr.Markdown('<h1 style="font-size: 4em;">gryannote</h1>')
34
- gr.Markdown()
35
- gr.Markdown('<h2 style="font-size: 2em;">Make the audio labeling process easier and faster! </h2>')
36
-
37
- with gr.Tab("application"):
38
- gr.Markdown(
39
- "To use the component, start by loading or recording audio."
40
- "Then apply the diarization pipeline (here [pyannote/speaker-diarization-3.1](https://huggingface.co/pyannote/speaker-diarization-3.1))"
41
- "or double-click directly on the waveform to add an annotations. The annotations produced can be edited."
42
- " You can also use keyboard shortcuts to speed things up! Click on the help button to see all the available shortcuts."
43
- " Finally, annotations can be saved by cliking on the downloading button in the RTTM component."
44
- )
45
- gr.Markdown()
46
- gr.Markdown()
47
- audio_labeling = AudioLabeling(
48
- type="filepath",
49
- interactive=True,
50
- )
51
-
52
- gr.Markdown()
53
- gr.Markdown()
54
-
55
- run_btn = gr.Button("Run pipeline")
56
-
57
- rttm = RTTM()
58
-
59
- with gr.Tab("poster"):
60
- gr.Markdown(
61
- '<p align="center"><img src="https://github.com/clement-pages/gryannote/blob/main/docs/assets/poster-interspeech.jpg?raw=true" alt="gryannote poster" width=700em/></p>'
62
- )
 
 
 
 
 
 
 
 
 
63
 
 
 
 
 
 
 
 
64
  run_btn.click(
65
  fn=apply_pipeline,
66
- inputs=audio_labeling,
67
- outputs=[audio_labeling, rttm],
68
  )
69
 
70
- audio_labeling.edit(
71
- fn=update_annotations,
72
- inputs=audio_labeling,
73
- outputs=rttm,
74
- preprocess=False,
75
- postprocess=False,
76
  )
77
 
78
- rttm.upload(
79
- fn=audio_labeling.load_annotations,
80
- inputs=[audio_labeling, rttm],
81
- outputs=audio_labeling,
 
82
  )
83
 
 
 
 
 
 
84
 
85
  if __name__ == "__main__":
86
  demo.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ import os
4
+ import spaces
5
  from gryannote_audio import AudioLabeling
6
  from gryannote_rttm import RTTM
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Model cache to avoid reloading
9
+ model_cache = {}
10
+
11
+ AVAILABLE_MODELS = {
12
+ "DiariZen WavLM Large (Recommended)": {
13
+ "id": "BUT-FIT/diarizen-wavlm-large-s80-md",
14
+ "type": "diarizen",
15
+ "requires_token": False,
16
+ "speed": "Fast",
17
+ "quality": "High",
18
+ "description": "Optimized 63M parameter model with excellent performance"
19
+ },
20
+ "DiariZen WavLM Base": {
21
+ "id": "BUT-FIT/diarizen-wavlm-base-s80-md",
22
+ "type": "diarizen",
23
+ "requires_token": False,
24
+ "speed": "Very Fast",
25
+ "quality": "Good",
26
+ "description": "Lighter model for faster inference"
27
+ },
28
+ "DiariZen WavLM Large MLC": {
29
+ "id": "BUT-FIT/diarizen-wavlm-large-s80-mlc",
30
+ "type": "diarizen",
31
+ "requires_token": False,
32
+ "speed": "Fast",
33
+ "quality": "High",
34
+ "description": "Multi-language optimized variant"
35
+ },
36
+ "Pyannote 3.1": {
37
+ "id": "pyannote/speaker-diarization-3.1",
38
+ "type": "pyannote",
39
+ "requires_token": True,
40
+ "speed": "Medium",
41
+ "quality": "High",
42
+ "description": "Original pyannote model (requires HF token)"
43
+ }
44
+ }
45
+
46
+ def load_pipeline(model_name):
47
+ """Load diarization pipeline based on model selection"""
48
+ model_config = AVAILABLE_MODELS[model_name]
49
+ model_id = model_config["id"]
50
+
51
+ # Check cache first
52
+ if model_id in model_cache:
53
+ return model_cache[model_id], None
54
+
55
+ try:
56
+ if model_config["type"] == "diarizen":
57
+ from diarizen.pipelines.inference import DiariZenPipeline
58
+ pipeline = DiariZenPipeline.from_pretrained(model_id)
59
+
60
+ elif model_config["type"] == "pyannote":
61
+ from pyannote.audio import Pipeline
62
+
63
+ # Check for HF token
64
+ if "HF_TOKEN" not in os.environ:
65
+ return None, "⚠️ Pyannote requires HF_TOKEN in Space secrets"
66
+
67
+ pipeline = Pipeline.from_pretrained(
68
+ model_id,
69
+ use_auth_token=os.environ["HF_TOKEN"]
70
+ )
71
+
72
+ # Move to GPU if available
73
+ if torch.cuda.is_available():
74
+ pipeline.to(torch.device("cuda"))
75
+
76
+ # Cache the model
77
+ model_cache[model_id] = pipeline
78
+
79
+ return pipeline, f"✅ {model_name} loaded successfully"
80
+
81
+ except Exception as e:
82
+ return None, f"❌ Error loading {model_name}: {str(e)}"
83
 
84
+ @spaces.GPU(duration=120)
85
+ def apply_pipeline(audio, model_name):
86
+ """Apply selected diarization model to audio"""
87
+
88
+ if audio is None:
89
+ return None, None, "⚠️ Please upload or record audio first"
90
+
91
+ # Load pipeline
92
+ pipeline, message = load_pipeline(model_name)
93
+
94
+ if pipeline is None:
95
+ return None, None, message
96
+
97
+ # Run diarization
98
+ try:
99
+ annotations = pipeline(audio)
100
+ return (audio, annotations), annotations, f"✅ Diarization complete with {model_name}"
101
+ except Exception as e:
102
+ return None, None, f"❌ Error during diarization: {str(e)}"
103
+
104
+ def update_annotations(new_annotations):
105
+ """Update RTTM annotations from audio labeling"""
106
+ rttm_obj.annotations = new_annotations
107
+ return new_annotations
108
+
109
+ def load_rttm_to_audio(rttm_annotations):
110
+ """Load RTTM annotations to audio labeling"""
111
+ audio_labeling.load_annotations(rttm_annotations)
112
+ return audio_labeling.value
113
+
114
+ # Initialize components
115
+ audio_labeling = AudioLabeling(type="filepath")
116
+ rttm_obj = RTTM()
117
+
118
+ # Build Gradio Interface
119
+ with gr.Blocks(title="GryanNote - Speaker Diarization") as demo:
120
+ gr.Markdown("""
121
+ # 🎙️ GryanNote - Speaker Diarization
122
+ Label speakers in audio recordings using state-of-the-art diarization models
123
+ """)
124
 
 
125
  with gr.Row():
126
+ with gr.Column(scale=1):
127
+ # Model selection dropdown
128
+ model_selector = gr.Dropdown(
129
+ choices=list(AVAILABLE_MODELS.keys()),
130
+ value="DiariZen WavLM Large (Recommended)",
131
+ label="🤖 Select Diarization Model",
132
+ info="Choose the model for speaker diarization"
133
+ )
134
+
135
+ # Model info display
136
+ with gr.Accordion("ℹ️ Model Information", open=False):
137
+ model_info = gr.Markdown()
138
+
139
+ # Audio input
140
+ audio_labeling.render()
141
+
142
+ # Action buttons
143
  with gr.Row():
144
+ run_btn = gr.Button("▶️ Run Diarization", variant="primary", size="lg")
145
+ clear_btn = gr.Button("🗑️ Clear", size="lg")
146
+
147
+ with gr.Column(scale=1):
148
+ # Status message
149
+ status_msg = gr.Textbox(
150
+ label="📊 Status",
151
+ interactive=False,
152
+ lines=3
153
+ )
154
+
155
+ # RTTM output
156
+ gr.Markdown("### 📝 RTTM Output")
157
+ rttm_obj.render()
158
+
159
+ # Footer
160
+ gr.Markdown("""
161
+ ---
162
+ **Models:**
163
+ - **DiariZen**: Optimized models by BUT-FIT, no token required
164
+ - **Pyannote**: Original model, requires HF token in Space secrets
165
+
166
+ **Usage:** Upload audio → Select model → Run diarization → Download/Edit annotations
167
+ """)
168
+
169
+ # Update model info when selection changes
170
+ def update_model_info(model_name):
171
+ config = AVAILABLE_MODELS[model_name]
172
+ info = f"""
173
+ **Model ID:** `{config['id']}`
174
+ **Type:** {config['type'].upper()}
175
+ **Speed:** {config['speed']} | **Quality:** {config['quality']}
176
+ **Token Required:** {'Yes ⚠️ (Add HF_TOKEN to Space secrets)' if config['requires_token'] else 'No ✅'}
177
+
178
+ {config['description']}
179
+ """
180
+ return info
181
+
182
+ # Initialize model info
183
+ demo.load(
184
+ fn=update_model_info,
185
+ inputs=[model_selector],
186
+ outputs=[model_info]
187
+ )
188
 
189
+ model_selector.change(
190
+ fn=update_model_info,
191
+ inputs=[model_selector],
192
+ outputs=[model_info]
193
+ )
194
+
195
+ # Run pipeline button
196
  run_btn.click(
197
  fn=apply_pipeline,
198
+ inputs=[audio_labeling.value, model_selector],
199
+ outputs=[audio_labeling.value, rttm_obj.value, status_msg]
200
  )
201
 
202
+ # Clear button
203
+ clear_btn.click(
204
+ fn=lambda: (None, None, "Cleared"),
205
+ inputs=[],
206
+ outputs=[audio_labeling.value, rttm_obj.value, status_msg]
 
207
  )
208
 
209
+ # Sync annotations between components
210
+ audio_labeling.change(
211
+ fn=update_annotations,
212
+ inputs=[audio_labeling.value],
213
+ outputs=[rttm_obj.value]
214
  )
215
 
216
+ rttm_obj.upload(
217
+ fn=load_rttm_to_audio,
218
+ inputs=[rttm_obj.value],
219
+ outputs=[audio_labeling.value]
220
+ )
221
 
222
  if __name__ == "__main__":
223
  demo.launch()
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  gryannote==0.3.3
2
  pyannote-audio==3.3.2
 
3
  spaces==0.30.2
 
 
 
1
  gryannote==0.3.3
2
  pyannote-audio==3.3.2
3
+ diarizen
4
  spaces==0.30.2
5
+ torch
6
+ gradio