Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import sys | |
| import subprocess | |
| import importlib | |
| from pathlib import Path | |
| import json | |
| import pandas as pd | |
| def install_private_package(): | |
| """Install private package from GitHub using token""" | |
| print("Installing private package...") | |
| gh_token = os.environ.get("GH_TOKEN") | |
| if not gh_token: | |
| raise ValueError("GH_TOKEN not found in environment variables") | |
| package_url = f"git+https://{gh_token}@github.com/tolulope/speech-model-analysis.git" | |
| # Use subprocess for better error handling | |
| result = subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--no-cache-dir", package_url], | |
| capture_output=True, | |
| text=True | |
| ) | |
| if result.returncode != 0: | |
| print("STDOUT:", result.stdout) | |
| print("STDERR:", result.stderr) | |
| raise RuntimeError(f"Failed to install package: {result.stderr}") | |
| print("✓ Package installed successfully!") | |
| # Clear import caches so Python recognizes the new package | |
| importlib.invalidate_caches() | |
| # Install the package first | |
| install_private_package() | |
| # # Install private package at startup | |
| # print("Installing private package...") | |
| # gh_token = os.environ.get("GH_TOKEN") | |
| # if not gh_token: | |
| # raise ValueError("GH_TOKEN not found in environment variables") | |
| # package_url = f"git+https://{gh_token}@github.com/tolulope/speech-model-analysis.git" | |
| # os.system(f"{sys.executable} -m pip install {package_url}") | |
| # Now import from your private package | |
| from speech_model_analysis import ( | |
| VoxCommunisPreprocessor, | |
| MultiModelAnalyzer, | |
| create_hubert_configs, | |
| ) | |
| from speech_model_analysis.phoneme_manager import PHONEMES, index_to_phoneme | |
| from speech_model_analysis.voxcommunis_preprocessing import VoxCommunisPreprocessor, create_hubert_configs | |
| from speech_model_analysis.gradio_viz import ClusterVisualizer | |
| from speech_model_analysis.enhanced_analysis import calculate_all_metrics | |
| from speech_model_analysis.audio_player import ClusterAudioExplorer, create_audio_grid | |
| from speech_model_analysis.embedding_projector_viz import EmbeddingProjectorViz | |
| from speech_model_analysis.context_pooling import ContextConfig, ContextAwarePooler, ContextAwareAnalyzer | |
| print("Private package loaded successfully!") | |
| from huggingface_hub import hf_hub_download, snapshot_download, login | |
| login(os.environ["HF_TOKEN"]) | |
| # Download the full repo snapshot to a local dir | |
| OUTPUT_DIR = snapshot_download("tolulope/speech-model-analysis", repo_type="dataset") | |
| def get_top_level_dirs(root): | |
| root = Path(root) | |
| return [d for d in root.iterdir() if d.is_dir()] | |
| def load_analyzer_for_subdir(subdir_path): | |
| return MultiModelAnalyzer(str(subdir_path)) | |
| def toggle_tsne_params(method): | |
| visible = method == "t-SNE" | |
| return [ | |
| gr.update(visible=visible), | |
| gr.update(visible=visible), | |
| gr.update(visible=visible) | |
| ] | |
| def create_integrated_gradio_interface(analyzer: MultiModelAnalyzer): | |
| """ | |
| Create comprehensive Gradio interface with model comparison. | |
| Args: | |
| analyzer: MultiModelAnalyzer instance | |
| """ | |
| # Extract feature options (same as before) | |
| all_manners = sorted(set(p.manner.name for p in PHONEMES.values() | |
| if p.manner)) | |
| all_places = sorted(set(p.place.name for p in PHONEMES.values() | |
| if p.place)) | |
| all_voicings = ['voiced', 'voiceless'] | |
| all_heights = ['high', 'mid', 'low'] | |
| all_backness = ['front', 'central', 'back'] | |
| model_names = analyzer.get_model_names() | |
| with gr.Blocks(title="Discrete Token Analysis") as demo: | |
| gr.Markdown("# Discrete Token Phoneme Analysis") | |
| # gr.Markdown("Compare HuBERT models and analyze discrete representations") | |
| with gr.Tabs(): | |
| # Tab 1: Model Comparison | |
| with gr.Tab("Model Comparison"): | |
| gr.Markdown("### Compare Clustering Quality Across Models") | |
| with gr.Row(): | |
| # comparison_plot = gr.Plot(label="Metrics Comparison") | |
| metrics_table = gr.Dataframe(label="Detailed Metrics") | |
| refresh_comparison_btn = gr.Button("Refresh Comparison", variant="primary") | |
| def update_comparison(): | |
| # fig = analyzer.create_comparison_plot() | |
| df = analyzer.compare_metrics() | |
| df = df.round(2) | |
| return df | |
| # refresh_comparison_btn.click( | |
| # fn=update_comparison, | |
| # outputs=[comparison_plot, metrics_table] | |
| # ) | |
| # Initialize | |
| demo.load( | |
| fn=update_comparison, | |
| # outputs=[comparison_plot, metrics_table] | |
| outputs=[metrics_table] | |
| ) | |
| # Tab 2: Single Model Analysis | |
| """ | |
| with gr.Tab("Single Model Analysis"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model & Filters") | |
| model_selector = gr.Dropdown( | |
| model_names, | |
| value=model_names[0] if model_names else None, | |
| label="Select Model" | |
| ) | |
| color_by = gr.Radio( | |
| ['cluster', 'phone'], | |
| value='cluster', | |
| label="Color by" | |
| ) | |
| gr.Markdown("#### Articulatory Filters") | |
| manner_filter = gr.Dropdown( | |
| all_manners, | |
| multiselect=True, | |
| label="Manner" | |
| ) | |
| place_filter = gr.Dropdown( | |
| all_places, | |
| multiselect=True, | |
| label="Place" | |
| ) | |
| voicing_filter = gr.Dropdown( | |
| all_voicings, | |
| multiselect=True, | |
| label="Voicing" | |
| ) | |
| vowel_height_filter = gr.Dropdown( | |
| all_heights, | |
| multiselect=True, | |
| label="Vowel Height" | |
| ) | |
| vowel_backness_filter = gr.Dropdown( | |
| all_backness, | |
| multiselect=True, | |
| label="Vowel Backness" | |
| ) | |
| update_btn = gr.Button("Update Visualization", variant="primary") | |
| with gr.Column(scale=2): | |
| plot_output = gr.Plot(label="Cluster Visualization") | |
| gr.Markdown("💡 **Tip**: Click on points to hear audio in the Audio Explorer tab!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| metrics_output = gr.Markdown() | |
| with gr.Column(): | |
| confusion_output = gr.Plot(label="Confusion Matrix") | |
| def update_single_model(model_name, color, manner, place, voicing, height, backness): | |
| if not model_name: | |
| return None, "Select a model", None | |
| visualizer = analyzer.visualizers[model_name] | |
| # Create scatter plot | |
| fig = visualizer.create_scatter_plot( | |
| color_by=color, | |
| filter_manner=manner if manner else None, | |
| filter_place=place if place else None, | |
| filter_voicing=voicing if voicing else None, | |
| filter_vowel_height=height if height else None, | |
| filter_vowel_backness=backness if backness else None | |
| ) | |
| # Calculate metrics | |
| metrics = visualizer.calculate_metrics( | |
| filter_manner=manner if manner else None, | |
| filter_place=place if place else None, | |
| filter_voicing=voicing if voicing else None, | |
| filter_vowel_height=height if height else None, | |
| filter_vowel_backness=backness if backness else None | |
| ) | |
| # Create confusion matrix | |
| confusion_fig = analyzer.create_confusion_heatmap(model_name) | |
| return fig, metrics, confusion_fig | |
| update_btn.click( | |
| fn=update_single_model, | |
| inputs=[model_selector, color_by, manner_filter, place_filter, | |
| voicing_filter, vowel_height_filter, vowel_backness_filter], | |
| outputs=[plot_output, metrics_output, confusion_output] | |
| ) | |
| """ | |
| # Tab 3: Audio Explorer | |
| """ | |
| with gr.Tab("Audio Explorer"): | |
| gr.Markdown("### Listen to Cluster Samples") | |
| gr.Markdown("Explore audio segments from clusters and phonemes") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_model_selector = gr.Dropdown( | |
| model_names, | |
| value=model_names[0] if model_names else None, | |
| label="Select Model" | |
| ) | |
| exploration_mode = gr.Radio( | |
| ['By Cluster', 'By Phoneme', 'Compare Phoneme Across Clusters'], | |
| value='By Cluster', | |
| label="Exploration Mode" | |
| ) | |
| # Cluster mode inputs | |
| with gr.Group(visible=True) as cluster_inputs: | |
| cluster_id_audio = gr.Number( | |
| label="Cluster ID", | |
| value=0, | |
| precision=0 | |
| ) | |
| n_cluster_samples = gr.Slider( | |
| 1, 10, value=5, | |
| step=1, | |
| label="Number of samples" | |
| ) | |
| # Phoneme mode inputs | |
| with gr.Group(visible=False) as phoneme_inputs: | |
| phoneme_select = gr.Dropdown( | |
| sorted(list(PHONEMES.keys())), | |
| label="Select Phoneme", | |
| value="æ" | |
| ) | |
| n_phoneme_samples = gr.Slider( | |
| 1, 10, value=5, | |
| step=1, | |
| label="Number of samples" | |
| ) | |
| # Compare mode inputs | |
| with gr.Group(visible=False) as compare_inputs: | |
| phoneme_compare = gr.Dropdown( | |
| sorted(list(PHONEMES.keys())), | |
| label="Phoneme to Compare", | |
| value="æ" | |
| ) | |
| n_per_cluster = gr.Slider( | |
| 1, 5, value=3, | |
| step=1, | |
| label="Samples per cluster" | |
| ) | |
| play_audio_btn = gr.Button("🎵 Load Audio Samples", variant="primary") | |
| with gr.Column(scale=2): | |
| audio_output = gr.HTML(label="Audio Player") | |
| audio_info = gr.Markdown() | |
| # Toggle visibility based on mode | |
| def update_visibility(mode): | |
| return ( | |
| gr.update(visible=(mode == 'By Cluster')), | |
| gr.update(visible=(mode == 'By Phoneme')), | |
| gr.update(visible=(mode == 'Compare Phoneme Across Clusters')) | |
| ) | |
| exploration_mode.change( | |
| fn=update_visibility, | |
| inputs=[exploration_mode], | |
| outputs=[cluster_inputs, phoneme_inputs, compare_inputs] | |
| ) | |
| def load_audio_samples(model_name, mode, cluster_id, n_cluster, | |
| phoneme, n_phoneme, phoneme_cmp, n_per_clust): | |
| if not model_name or model_name not in analyzer.audio_explorers: | |
| return "<p>Audio not available for this model</p>", "No audio data loaded" | |
| explorer = analyzer.audio_explorers[model_name] | |
| try: | |
| if mode == 'By Cluster': | |
| samples = explorer.get_cluster_samples( | |
| cluster_id=int(cluster_id), | |
| n_samples=int(n_cluster) | |
| ) | |
| info = f"### Cluster {cluster_id}\n\nShowing {len(samples)} samples" | |
| elif mode == 'By Phoneme': | |
| samples = explorer.get_phoneme_samples( | |
| phoneme=phoneme, | |
| n_samples=int(n_phoneme) | |
| ) | |
| info = f"### Phoneme: {phoneme}\n\nShowing {len(samples)} samples" | |
| else: # Compare mode | |
| cluster_samples = explorer.compare_phoneme_in_clusters( | |
| phoneme=phoneme_cmp, | |
| n_per_cluster=int(n_per_clust) | |
| ) | |
| # Flatten samples and add cluster headers | |
| html = "" | |
| info_lines = [f"### Phoneme: {phoneme_cmp} across clusters\n"] | |
| for cluster_id, samps in sorted(cluster_samples.items()): | |
| html += f'<h4>Cluster {cluster_id}</h4>' | |
| html += create_audio_grid(samps, columns=3) | |
| info_lines.append(f"- Cluster {cluster_id}: {len(samps)} samples") | |
| return html, "\n".join(info_lines) | |
| if not samples: | |
| return "<p>No samples found</p>", "No matching samples" | |
| html = create_audio_grid(samples, columns=3) | |
| return html, info | |
| except Exception as e: | |
| return f"<p>Error loading audio: {str(e)}</p>", f"Error: {str(e)}" | |
| play_audio_btn.click( | |
| fn=load_audio_samples, | |
| inputs=[audio_model_selector, exploration_mode, | |
| cluster_id_audio, n_cluster_samples, | |
| phoneme_select, n_phoneme_samples, | |
| phoneme_compare, n_per_cluster], | |
| outputs=[audio_output, audio_info] | |
| ) | |
| """ | |
| # Tab 4: Export & Analysis | |
| """ | |
| with gr.Tab("Export & Analysis"): | |
| gr.Markdown("### Export Results") | |
| with gr.Row(): | |
| export_model = gr.Dropdown( | |
| model_names, | |
| label="Select Model to Export" | |
| ) | |
| export_format = gr.Radio( | |
| ['CSV', 'JSON', 'NPZ'], | |
| value='CSV', | |
| label="Format" | |
| ) | |
| export_btn = gr.Button("Export Data", variant="primary") | |
| export_output = gr.File(label="Download") | |
| def export_data(model_name, format_type): | |
| if not model_name: | |
| return None | |
| data = analyzer.models[model_name] | |
| output_path = f"{model_name}_export.{format_type.lower()}" | |
| if format_type == 'CSV': | |
| df = pd.DataFrame({ | |
| 'cluster': data['cluster_labels'], | |
| 'phoneme': data['phoneme_strings'], | |
| 'phone_idx': data['phone_labels'] | |
| }) | |
| df.to_csv(output_path, index=False) | |
| elif format_type == 'JSON': | |
| export_dict = { | |
| 'clusters': data['cluster_labels'].tolist(), | |
| 'phonemes': data['phoneme_strings'].tolist(), | |
| 'phone_indices': data['phone_labels'].tolist() | |
| } | |
| with open(output_path, 'w') as f: | |
| json.dump(export_dict, f, indent=2) | |
| else: # NPZ | |
| np.savez( | |
| output_path, | |
| features=data['features'], | |
| clusters=data['cluster_labels'], | |
| phones=data['phone_labels'] | |
| ) | |
| return output_path | |
| export_btn.click( | |
| fn=export_data, | |
| inputs=[export_model, export_format], | |
| outputs=[export_output] | |
| ) | |
| """ | |
| # Tab 6: Context Pooling Analysis | |
| """ | |
| with gr.Tab("Context Pooling"): | |
| gr.Markdown("### Coarticulation Analysis") | |
| gr.Markdown("Pool phoneme embeddings by context to account for coarticulation effects") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### Pooling Configuration") | |
| context_model = gr.Dropdown( | |
| model_names, | |
| value=model_names[0] if model_names else None, | |
| label="Select Model" | |
| ) | |
| enable_pooling = gr.Checkbox( | |
| label="Enable Context Pooling", | |
| value=False | |
| ) | |
| left_context = gr.Slider( | |
| 0, 3, value=1, step=1, | |
| label="Left Context (# phones)", | |
| info="How many phones before target" | |
| ) | |
| right_context = gr.Slider( | |
| 0, 3, value=1, step=1, | |
| label="Right Context (# phones)", | |
| info="How many phones after target" | |
| ) | |
| pooling_method = gr.Radio( | |
| choices=['mean', 'median', 'max'], | |
| value='mean', | |
| label="Pooling Method" | |
| ) | |
| min_samples = gr.Slider( | |
| 1, 10, value=2, step=1, | |
| label="Min Samples per Context", | |
| info="Minimum instances to pool" | |
| ) | |
| compute_pooling_btn = gr.Button("Apply Pooling", variant="primary") | |
| pooling_status = gr.Markdown("") | |
| gr.Markdown("#### Analyze Specific Phone") | |
| phone_to_analyze = gr.Textbox( | |
| label="Phoneme", | |
| placeholder="æ", | |
| value="æ" | |
| ) | |
| analyze_phone_btn = gr.Button("Analyze Contexts") | |
| with gr.Column(scale=2): | |
| pooling_comparison = gr.Markdown("*Apply pooling to see comparison*") | |
| context_analysis = gr.Markdown("*Analyze a phone to see contexts*") | |
| # with gr.Row(): | |
| # pooled_plot = gr.Plot(label="Pooled Embeddings (UMAP)") | |
| # Context pooling callbacks | |
| def apply_context_pooling(model_name, enable, left, right, method, min_samp): | |
| if not model_name or model_name not in analyzer.models: | |
| return "Model not available", "" | |
| data = analyzer.models[model_name] | |
| if not enable: | |
| # No pooling | |
| metrics = calculate_all_metrics( | |
| data['cluster_labels'], | |
| data['phone_labels'] | |
| ) | |
| comparison = "### No Pooling (Baseline)\n\n" | |
| comparison += f"- **Points**: {len(data['features'])}\n" | |
| comparison += f"- **Cluster Purity**: {metrics['cluster_purity']:.3f}\n" | |
| comparison += f"- **Phone Purity**: {metrics['phone_purity']:.3f}\n" | |
| comparison += f"- **V-Measure**: {metrics['v_measure']:.3f}\n" | |
| comparison += f"- **NMI**: {metrics.get('nmi', 0):.3f}\n" | |
| return "No pooling applied (baseline)", comparison | |
| try: | |
| # Create context config | |
| config = ContextConfig( | |
| enabled=True, | |
| left_context=int(left), | |
| right_context=int(right), | |
| pooling_method=method, | |
| min_samples=int(min_samp) | |
| ) | |
| # Create pooler | |
| pooler = ContextAwarePooler(config) | |
| # Pool embeddings | |
| # Note: This assumes sequential data. In practice, you'd need | |
| # utterance boundaries from preprocessing | |
| phone_sequence = data['phone_labels'] # Simplified | |
| pooled_embeddings, context_info = pooler.create_context_clusters( | |
| data['features'], | |
| data['phone_labels'], | |
| phone_sequence, | |
| utterance_boundaries=None # Would come from data | |
| ) | |
| # Calculate metrics on pooled space | |
| # Need to re-cluster or map clusters | |
| from sklearn.cluster import KMeans | |
| n_clusters = len(np.unique(data['cluster_labels'])) | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=42) | |
| pooled_clusters = kmeans.fit_predict(pooled_embeddings) | |
| metrics = calculate_all_metrics( | |
| pooled_clusters, | |
| context_info['labels'] | |
| ) | |
| # Create comparison | |
| comparison = f"### Context Pooling Results\n\n" | |
| comparison += f"**Configuration**: L{left}R{right} ({method})\n\n" | |
| comparison += f"- **Original Points**: {context_info['n_original']}\n" | |
| comparison += f"- **Pooled Points**: {context_info['n_pooled']}\n" | |
| comparison += f"- **Reduction**: {(1 - context_info['reduction_ratio'])*100:.1f}%\n\n" | |
| comparison += f"**Metrics**:\n" | |
| comparison += f"- **Cluster Purity**: {metrics['cluster_purity']:.3f}\n" | |
| comparison += f"- **Phone Purity**: {metrics['phone_purity']:.3f}\n" | |
| comparison += f"- **V-Measure**: {metrics['v_measure']:.3f}\n" | |
| comparison += f"- **NMI**: {metrics.get('nmi', 0):.3f}\n" | |
| status = f"Pooled {context_info['n_original']} → {context_info['n_pooled']} points" | |
| return status, comparison | |
| except Exception as e: | |
| return f"Error: {str(e)}", "" | |
| def analyze_phone_contexts(model_name, phone, left, right): | |
| if not model_name or not phone: | |
| return "*Enter phone to analyze*" | |
| if model_name not in analyzer.models: | |
| return "Model not available" | |
| try: | |
| data = analyzer.models[model_name] | |
| # Create analyzer | |
| ctx_analyzer = ContextAwareAnalyzer( | |
| embeddings=data['features'], | |
| phone_labels=data['phone_labels'], | |
| phone_sequence=data['phone_labels'], | |
| cluster_labels=data['cluster_labels'] | |
| ) | |
| # Analyze phone | |
| analysis = ctx_analyzer.analyze_context_effects(phone, PHONEMES) | |
| if 'error' in analysis: | |
| return f"{analysis['error']}" | |
| # Format output | |
| output = f"### Analysis of /{phone}/\n\n" | |
| output += f"- **Total occurrences**: {analysis['total_occurrences']}\n" | |
| output += f"- **Unique contexts**: {analysis['unique_contexts']}\n\n" | |
| output += f"**Most Common Contexts**:\n\n" | |
| # Sort by count | |
| contexts_sorted = sorted( | |
| analysis['contexts'].items(), | |
| key=lambda x: x[1]['count'], | |
| reverse=True | |
| ) | |
| for ctx_str, info in contexts_sorted[:10]: | |
| output += f"- **{ctx_str}**: {info['count']} times" | |
| if info['cluster_distribution']: | |
| clusters = ", ".join(f"C{c}({cnt})" | |
| for c, cnt in info['cluster_distribution'].items()) | |
| output += f" → {clusters}" | |
| output += "\n" | |
| if len(contexts_sorted) > 10: | |
| output += f"\n*... and {len(contexts_sorted) - 10} more contexts*" | |
| return output | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Connect callbacks | |
| compute_pooling_btn.click( | |
| fn=apply_context_pooling, | |
| inputs=[context_model, enable_pooling, left_context, right_context, | |
| pooling_method, min_samples], | |
| outputs=[pooling_status, pooling_comparison] | |
| ) | |
| analyze_phone_btn.click( | |
| fn=analyze_phone_contexts, | |
| inputs=[context_model, phone_to_analyze, left_context, right_context], | |
| outputs=[context_analysis] | |
| ) | |
| """ | |
| # def get_choices(model_name, label_type): | |
| # viz = analyzer.projector_vizs[model_name] | |
| # df = pd.DataFrame(viz.labels) | |
| # choices = [str(x) for x in df[label_type].unique()] | |
| # print(choices) | |
| # value = choices[0] if choices else None | |
| # return choices, value | |
| def get_choices(model_name, label_type): | |
| viz = analyzer.projector_vizs[model_name] | |
| df = pd.DataFrame(viz.labels) | |
| if label_type == "phone": | |
| choices = df["phone"].unique() | |
| elif label_type == "cluster": | |
| choices = df["cluster"].unique() | |
| else: | |
| choices = df["language"].unique() | |
| return gr.update( | |
| choices=[str(x) for x in choices], # MUST be a Python list of strings | |
| value=str(choices[0]) # MUST be one of the choices | |
| ) | |
| with gr.Tab("Embedding Projector"): | |
| gr.Markdown("### TensorFlow Projector-Style 3D Visualization") | |
| gr.Markdown("Interactive exploration similar to TensorFlow's Embedding Projector") | |
| with gr.Row(): | |
| # Left sidebar | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### Model & Projection") | |
| projector_model = gr.Dropdown( | |
| model_names, | |
| value=model_names[0] if model_names else None, | |
| label="Select Model" | |
| ) | |
| projection_method = gr.Radio( | |
| choices=['PCA', 't-SNE', 'UMAP'], | |
| # choices=['PCA', 'UMAP'], | |
| value='UMAP', | |
| label="Projection Method" | |
| ) | |
| tsne_perplexity = gr.Slider(5, 50, value=30, step=1, label="t-SNE Perplexity", visible=False) | |
| tsne_lr = gr.Slider(10, 1000, value=200, step=10, label="t-SNE Learning Rate", visible=False) | |
| tsne_iters = gr.Slider(250, 5000, value=1000, step=250, label="t-SNE Iterations", visible=False) | |
| projection_method.change( | |
| fn=toggle_tsne_params, | |
| inputs=[projection_method], | |
| outputs=[tsne_perplexity, tsne_lr, tsne_iters] | |
| ) | |
| dimension = gr.Radio( | |
| choices=['3D', '2D'], | |
| value='3D', | |
| label="Dimensions" | |
| ) | |
| projector_color_by = gr.Radio( | |
| # choices=['cluster', 'phone', 'language'], | |
| choices=['cluster', 'language'], | |
| value='cluster', | |
| label="Color by" | |
| ) | |
| compute_btn = gr.Button("Compute Projections", variant="primary") | |
| compute_status = gr.Markdown("*Click to compute projections*") | |
| gr.Markdown("#### Search & Highlight") | |
| search_mode = gr.Radio( | |
| choices=['By Label', 'By Features'], | |
| value='By Label', | |
| label="Search Mode" | |
| ) | |
| phones = ["æ", "ɑ", "ə", "i", "u"] | |
| clusters = [0, 1, 2, 3] | |
| languages = ["hi", "pa"] | |
| # Label search (simple) | |
| with gr.Group(visible=True) as label_search_group: | |
| # search_label_type = gr.Radio( | |
| # choices=['phone', 'cluster', 'language'], | |
| # value='phone', | |
| # label="Search in" | |
| # ) | |
| # search_term = gr.Textbox( | |
| # label="Search term", | |
| # placeholder="e.g., 'æ' or '5'" | |
| # ) | |
| # search_term = gr.Dropdown( | |
| # choices=list(phones), # initial choices | |
| # value=phones[0], # initial value | |
| # label="Search term", | |
| # allow_custom_value=True | |
| # ) | |
| # # Update dropdown choices when the label type changes | |
| # # Update search_term whenever the label type changes | |
| # search_label_type.change( | |
| # fn=get_choices, | |
| # inputs=[projector_model, search_label_type], | |
| # outputs=[search_term, search_term] # first = choices, second = value | |
| # ) | |
| search_label_type = gr.Radio( | |
| choices=["phone", "cluster", "language"], | |
| value="phone", | |
| label="Search in" | |
| ) | |
| search_term = gr.Dropdown( | |
| choices=[str(x) for x in phones], | |
| value=str(phones[0]), | |
| label="Search term" | |
| ) | |
| search_label_type.change( | |
| fn=get_choices, | |
| inputs=[projector_model, search_label_type], | |
| outputs=search_term | |
| ) | |
| # Feature search (advanced) | |
| with gr.Group(visible=False) as feature_search_group: | |
| search_manner = gr.Dropdown( | |
| choices=['stop', 'fricative', 'nasal', 'approximant', | |
| 'affricate', 'tap/flap'], | |
| multiselect=True, | |
| label="Manner" | |
| ) | |
| search_place = gr.Dropdown( | |
| choices=['bilabial', 'labiodental', 'dental', 'alveolar', | |
| 'postalveolar', 'palatal', 'velar', 'uvular', | |
| 'pharyngeal', 'glottal'], | |
| multiselect=True, | |
| label="Place" | |
| ) | |
| search_voicing = gr.Dropdown( | |
| choices=['voiced', 'voiceless'], | |
| multiselect=True, | |
| label="Voicing" | |
| ) | |
| search_vowel_height = gr.Dropdown( | |
| choices=['high', 'mid', 'low'], | |
| multiselect=True, | |
| label="Vowel Height" | |
| ) | |
| search_vowel_backness = gr.Dropdown( | |
| choices=['front', 'central', 'back'], | |
| multiselect=True, | |
| label="Vowel Backness" | |
| ) | |
| search_btn = gr.Button("🔍 Search") | |
| # gr.Markdown("#### Nearest Neighbors") | |
| # point_idx = gr.Number( | |
| # label="Point index", | |
| # value=0, | |
| # precision=0 | |
| # ) | |
| # n_neighbors = gr.Slider( | |
| # 1, 50, value=10, | |
| # step=1, | |
| # label="Number of neighbors" | |
| # ) | |
| # show_nn_btn = gr.Button("Show Neighbors") | |
| info_display = gr.Markdown("*Select a point or search*") | |
| # Main visualization area | |
| with gr.Column(scale=3): | |
| projector_plot = gr.Plot(label="Embedding Space") | |
| # with gr.Row(): | |
| # comparison_btn = gr.Button("Show Comparison View (PCA | t-SNE | UMAP)") | |
| # comparison_plot = gr.Plot(label="Comparison", visible=False) | |
| # Projector callbacks | |
| def compute_projections(model_name, method, tsne_perplexity, tsne_lr, tsne_iters): | |
| if not model_name or model_name not in analyzer.projector_vizs: | |
| return "Model not available", None | |
| viz = analyzer.projector_vizs[model_name] | |
| try: | |
| method_lower = method.lower() | |
| viz.compute_projections(method_lower, tsne_perplexity, tsne_lr, tsne_iters) | |
| # Create initial plot | |
| proj_key = f"{method_lower}_3d" | |
| fig = viz.create_3d_scatter( | |
| projection=proj_key, | |
| color_by='cluster' | |
| ) | |
| return f"{method} projections computed!", fig | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| def toggle_search_mode(mode): | |
| """Toggle between label and feature search.""" | |
| if mode == 'By Label': | |
| return gr.update(visible=True), gr.update(visible=False) | |
| else: | |
| return gr.update(visible=False), gr.update(visible=True) | |
| def update_projector_plot(model_name, method, dim, color_by_val, highlight_indices=None): | |
| if not model_name or model_name not in analyzer.projector_vizs: | |
| return None | |
| viz = analyzer.projector_vizs[model_name] | |
| proj_key = f"{method.lower()}_{dim.lower()}" | |
| # Check if projection exists | |
| if proj_key not in viz.projections: | |
| return None | |
| try: | |
| if dim == '3D': | |
| fig = viz.create_3d_scatter( | |
| projection=proj_key, | |
| color_by=color_by_val.lower(), | |
| highlight_indices=highlight_indices | |
| ) | |
| else: | |
| fig = viz.create_2d_scatter( | |
| projection=proj_key, | |
| color_by=color_by_val.lower(), | |
| highlight_indices=highlight_indices | |
| ) | |
| return fig | |
| except Exception as e: | |
| print(f"Error creating plot: {e}") | |
| return None | |
| def search_points(model_name, search_mode, search_type, term, method, dim, | |
| color_by_val, manner, place, voicing, vheight, vbackness): | |
| if not model_name or model_name not in analyzer.projector_vizs: | |
| return None, "Model not available" | |
| viz = analyzer.projector_vizs[model_name] | |
| if search_mode == 'By Label': | |
| if not term: | |
| fig = update_projector_plot(model_name, method, dim, color_by_val) | |
| return fig, "No search term provided" | |
| matches = viz.search_by_label(term, search_type.lower()) | |
| info = f"Found {len(matches)} matches for '{term}' in {search_type}" | |
| else: # By Features | |
| matches = viz.search_by_articulatory_features( | |
| PHONEMES, | |
| manner=manner if manner else None, | |
| place=place if place else None, | |
| voicing=voicing if voicing else None, | |
| vowel_height=vheight if vheight else None, | |
| vowel_backness=vbackness if vbackness else None | |
| ) | |
| # Get summary | |
| summary = viz.get_articulatory_summary(matches, PHONEMES) | |
| info = f"Found {len(matches)} points matching features:\n\n" | |
| if manner: | |
| info += f"**Manner**: {', '.join(manner)}\n" | |
| if place: | |
| info += f"**Place**: {', '.join(place)}\n" | |
| if voicing: | |
| info += f"**Voicing**: {', '.join(voicing)}\n" | |
| if vheight: | |
| info += f"**Vowel Height**: {', '.join(vheight)}\n" | |
| if vbackness: | |
| info += f"**Vowel Backness**: {', '.join(vbackness)}\n" | |
| if summary and len(matches) > 0: | |
| info += f"\n**Distribution**:\n" | |
| if summary.get('manner'): | |
| info += "- Manner: " + ", ".join( | |
| f"{k}({v})" for k, v in sorted(summary['manner'].items()) | |
| ) + "\n" | |
| if summary.get('place'): | |
| info += "- Place: " + ", ".join( | |
| f"{k}({v})" for k, v in sorted(summary['place'].items()) | |
| ) + "\n" | |
| fig = update_projector_plot(model_name, method, dim, color_by_val, | |
| highlight_indices=matches) | |
| if matches: | |
| if len(matches) <= 10: | |
| info += f"\n\nIndices: {matches}" | |
| else: | |
| info += f"\n\nSample indices: {matches[:10]}... (+{len(matches)-10} more)" | |
| return fig, info | |
| def show_neighbors(model_name, idx, n, method, dim, color_by_val): | |
| if not model_name or model_name not in analyzer.projector_vizs: | |
| return None, "Model not available" | |
| viz = analyzer.projector_vizs[model_name] | |
| if viz.nn_model is None: | |
| viz.build_nn_index() | |
| neighbors, distances = viz.find_nearest_neighbors(int(idx), int(n)) | |
| # Show with lines to neighbors | |
| line_pairs = [(int(idx), int(nn)) for nn in neighbors] | |
| proj_key = f"{method.lower()}_{dim.lower()}" | |
| if proj_key not in viz.projections: | |
| return None, "Projections not computed" | |
| if dim == '3D': | |
| fig = viz.create_3d_scatter( | |
| projection=proj_key, | |
| color_by=color_by_val.lower(), | |
| highlight_indices=[int(idx)] + list(neighbors), | |
| show_lines=True, | |
| line_pairs=line_pairs | |
| ) | |
| else: | |
| fig = viz.create_2d_scatter( | |
| projection=proj_key, | |
| color_by=color_by_val.lower(), | |
| highlight_indices=[int(idx)] + list(neighbors) | |
| ) | |
| info = f"Point {idx} - Nearest {n} neighbors:\n\n" | |
| for i, (nn_idx, dist) in enumerate(zip(neighbors, distances), 1): | |
| info += f"{i}. Index {nn_idx} (distance: {dist:.3f})\n" | |
| return fig, info | |
| def show_comparison_view(model_name, color_by_val): | |
| if not model_name or model_name not in analyzer.projector_vizs: | |
| return gr.update(visible=False), None | |
| viz = analyzer.projector_vizs[model_name] | |
| # Ensure all projections exist | |
| for method in ['pca', 'tsne', 'umap']: | |
| if f'{method}_3d' not in viz.projections: | |
| return gr.update(visible=False), None | |
| fig = viz.create_comparison_view(color_by=color_by_val.lower()) | |
| return gr.update(visible=True), fig | |
| # Connect callbacks | |
| # compute_btn.click( | |
| # fn=compute_projections, | |
| # inputs=[projector_model, projection_method], | |
| # outputs=[compute_status, projector_plot] | |
| # ) | |
| compute_btn.click( | |
| fn=compute_projections, | |
| inputs=[projector_model, projection_method, | |
| tsne_perplexity, tsne_lr, tsne_iters], | |
| outputs=[compute_status, projector_plot] | |
| ) | |
| search_mode.change( | |
| fn=toggle_search_mode, | |
| inputs=[search_mode], | |
| outputs=[label_search_group, feature_search_group] | |
| ) | |
| for component in [projection_method, dimension, projector_color_by]: | |
| component.change( | |
| fn=lambda m, meth, d, c: update_projector_plot(m, meth, d, c), | |
| inputs=[projector_model, projection_method, dimension, projector_color_by], | |
| outputs=[projector_plot] | |
| ) | |
| search_btn.click( | |
| fn=search_points, | |
| inputs=[projector_model, search_mode, search_label_type, search_term, | |
| projection_method, dimension, projector_color_by, | |
| search_manner, search_place, search_voicing, | |
| search_vowel_height, search_vowel_backness], | |
| outputs=[projector_plot, info_display] | |
| ) | |
| # show_nn_btn.click( | |
| # fn=show_neighbors, | |
| # inputs=[projector_model, point_idx, n_neighbors, | |
| # projection_method, dimension, projector_color_by], | |
| # outputs=[projector_plot, info_display] | |
| # ) | |
| # comparison_btn.click( | |
| # fn=lambda m, c: show_comparison_view(m, c), | |
| # inputs=[projector_model, projector_color_by], | |
| # outputs=[comparison_plot, comparison_plot] | |
| # ) | |
| return demo | |
| def create_root_interface(output_dir): | |
| subdirs = get_top_level_dirs(output_dir) | |
| # Load config | |
| try: | |
| with open("config.json") as f: | |
| config = json.load(f) | |
| selected = config.get("selected_dirs", []) | |
| if selected: | |
| subdirs = [d for d in subdirs if d.name in selected] | |
| except FileNotFoundError: | |
| pass # Load all if no config | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Discrete Token Phoneme Analysis") | |
| with gr.Tabs(): | |
| for subdir in subdirs: | |
| with gr.Tab(subdir.name): | |
| analyzer = load_analyzer_for_subdir(subdir) | |
| create_integrated_gradio_interface(analyzer) | |
| return demo | |
| if __name__ == "__main__": | |
| # # Create analyzer | |
| # analyzer = MultiModelAnalyzer(OUTPUT_DIR) | |
| # # Create and launch interface | |
| # demo = create_integrated_gradio_interface(analyzer) | |
| demo = create_root_interface(OUTPUT_DIR) | |
| demo.launch( | |
| theme=gr.themes.Soft() | |
| # server_port=args.port, | |
| # share=True # Creates public link | |
| ) | |
| # # demo = create_interface() | |
| # # demo.launch() | |