Tolulope Ogunremi commited on
Commit
207e539
·
1 Parent(s): 2e8ad24

Add application file

Browse files
Files changed (1) hide show
  1. app.py +970 -0
app.py ADDED
@@ -0,0 +1,970 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+
5
+ # Install private package at startup
6
+ print("Installing private package...")
7
+ gh_token = os.environ.get("GH_TOKEN")
8
+ if not gh_token:
9
+ raise ValueError("GH_TOKEN not found in environment variables")
10
+
11
+ package_url = f"git+https://{gh_token}@github.com/tolulope/speech-model-analysis.git"
12
+ os.system(f"{sys.executable} -m pip install {package_url}")
13
+
14
+ # Now import from your private package
15
+ from voxcommunis_core import (
16
+ VoxCommunisPreprocessor,
17
+ MultiModelAnalyzer,
18
+ create_hubert_configs
19
+ )
20
+
21
+ print("Private package loaded successfully!")
22
+
23
+ # Initialize your analyzer
24
+ OUTPUT_DIR = "tolulope/speech-model-analysis"
25
+ # analyzer = MultiModelAnalyzer(OUTPUT_DIR)
26
+
27
+ # def analyze_audio(audio_file, analysis_type):
28
+ # """Wrapper for audio analysis"""
29
+ # try:
30
+ # # Your analysis logic using the analyzer
31
+ # results = analyzer.analyze(audio_file, analysis_type)
32
+ # return results
33
+ # except Exception as e:
34
+ # return f"Error: {str(e)}"
35
+
36
+ # def run_preprocessing(voxcommunis_root, output_dir):
37
+ # """Wrapper for preprocessing"""
38
+ # try:
39
+ # preprocessor = VoxCommunisPreprocessor(
40
+ # voxcommunis_root=voxcommunis_root,
41
+ # output_dir=output_dir
42
+ # )
43
+
44
+ # # Your preprocessing logic
45
+ # hubert_configs = create_hubert_configs()
46
+ # # ... rest of preprocessing
47
+
48
+ # return "Preprocessing completed successfully!"
49
+ # except Exception as e:
50
+ # return f"Error: {str(e)}"
51
+
52
+
53
+ def create_integrated_gradio_interface(analyzer: MultiModelAnalyzer):
54
+ """
55
+ Create comprehensive Gradio interface with model comparison.
56
+
57
+ Args:
58
+ analyzer: MultiModelAnalyzer instance
59
+ """
60
+
61
+ # Extract feature options (same as before)
62
+ all_manners = sorted(set(p.manner.name for p in PHONEMES.values()
63
+ if p.manner))
64
+ all_places = sorted(set(p.place.name for p in PHONEMES.values()
65
+ if p.place))
66
+ all_voicings = ['voiced', 'voiceless']
67
+ all_heights = ['high', 'mid', 'low']
68
+ all_backness = ['front', 'central', 'back']
69
+
70
+ model_names = analyzer.get_model_names()
71
+
72
+ with gr.Blocks(title="Discrete Token Analysis", theme=gr.themes.Soft()) as demo:
73
+ gr.Markdown("#Discrete Token Phoneme Analysis")
74
+ # gr.Markdown("Compare HuBERT models and analyze discrete representations")
75
+
76
+ with gr.Tabs():
77
+ # Tab 1: Model Comparison
78
+ with gr.Tab("Model Comparison"):
79
+ gr.Markdown("### Compare Clustering Quality Across Models")
80
+
81
+ with gr.Row():
82
+ comparison_plot = gr.Plot(label="Metrics Comparison")
83
+ metrics_table = gr.Dataframe(label="Detailed Metrics")
84
+
85
+ refresh_comparison_btn = gr.Button("Refresh Comparison", variant="primary")
86
+
87
+ def update_comparison():
88
+ fig = analyzer.create_comparison_plot()
89
+ df = analyzer.compare_metrics()
90
+ return fig, df
91
+
92
+ refresh_comparison_btn.click(
93
+ fn=update_comparison,
94
+ outputs=[comparison_plot, metrics_table]
95
+ )
96
+
97
+ # Initialize
98
+ demo.load(
99
+ fn=update_comparison,
100
+ outputs=[comparison_plot, metrics_table]
101
+ )
102
+
103
+ # Tab 2: Single Model Analysis
104
+ """
105
+ with gr.Tab("Single Model Analysis"):
106
+ with gr.Row():
107
+ with gr.Column(scale=1):
108
+ gr.Markdown("### Model & Filters")
109
+
110
+ model_selector = gr.Dropdown(
111
+ model_names,
112
+ value=model_names[0] if model_names else None,
113
+ label="Select Model"
114
+ )
115
+
116
+ color_by = gr.Radio(
117
+ ['cluster', 'phone'],
118
+ value='cluster',
119
+ label="Color by"
120
+ )
121
+
122
+ gr.Markdown("#### Articulatory Filters")
123
+
124
+ manner_filter = gr.Dropdown(
125
+ all_manners,
126
+ multiselect=True,
127
+ label="Manner"
128
+ )
129
+
130
+ place_filter = gr.Dropdown(
131
+ all_places,
132
+ multiselect=True,
133
+ label="Place"
134
+ )
135
+
136
+ voicing_filter = gr.Dropdown(
137
+ all_voicings,
138
+ multiselect=True,
139
+ label="Voicing"
140
+ )
141
+
142
+ vowel_height_filter = gr.Dropdown(
143
+ all_heights,
144
+ multiselect=True,
145
+ label="Vowel Height"
146
+ )
147
+
148
+ vowel_backness_filter = gr.Dropdown(
149
+ all_backness,
150
+ multiselect=True,
151
+ label="Vowel Backness"
152
+ )
153
+
154
+ update_btn = gr.Button("Update Visualization", variant="primary")
155
+
156
+ with gr.Column(scale=2):
157
+ plot_output = gr.Plot(label="Cluster Visualization")
158
+ gr.Markdown("💡 **Tip**: Click on points to hear audio in the Audio Explorer tab!")
159
+
160
+ with gr.Row():
161
+ with gr.Column():
162
+ metrics_output = gr.Markdown()
163
+
164
+ with gr.Column():
165
+ confusion_output = gr.Plot(label="Confusion Matrix")
166
+
167
+ def update_single_model(model_name, color, manner, place, voicing, height, backness):
168
+ if not model_name:
169
+ return None, "Select a model", None
170
+
171
+ visualizer = analyzer.visualizers[model_name]
172
+
173
+ # Create scatter plot
174
+ fig = visualizer.create_scatter_plot(
175
+ color_by=color,
176
+ filter_manner=manner if manner else None,
177
+ filter_place=place if place else None,
178
+ filter_voicing=voicing if voicing else None,
179
+ filter_vowel_height=height if height else None,
180
+ filter_vowel_backness=backness if backness else None
181
+ )
182
+
183
+ # Calculate metrics
184
+ metrics = visualizer.calculate_metrics(
185
+ filter_manner=manner if manner else None,
186
+ filter_place=place if place else None,
187
+ filter_voicing=voicing if voicing else None,
188
+ filter_vowel_height=height if height else None,
189
+ filter_vowel_backness=backness if backness else None
190
+ )
191
+
192
+ # Create confusion matrix
193
+ confusion_fig = analyzer.create_confusion_heatmap(model_name)
194
+
195
+ return fig, metrics, confusion_fig
196
+
197
+ update_btn.click(
198
+ fn=update_single_model,
199
+ inputs=[model_selector, color_by, manner_filter, place_filter,
200
+ voicing_filter, vowel_height_filter, vowel_backness_filter],
201
+ outputs=[plot_output, metrics_output, confusion_output]
202
+ )
203
+ """
204
+
205
+ # Tab 3: Audio Explorer
206
+ """
207
+ with gr.Tab("Audio Explorer"):
208
+ gr.Markdown("### Listen to Cluster Samples")
209
+ gr.Markdown("Explore audio segments from clusters and phonemes")
210
+
211
+ with gr.Row():
212
+ with gr.Column():
213
+ audio_model_selector = gr.Dropdown(
214
+ model_names,
215
+ value=model_names[0] if model_names else None,
216
+ label="Select Model"
217
+ )
218
+
219
+ exploration_mode = gr.Radio(
220
+ ['By Cluster', 'By Phoneme', 'Compare Phoneme Across Clusters'],
221
+ value='By Cluster',
222
+ label="Exploration Mode"
223
+ )
224
+
225
+ # Cluster mode inputs
226
+ with gr.Group(visible=True) as cluster_inputs:
227
+ cluster_id_audio = gr.Number(
228
+ label="Cluster ID",
229
+ value=0,
230
+ precision=0
231
+ )
232
+ n_cluster_samples = gr.Slider(
233
+ 1, 10, value=5,
234
+ step=1,
235
+ label="Number of samples"
236
+ )
237
+
238
+ # Phoneme mode inputs
239
+ with gr.Group(visible=False) as phoneme_inputs:
240
+ phoneme_select = gr.Dropdown(
241
+ sorted(list(PHONEMES.keys())),
242
+ label="Select Phoneme",
243
+ value="æ"
244
+ )
245
+ n_phoneme_samples = gr.Slider(
246
+ 1, 10, value=5,
247
+ step=1,
248
+ label="Number of samples"
249
+ )
250
+
251
+ # Compare mode inputs
252
+ with gr.Group(visible=False) as compare_inputs:
253
+ phoneme_compare = gr.Dropdown(
254
+ sorted(list(PHONEMES.keys())),
255
+ label="Phoneme to Compare",
256
+ value="æ"
257
+ )
258
+ n_per_cluster = gr.Slider(
259
+ 1, 5, value=3,
260
+ step=1,
261
+ label="Samples per cluster"
262
+ )
263
+
264
+ play_audio_btn = gr.Button("🎵 Load Audio Samples", variant="primary")
265
+
266
+ with gr.Column(scale=2):
267
+ audio_output = gr.HTML(label="Audio Player")
268
+ audio_info = gr.Markdown()
269
+
270
+ # Toggle visibility based on mode
271
+ def update_visibility(mode):
272
+ return (
273
+ gr.update(visible=(mode == 'By Cluster')),
274
+ gr.update(visible=(mode == 'By Phoneme')),
275
+ gr.update(visible=(mode == 'Compare Phoneme Across Clusters'))
276
+ )
277
+
278
+ exploration_mode.change(
279
+ fn=update_visibility,
280
+ inputs=[exploration_mode],
281
+ outputs=[cluster_inputs, phoneme_inputs, compare_inputs]
282
+ )
283
+
284
+ def load_audio_samples(model_name, mode, cluster_id, n_cluster,
285
+ phoneme, n_phoneme, phoneme_cmp, n_per_clust):
286
+ if not model_name or model_name not in analyzer.audio_explorers:
287
+ return "<p>Audio not available for this model</p>", "No audio data loaded"
288
+
289
+ explorer = analyzer.audio_explorers[model_name]
290
+
291
+ try:
292
+ if mode == 'By Cluster':
293
+ samples = explorer.get_cluster_samples(
294
+ cluster_id=int(cluster_id),
295
+ n_samples=int(n_cluster)
296
+ )
297
+ info = f"### Cluster {cluster_id}\n\nShowing {len(samples)} samples"
298
+
299
+ elif mode == 'By Phoneme':
300
+ samples = explorer.get_phoneme_samples(
301
+ phoneme=phoneme,
302
+ n_samples=int(n_phoneme)
303
+ )
304
+ info = f"### Phoneme: {phoneme}\n\nShowing {len(samples)} samples"
305
+
306
+ else: # Compare mode
307
+ cluster_samples = explorer.compare_phoneme_in_clusters(
308
+ phoneme=phoneme_cmp,
309
+ n_per_cluster=int(n_per_clust)
310
+ )
311
+
312
+ # Flatten samples and add cluster headers
313
+ html = ""
314
+ info_lines = [f"### Phoneme: {phoneme_cmp} across clusters\n"]
315
+
316
+ for cluster_id, samps in sorted(cluster_samples.items()):
317
+ html += f'<h4>Cluster {cluster_id}</h4>'
318
+ html += create_audio_grid(samps, columns=3)
319
+ info_lines.append(f"- Cluster {cluster_id}: {len(samps)} samples")
320
+
321
+ return html, "\n".join(info_lines)
322
+
323
+ if not samples:
324
+ return "<p>No samples found</p>", "No matching samples"
325
+
326
+ html = create_audio_grid(samples, columns=3)
327
+ return html, info
328
+
329
+ except Exception as e:
330
+ return f"<p>Error loading audio: {str(e)}</p>", f"Error: {str(e)}"
331
+
332
+ play_audio_btn.click(
333
+ fn=load_audio_samples,
334
+ inputs=[audio_model_selector, exploration_mode,
335
+ cluster_id_audio, n_cluster_samples,
336
+ phoneme_select, n_phoneme_samples,
337
+ phoneme_compare, n_per_cluster],
338
+ outputs=[audio_output, audio_info]
339
+ )
340
+ """
341
+
342
+ # Tab 4: Export & Analysis
343
+ """
344
+ with gr.Tab("Export & Analysis"):
345
+ gr.Markdown("### Export Results")
346
+
347
+ with gr.Row():
348
+ export_model = gr.Dropdown(
349
+ model_names,
350
+ label="Select Model to Export"
351
+ )
352
+
353
+ export_format = gr.Radio(
354
+ ['CSV', 'JSON', 'NPZ'],
355
+ value='CSV',
356
+ label="Format"
357
+ )
358
+
359
+ export_btn = gr.Button("Export Data", variant="primary")
360
+ export_output = gr.File(label="Download")
361
+
362
+ def export_data(model_name, format_type):
363
+ if not model_name:
364
+ return None
365
+
366
+ data = analyzer.models[model_name]
367
+ output_path = f"{model_name}_export.{format_type.lower()}"
368
+
369
+ if format_type == 'CSV':
370
+ df = pd.DataFrame({
371
+ 'cluster': data['cluster_labels'],
372
+ 'phoneme': data['phoneme_strings'],
373
+ 'phone_idx': data['phone_labels']
374
+ })
375
+ df.to_csv(output_path, index=False)
376
+
377
+ elif format_type == 'JSON':
378
+ export_dict = {
379
+ 'clusters': data['cluster_labels'].tolist(),
380
+ 'phonemes': data['phoneme_strings'].tolist(),
381
+ 'phone_indices': data['phone_labels'].tolist()
382
+ }
383
+ with open(output_path, 'w') as f:
384
+ json.dump(export_dict, f, indent=2)
385
+
386
+ else: # NPZ
387
+ np.savez(
388
+ output_path,
389
+ features=data['features'],
390
+ clusters=data['cluster_labels'],
391
+ phones=data['phone_labels']
392
+ )
393
+
394
+ return output_path
395
+
396
+ export_btn.click(
397
+ fn=export_data,
398
+ inputs=[export_model, export_format],
399
+ outputs=[export_output]
400
+ )
401
+ """
402
+
403
+ # Tab 6: Context Pooling Analysis
404
+ """
405
+ with gr.Tab("Context Pooling"):
406
+ gr.Markdown("### Coarticulation Analysis")
407
+ gr.Markdown("Pool phoneme embeddings by context to account for coarticulation effects")
408
+
409
+ with gr.Row():
410
+ with gr.Column(scale=1):
411
+ gr.Markdown("#### Pooling Configuration")
412
+
413
+ context_model = gr.Dropdown(
414
+ model_names,
415
+ value=model_names[0] if model_names else None,
416
+ label="Select Model"
417
+ )
418
+
419
+ enable_pooling = gr.Checkbox(
420
+ label="Enable Context Pooling",
421
+ value=False
422
+ )
423
+
424
+ left_context = gr.Slider(
425
+ 0, 3, value=1, step=1,
426
+ label="Left Context (# phones)",
427
+ info="How many phones before target"
428
+ )
429
+
430
+ right_context = gr.Slider(
431
+ 0, 3, value=1, step=1,
432
+ label="Right Context (# phones)",
433
+ info="How many phones after target"
434
+ )
435
+
436
+ pooling_method = gr.Radio(
437
+ choices=['mean', 'median', 'max'],
438
+ value='mean',
439
+ label="Pooling Method"
440
+ )
441
+
442
+ min_samples = gr.Slider(
443
+ 1, 10, value=2, step=1,
444
+ label="Min Samples per Context",
445
+ info="Minimum instances to pool"
446
+ )
447
+
448
+ compute_pooling_btn = gr.Button("Apply Pooling", variant="primary")
449
+ pooling_status = gr.Markdown("")
450
+
451
+ gr.Markdown("#### Analyze Specific Phone")
452
+
453
+ phone_to_analyze = gr.Textbox(
454
+ label="Phoneme",
455
+ placeholder="æ",
456
+ value="æ"
457
+ )
458
+
459
+ analyze_phone_btn = gr.Button("Analyze Contexts")
460
+
461
+ with gr.Column(scale=2):
462
+ pooling_comparison = gr.Markdown("*Apply pooling to see comparison*")
463
+
464
+ context_analysis = gr.Markdown("*Analyze a phone to see contexts*")
465
+
466
+ # with gr.Row():
467
+ # pooled_plot = gr.Plot(label="Pooled Embeddings (UMAP)")
468
+
469
+ # Context pooling callbacks
470
+ def apply_context_pooling(model_name, enable, left, right, method, min_samp):
471
+ if not model_name or model_name not in analyzer.models:
472
+ return "Model not available", ""
473
+
474
+ data = analyzer.models[model_name]
475
+
476
+ if not enable:
477
+ # No pooling
478
+ metrics = calculate_all_metrics(
479
+ data['cluster_labels'],
480
+ data['phone_labels']
481
+ )
482
+
483
+ comparison = "### No Pooling (Baseline)\n\n"
484
+ comparison += f"- **Points**: {len(data['features'])}\n"
485
+ comparison += f"- **Cluster Purity**: {metrics['cluster_purity']:.3f}\n"
486
+ comparison += f"- **Phone Purity**: {metrics['phone_purity']:.3f}\n"
487
+ comparison += f"- **V-Measure**: {metrics['v_measure']:.3f}\n"
488
+ comparison += f"- **NMI**: {metrics.get('nmi', 0):.3f}\n"
489
+
490
+ return "No pooling applied (baseline)", comparison
491
+
492
+ try:
493
+ # Create context config
494
+ config = ContextConfig(
495
+ enabled=True,
496
+ left_context=int(left),
497
+ right_context=int(right),
498
+ pooling_method=method,
499
+ min_samples=int(min_samp)
500
+ )
501
+
502
+ # Create pooler
503
+ pooler = ContextAwarePooler(config)
504
+
505
+ # Pool embeddings
506
+ # Note: This assumes sequential data. In practice, you'd need
507
+ # utterance boundaries from preprocessing
508
+ phone_sequence = data['phone_labels'] # Simplified
509
+
510
+ pooled_embeddings, context_info = pooler.create_context_clusters(
511
+ data['features'],
512
+ data['phone_labels'],
513
+ phone_sequence,
514
+ utterance_boundaries=None # Would come from data
515
+ )
516
+
517
+ # Calculate metrics on pooled space
518
+ # Need to re-cluster or map clusters
519
+ from sklearn.cluster import KMeans
520
+ n_clusters = len(np.unique(data['cluster_labels']))
521
+ kmeans = KMeans(n_clusters=n_clusters, random_state=42)
522
+ pooled_clusters = kmeans.fit_predict(pooled_embeddings)
523
+
524
+ metrics = calculate_all_metrics(
525
+ pooled_clusters,
526
+ context_info['labels']
527
+ )
528
+
529
+ # Create comparison
530
+ comparison = f"### Context Pooling Results\n\n"
531
+ comparison += f"**Configuration**: L{left}R{right} ({method})\n\n"
532
+ comparison += f"- **Original Points**: {context_info['n_original']}\n"
533
+ comparison += f"- **Pooled Points**: {context_info['n_pooled']}\n"
534
+ comparison += f"- **Reduction**: {(1 - context_info['reduction_ratio'])*100:.1f}%\n\n"
535
+ comparison += f"**Metrics**:\n"
536
+ comparison += f"- **Cluster Purity**: {metrics['cluster_purity']:.3f}\n"
537
+ comparison += f"- **Phone Purity**: {metrics['phone_purity']:.3f}\n"
538
+ comparison += f"- **V-Measure**: {metrics['v_measure']:.3f}\n"
539
+ comparison += f"- **NMI**: {metrics.get('nmi', 0):.3f}\n"
540
+
541
+ status = f"Pooled {context_info['n_original']} → {context_info['n_pooled']} points"
542
+
543
+ return status, comparison
544
+
545
+ except Exception as e:
546
+ return f"Error: {str(e)}", ""
547
+
548
+ def analyze_phone_contexts(model_name, phone, left, right):
549
+ if not model_name or not phone:
550
+ return "*Enter phone to analyze*"
551
+
552
+ if model_name not in analyzer.models:
553
+ return "Model not available"
554
+
555
+ try:
556
+ data = analyzer.models[model_name]
557
+
558
+ # Create analyzer
559
+ ctx_analyzer = ContextAwareAnalyzer(
560
+ embeddings=data['features'],
561
+ phone_labels=data['phone_labels'],
562
+ phone_sequence=data['phone_labels'],
563
+ cluster_labels=data['cluster_labels']
564
+ )
565
+
566
+ # Analyze phone
567
+ analysis = ctx_analyzer.analyze_context_effects(phone, PHONEMES)
568
+
569
+ if 'error' in analysis:
570
+ return f"{analysis['error']}"
571
+
572
+ # Format output
573
+ output = f"### Analysis of /{phone}/\n\n"
574
+ output += f"- **Total occurrences**: {analysis['total_occurrences']}\n"
575
+ output += f"- **Unique contexts**: {analysis['unique_contexts']}\n\n"
576
+ output += f"**Most Common Contexts**:\n\n"
577
+
578
+ # Sort by count
579
+ contexts_sorted = sorted(
580
+ analysis['contexts'].items(),
581
+ key=lambda x: x[1]['count'],
582
+ reverse=True
583
+ )
584
+
585
+ for ctx_str, info in contexts_sorted[:10]:
586
+ output += f"- **{ctx_str}**: {info['count']} times"
587
+
588
+ if info['cluster_distribution']:
589
+ clusters = ", ".join(f"C{c}({cnt})"
590
+ for c, cnt in info['cluster_distribution'].items())
591
+ output += f" → {clusters}"
592
+
593
+ output += "\n"
594
+
595
+ if len(contexts_sorted) > 10:
596
+ output += f"\n*... and {len(contexts_sorted) - 10} more contexts*"
597
+
598
+ return output
599
+
600
+ except Exception as e:
601
+ return f"Error: {str(e)}"
602
+
603
+ # Connect callbacks
604
+ compute_pooling_btn.click(
605
+ fn=apply_context_pooling,
606
+ inputs=[context_model, enable_pooling, left_context, right_context,
607
+ pooling_method, min_samples],
608
+ outputs=[pooling_status, pooling_comparison]
609
+ )
610
+
611
+ analyze_phone_btn.click(
612
+ fn=analyze_phone_contexts,
613
+ inputs=[context_model, phone_to_analyze, left_context, right_context],
614
+ outputs=[context_analysis]
615
+ )
616
+ """
617
+ with gr.Tab("Embedding Projector"):
618
+ gr.Markdown("### TensorFlow Projector-Style 3D Visualization")
619
+ gr.Markdown("Interactive exploration similar to TensorFlow's Embedding Projector")
620
+
621
+ with gr.Row():
622
+ # Left sidebar
623
+ with gr.Column(scale=1):
624
+ gr.Markdown("#### Model & Projection")
625
+
626
+ projector_model = gr.Dropdown(
627
+ model_names,
628
+ value=model_names[0] if model_names else None,
629
+ label="Select Model"
630
+ )
631
+
632
+ projection_method = gr.Radio(
633
+ # choices=['PCA', 't-SNE', 'UMAP'],
634
+ choices=['PCA', 'UMAP'],
635
+ value='UMAP',
636
+ label="Projection Method"
637
+ )
638
+
639
+ dimension = gr.Radio(
640
+ choices=['3D', '2D'],
641
+ value='3D',
642
+ label="Dimensions"
643
+ )
644
+
645
+ projector_color_by = gr.Radio(
646
+ # choices=['cluster', 'phone', 'language'],
647
+ choices=['cluster', 'language'],
648
+ value='cluster',
649
+ label="Color by"
650
+ )
651
+
652
+ compute_btn = gr.Button("Compute Projections", variant="primary")
653
+ compute_status = gr.Markdown("*Click to compute projections*")
654
+
655
+ gr.Markdown("#### Search & Highlight")
656
+
657
+ search_mode = gr.Radio(
658
+ choices=['By Label', 'By Features'],
659
+ value='By Label',
660
+ label="Search Mode"
661
+ )
662
+
663
+ # Label search (simple)
664
+ with gr.Group(visible=True) as label_search_group:
665
+ search_label_type = gr.Radio(
666
+ choices=['phone', 'cluster', 'language'],
667
+ value='phone',
668
+ label="Search in"
669
+ )
670
+
671
+ search_term = gr.Textbox(
672
+ label="Search term",
673
+ placeholder="e.g., 'æ' or '5'"
674
+ )
675
+
676
+ # Feature search (advanced)
677
+ with gr.Group(visible=False) as feature_search_group:
678
+ search_manner = gr.Dropdown(
679
+ choices=['stop', 'fricative', 'nasal', 'approximant',
680
+ 'affricate', 'tap/flap'],
681
+ multiselect=True,
682
+ label="Manner"
683
+ )
684
+
685
+ search_place = gr.Dropdown(
686
+ choices=['bilabial', 'labiodental', 'dental', 'alveolar',
687
+ 'postalveolar', 'palatal', 'velar', 'uvular',
688
+ 'pharyngeal', 'glottal'],
689
+ multiselect=True,
690
+ label="Place"
691
+ )
692
+
693
+ search_voicing = gr.Dropdown(
694
+ choices=['voiced', 'voiceless'],
695
+ multiselect=True,
696
+ label="Voicing"
697
+ )
698
+
699
+ search_vowel_height = gr.Dropdown(
700
+ choices=['high', 'mid', 'low'],
701
+ multiselect=True,
702
+ label="Vowel Height"
703
+ )
704
+
705
+ search_vowel_backness = gr.Dropdown(
706
+ choices=['front', 'central', 'back'],
707
+ multiselect=True,
708
+ label="Vowel Backness"
709
+ )
710
+
711
+ search_btn = gr.Button("🔍 Search")
712
+
713
+ gr.Markdown("#### Nearest Neighbors")
714
+
715
+ point_idx = gr.Number(
716
+ label="Point index",
717
+ value=0,
718
+ precision=0
719
+ )
720
+
721
+ n_neighbors = gr.Slider(
722
+ 1, 50, value=10,
723
+ step=1,
724
+ label="Number of neighbors"
725
+ )
726
+
727
+ show_nn_btn = gr.Button("Show Neighbors")
728
+
729
+ info_display = gr.Markdown("*Select a point or search*")
730
+
731
+ # Main visualization area
732
+ with gr.Column(scale=3):
733
+ projector_plot = gr.Plot(label="Embedding Space")
734
+
735
+ # with gr.Row():
736
+ # comparison_btn = gr.Button("Show Comparison View (PCA | t-SNE | UMAP)")
737
+
738
+ # comparison_plot = gr.Plot(label="Comparison", visible=False)
739
+
740
+ # Projector callbacks
741
+ def compute_projections(model_name, method):
742
+ if not model_name or model_name not in analyzer.projector_vizs:
743
+ return "Model not available", None
744
+
745
+ viz = analyzer.projector_vizs[model_name]
746
+
747
+ try:
748
+ method_lower = method.lower()
749
+ viz.compute_projections(method_lower)
750
+
751
+ # Create initial plot
752
+ proj_key = f"{method_lower}_3d"
753
+ fig = viz.create_3d_scatter(
754
+ projection=proj_key,
755
+ color_by='cluster'
756
+ )
757
+
758
+ return f"{method} projections computed!", fig
759
+ except Exception as e:
760
+ return f"Error: {str(e)}", None
761
+
762
+ def toggle_search_mode(mode):
763
+ """Toggle between label and feature search."""
764
+ if mode == 'By Label':
765
+ return gr.update(visible=True), gr.update(visible=False)
766
+ else:
767
+ return gr.update(visible=False), gr.update(visible=True)
768
+
769
+ def update_projector_plot(model_name, method, dim, color_by_val, highlight_indices=None):
770
+ if not model_name or model_name not in analyzer.projector_vizs:
771
+ return None
772
+
773
+ viz = analyzer.projector_vizs[model_name]
774
+ proj_key = f"{method.lower()}_{dim.lower()}"
775
+
776
+ # Check if projection exists
777
+ if proj_key not in viz.projections:
778
+ return None
779
+
780
+ try:
781
+ if dim == '3D':
782
+ fig = viz.create_3d_scatter(
783
+ projection=proj_key,
784
+ color_by=color_by_val.lower(),
785
+ highlight_indices=highlight_indices
786
+ )
787
+ else:
788
+ fig = viz.create_2d_scatter(
789
+ projection=proj_key,
790
+ color_by=color_by_val.lower(),
791
+ highlight_indices=highlight_indices
792
+ )
793
+ return fig
794
+ except Exception as e:
795
+ print(f"Error creating plot: {e}")
796
+ return None
797
+
798
+ def search_points(model_name, search_mode, search_type, term, method, dim,
799
+ color_by_val, manner, place, voicing, vheight, vbackness):
800
+ if not model_name or model_name not in analyzer.projector_vizs:
801
+ return None, "Model not available"
802
+
803
+ viz = analyzer.projector_vizs[model_name]
804
+
805
+ if search_mode == 'By Label':
806
+ if not term:
807
+ fig = update_projector_plot(model_name, method, dim, color_by_val)
808
+ return fig, "No search term provided"
809
+
810
+ matches = viz.search_by_label(term, search_type.lower())
811
+ info = f"Found {len(matches)} matches for '{term}' in {search_type}"
812
+
813
+ else: # By Features
814
+ matches = viz.search_by_articulatory_features(
815
+ PHONEMES,
816
+ manner=manner if manner else None,
817
+ place=place if place else None,
818
+ voicing=voicing if voicing else None,
819
+ vowel_height=vheight if vheight else None,
820
+ vowel_backness=vbackness if vbackness else None
821
+ )
822
+
823
+ # Get summary
824
+ summary = viz.get_articulatory_summary(matches, PHONEMES)
825
+
826
+ info = f"Found {len(matches)} points matching features:\n\n"
827
+
828
+ if manner:
829
+ info += f"**Manner**: {', '.join(manner)}\n"
830
+ if place:
831
+ info += f"**Place**: {', '.join(place)}\n"
832
+ if voicing:
833
+ info += f"**Voicing**: {', '.join(voicing)}\n"
834
+ if vheight:
835
+ info += f"**Vowel Height**: {', '.join(vheight)}\n"
836
+ if vbackness:
837
+ info += f"**Vowel Backness**: {', '.join(vbackness)}\n"
838
+
839
+ if summary and len(matches) > 0:
840
+ info += f"\n**Distribution**:\n"
841
+ if summary.get('manner'):
842
+ info += "- Manner: " + ", ".join(
843
+ f"{k}({v})" for k, v in sorted(summary['manner'].items())
844
+ ) + "\n"
845
+ if summary.get('place'):
846
+ info += "- Place: " + ", ".join(
847
+ f"{k}({v})" for k, v in sorted(summary['place'].items())
848
+ ) + "\n"
849
+
850
+ fig = update_projector_plot(model_name, method, dim, color_by_val,
851
+ highlight_indices=matches)
852
+
853
+ if matches:
854
+ if len(matches) <= 10:
855
+ info += f"\n\nIndices: {matches}"
856
+ else:
857
+ info += f"\n\nSample indices: {matches[:10]}... (+{len(matches)-10} more)"
858
+
859
+ return fig, info
860
+
861
+ def show_neighbors(model_name, idx, n, method, dim, color_by_val):
862
+ if not model_name or model_name not in analyzer.projector_vizs:
863
+ return None, "Model not available"
864
+
865
+ viz = analyzer.projector_vizs[model_name]
866
+
867
+ if viz.nn_model is None:
868
+ viz.build_nn_index()
869
+
870
+ neighbors, distances = viz.find_nearest_neighbors(int(idx), int(n))
871
+
872
+ # Show with lines to neighbors
873
+ line_pairs = [(int(idx), int(nn)) for nn in neighbors]
874
+
875
+ proj_key = f"{method.lower()}_{dim.lower()}"
876
+
877
+ if proj_key not in viz.projections:
878
+ return None, "Projections not computed"
879
+
880
+ if dim == '3D':
881
+ fig = viz.create_3d_scatter(
882
+ projection=proj_key,
883
+ color_by=color_by_val.lower(),
884
+ highlight_indices=[int(idx)] + list(neighbors),
885
+ show_lines=True,
886
+ line_pairs=line_pairs
887
+ )
888
+ else:
889
+ fig = viz.create_2d_scatter(
890
+ projection=proj_key,
891
+ color_by=color_by_val.lower(),
892
+ highlight_indices=[int(idx)] + list(neighbors)
893
+ )
894
+
895
+ info = f"Point {idx} - Nearest {n} neighbors:\n\n"
896
+ for i, (nn_idx, dist) in enumerate(zip(neighbors, distances), 1):
897
+ info += f"{i}. Index {nn_idx} (distance: {dist:.3f})\n"
898
+
899
+ return fig, info
900
+
901
+ def show_comparison_view(model_name, color_by_val):
902
+ if not model_name or model_name not in analyzer.projector_vizs:
903
+ return gr.update(visible=False), None
904
+
905
+ viz = analyzer.projector_vizs[model_name]
906
+
907
+ # Ensure all projections exist
908
+ for method in ['pca', 'tsne', 'umap']:
909
+ if f'{method}_3d' not in viz.projections:
910
+ return gr.update(visible=False), None
911
+
912
+ fig = viz.create_comparison_view(color_by=color_by_val.lower())
913
+ return gr.update(visible=True), fig
914
+
915
+ # Connect callbacks
916
+ compute_btn.click(
917
+ fn=compute_projections,
918
+ inputs=[projector_model, projection_method],
919
+ outputs=[compute_status, projector_plot]
920
+ )
921
+
922
+ search_mode.change(
923
+ fn=toggle_search_mode,
924
+ inputs=[search_mode],
925
+ outputs=[label_search_group, feature_search_group]
926
+ )
927
+
928
+ for component in [projection_method, dimension, projector_color_by]:
929
+ component.change(
930
+ fn=lambda m, meth, d, c: update_projector_plot(m, meth, d, c),
931
+ inputs=[projector_model, projection_method, dimension, projector_color_by],
932
+ outputs=[projector_plot]
933
+ )
934
+
935
+ search_btn.click(
936
+ fn=search_points,
937
+ inputs=[projector_model, search_mode, search_label_type, search_term,
938
+ projection_method, dimension, projector_color_by,
939
+ search_manner, search_place, search_voicing,
940
+ search_vowel_height, search_vowel_backness],
941
+ outputs=[projector_plot, info_display]
942
+ )
943
+
944
+ show_nn_btn.click(
945
+ fn=show_neighbors,
946
+ inputs=[projector_model, point_idx, n_neighbors,
947
+ projection_method, dimension, projector_color_by],
948
+ outputs=[projector_plot, info_display]
949
+ )
950
+
951
+ # comparison_btn.click(
952
+ # fn=lambda m, c: show_comparison_view(m, c),
953
+ # inputs=[projector_model, projector_color_by],
954
+ # outputs=[comparison_plot, comparison_plot]
955
+ # )
956
+
957
+ return demo
958
+
959
+ if __name__ == "__main__":
960
+ # Create analyzer
961
+ analyzer = MultiModelAnalyzer(OUTPUT_DIR)
962
+
963
+ # Create and launch interface
964
+ demo = create_integrated_gradio_interface(analyzer)
965
+ demo.launch(
966
+ # server_port=args.port,
967
+ # share=True # Creates public link
968
+ )
969
+ # demo = create_interface()
970
+ # demo.launch()