import gradio as gr import os import sys import subprocess import importlib from pathlib import Path import json import pandas as pd def install_private_package(): """Install private package from GitHub using token""" print("Installing private package...") gh_token = os.environ.get("GH_TOKEN") if not gh_token: raise ValueError("GH_TOKEN not found in environment variables") package_url = f"git+https://{gh_token}@github.com/tolulope/speech-model-analysis.git" # Use subprocess for better error handling result = subprocess.run( [sys.executable, "-m", "pip", "install", "--no-cache-dir", package_url], capture_output=True, text=True ) if result.returncode != 0: print("STDOUT:", result.stdout) print("STDERR:", result.stderr) raise RuntimeError(f"Failed to install package: {result.stderr}") print("✓ Package installed successfully!") # Clear import caches so Python recognizes the new package importlib.invalidate_caches() # Install the package first install_private_package() # # Install private package at startup # print("Installing private package...") # gh_token = os.environ.get("GH_TOKEN") # if not gh_token: # raise ValueError("GH_TOKEN not found in environment variables") # package_url = f"git+https://{gh_token}@github.com/tolulope/speech-model-analysis.git" # os.system(f"{sys.executable} -m pip install {package_url}") # Now import from your private package from speech_model_analysis import ( VoxCommunisPreprocessor, MultiModelAnalyzer, create_hubert_configs, ) from speech_model_analysis.phoneme_manager import PHONEMES, index_to_phoneme from speech_model_analysis.voxcommunis_preprocessing import VoxCommunisPreprocessor, create_hubert_configs from speech_model_analysis.gradio_viz import ClusterVisualizer from speech_model_analysis.enhanced_analysis import calculate_all_metrics from speech_model_analysis.audio_player import ClusterAudioExplorer, create_audio_grid from speech_model_analysis.embedding_projector_viz import EmbeddingProjectorViz from speech_model_analysis.context_pooling import ContextConfig, ContextAwarePooler, ContextAwareAnalyzer print("Private package loaded successfully!") from huggingface_hub import hf_hub_download, snapshot_download, login login(os.environ["HF_TOKEN"]) # Download the full repo snapshot to a local dir OUTPUT_DIR = snapshot_download("tolulope/speech-model-analysis", repo_type="dataset") def get_top_level_dirs(root): root = Path(root) return [d for d in root.iterdir() if d.is_dir()] def load_analyzer_for_subdir(subdir_path): return MultiModelAnalyzer(str(subdir_path)) def toggle_tsne_params(method): visible = method == "t-SNE" return [ gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible) ] def create_integrated_gradio_interface(analyzer: MultiModelAnalyzer): """ Create comprehensive Gradio interface with model comparison. Args: analyzer: MultiModelAnalyzer instance """ # Extract feature options (same as before) all_manners = sorted(set(p.manner.name for p in PHONEMES.values() if p.manner)) all_places = sorted(set(p.place.name for p in PHONEMES.values() if p.place)) all_voicings = ['voiced', 'voiceless'] all_heights = ['high', 'mid', 'low'] all_backness = ['front', 'central', 'back'] model_names = analyzer.get_model_names() with gr.Blocks(title="Discrete Token Analysis") as demo: gr.Markdown("# Discrete Token Phoneme Analysis") # gr.Markdown("Compare HuBERT models and analyze discrete representations") with gr.Tabs(): # Tab 1: Model Comparison with gr.Tab("Model Comparison"): gr.Markdown("### Compare Clustering Quality Across Models") with gr.Row(): # comparison_plot = gr.Plot(label="Metrics Comparison") metrics_table = gr.Dataframe(label="Detailed Metrics") refresh_comparison_btn = gr.Button("Refresh Comparison", variant="primary") def update_comparison(): # fig = analyzer.create_comparison_plot() df = analyzer.compare_metrics() df = df.round(2) return df # refresh_comparison_btn.click( # fn=update_comparison, # outputs=[comparison_plot, metrics_table] # ) # Initialize demo.load( fn=update_comparison, # outputs=[comparison_plot, metrics_table] outputs=[metrics_table] ) # Tab 2: Single Model Analysis """ with gr.Tab("Single Model Analysis"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Model & Filters") model_selector = gr.Dropdown( model_names, value=model_names[0] if model_names else None, label="Select Model" ) color_by = gr.Radio( ['cluster', 'phone'], value='cluster', label="Color by" ) gr.Markdown("#### Articulatory Filters") manner_filter = gr.Dropdown( all_manners, multiselect=True, label="Manner" ) place_filter = gr.Dropdown( all_places, multiselect=True, label="Place" ) voicing_filter = gr.Dropdown( all_voicings, multiselect=True, label="Voicing" ) vowel_height_filter = gr.Dropdown( all_heights, multiselect=True, label="Vowel Height" ) vowel_backness_filter = gr.Dropdown( all_backness, multiselect=True, label="Vowel Backness" ) update_btn = gr.Button("Update Visualization", variant="primary") with gr.Column(scale=2): plot_output = gr.Plot(label="Cluster Visualization") gr.Markdown("💡 **Tip**: Click on points to hear audio in the Audio Explorer tab!") with gr.Row(): with gr.Column(): metrics_output = gr.Markdown() with gr.Column(): confusion_output = gr.Plot(label="Confusion Matrix") def update_single_model(model_name, color, manner, place, voicing, height, backness): if not model_name: return None, "Select a model", None visualizer = analyzer.visualizers[model_name] # Create scatter plot fig = visualizer.create_scatter_plot( color_by=color, filter_manner=manner if manner else None, filter_place=place if place else None, filter_voicing=voicing if voicing else None, filter_vowel_height=height if height else None, filter_vowel_backness=backness if backness else None ) # Calculate metrics metrics = visualizer.calculate_metrics( filter_manner=manner if manner else None, filter_place=place if place else None, filter_voicing=voicing if voicing else None, filter_vowel_height=height if height else None, filter_vowel_backness=backness if backness else None ) # Create confusion matrix confusion_fig = analyzer.create_confusion_heatmap(model_name) return fig, metrics, confusion_fig update_btn.click( fn=update_single_model, inputs=[model_selector, color_by, manner_filter, place_filter, voicing_filter, vowel_height_filter, vowel_backness_filter], outputs=[plot_output, metrics_output, confusion_output] ) """ # Tab 3: Audio Explorer """ with gr.Tab("Audio Explorer"): gr.Markdown("### Listen to Cluster Samples") gr.Markdown("Explore audio segments from clusters and phonemes") with gr.Row(): with gr.Column(): audio_model_selector = gr.Dropdown( model_names, value=model_names[0] if model_names else None, label="Select Model" ) exploration_mode = gr.Radio( ['By Cluster', 'By Phoneme', 'Compare Phoneme Across Clusters'], value='By Cluster', label="Exploration Mode" ) # Cluster mode inputs with gr.Group(visible=True) as cluster_inputs: cluster_id_audio = gr.Number( label="Cluster ID", value=0, precision=0 ) n_cluster_samples = gr.Slider( 1, 10, value=5, step=1, label="Number of samples" ) # Phoneme mode inputs with gr.Group(visible=False) as phoneme_inputs: phoneme_select = gr.Dropdown( sorted(list(PHONEMES.keys())), label="Select Phoneme", value="æ" ) n_phoneme_samples = gr.Slider( 1, 10, value=5, step=1, label="Number of samples" ) # Compare mode inputs with gr.Group(visible=False) as compare_inputs: phoneme_compare = gr.Dropdown( sorted(list(PHONEMES.keys())), label="Phoneme to Compare", value="æ" ) n_per_cluster = gr.Slider( 1, 5, value=3, step=1, label="Samples per cluster" ) play_audio_btn = gr.Button("🎵 Load Audio Samples", variant="primary") with gr.Column(scale=2): audio_output = gr.HTML(label="Audio Player") audio_info = gr.Markdown() # Toggle visibility based on mode def update_visibility(mode): return ( gr.update(visible=(mode == 'By Cluster')), gr.update(visible=(mode == 'By Phoneme')), gr.update(visible=(mode == 'Compare Phoneme Across Clusters')) ) exploration_mode.change( fn=update_visibility, inputs=[exploration_mode], outputs=[cluster_inputs, phoneme_inputs, compare_inputs] ) def load_audio_samples(model_name, mode, cluster_id, n_cluster, phoneme, n_phoneme, phoneme_cmp, n_per_clust): if not model_name or model_name not in analyzer.audio_explorers: return "

Audio not available for this model

", "No audio data loaded" explorer = analyzer.audio_explorers[model_name] try: if mode == 'By Cluster': samples = explorer.get_cluster_samples( cluster_id=int(cluster_id), n_samples=int(n_cluster) ) info = f"### Cluster {cluster_id}\n\nShowing {len(samples)} samples" elif mode == 'By Phoneme': samples = explorer.get_phoneme_samples( phoneme=phoneme, n_samples=int(n_phoneme) ) info = f"### Phoneme: {phoneme}\n\nShowing {len(samples)} samples" else: # Compare mode cluster_samples = explorer.compare_phoneme_in_clusters( phoneme=phoneme_cmp, n_per_cluster=int(n_per_clust) ) # Flatten samples and add cluster headers html = "" info_lines = [f"### Phoneme: {phoneme_cmp} across clusters\n"] for cluster_id, samps in sorted(cluster_samples.items()): html += f'

Cluster {cluster_id}

' html += create_audio_grid(samps, columns=3) info_lines.append(f"- Cluster {cluster_id}: {len(samps)} samples") return html, "\n".join(info_lines) if not samples: return "

No samples found

", "No matching samples" html = create_audio_grid(samples, columns=3) return html, info except Exception as e: return f"

Error loading audio: {str(e)}

", f"Error: {str(e)}" play_audio_btn.click( fn=load_audio_samples, inputs=[audio_model_selector, exploration_mode, cluster_id_audio, n_cluster_samples, phoneme_select, n_phoneme_samples, phoneme_compare, n_per_cluster], outputs=[audio_output, audio_info] ) """ # Tab 4: Export & Analysis """ with gr.Tab("Export & Analysis"): gr.Markdown("### Export Results") with gr.Row(): export_model = gr.Dropdown( model_names, label="Select Model to Export" ) export_format = gr.Radio( ['CSV', 'JSON', 'NPZ'], value='CSV', label="Format" ) export_btn = gr.Button("Export Data", variant="primary") export_output = gr.File(label="Download") def export_data(model_name, format_type): if not model_name: return None data = analyzer.models[model_name] output_path = f"{model_name}_export.{format_type.lower()}" if format_type == 'CSV': df = pd.DataFrame({ 'cluster': data['cluster_labels'], 'phoneme': data['phoneme_strings'], 'phone_idx': data['phone_labels'] }) df.to_csv(output_path, index=False) elif format_type == 'JSON': export_dict = { 'clusters': data['cluster_labels'].tolist(), 'phonemes': data['phoneme_strings'].tolist(), 'phone_indices': data['phone_labels'].tolist() } with open(output_path, 'w') as f: json.dump(export_dict, f, indent=2) else: # NPZ np.savez( output_path, features=data['features'], clusters=data['cluster_labels'], phones=data['phone_labels'] ) return output_path export_btn.click( fn=export_data, inputs=[export_model, export_format], outputs=[export_output] ) """ # Tab 6: Context Pooling Analysis """ with gr.Tab("Context Pooling"): gr.Markdown("### Coarticulation Analysis") gr.Markdown("Pool phoneme embeddings by context to account for coarticulation effects") with gr.Row(): with gr.Column(scale=1): gr.Markdown("#### Pooling Configuration") context_model = gr.Dropdown( model_names, value=model_names[0] if model_names else None, label="Select Model" ) enable_pooling = gr.Checkbox( label="Enable Context Pooling", value=False ) left_context = gr.Slider( 0, 3, value=1, step=1, label="Left Context (# phones)", info="How many phones before target" ) right_context = gr.Slider( 0, 3, value=1, step=1, label="Right Context (# phones)", info="How many phones after target" ) pooling_method = gr.Radio( choices=['mean', 'median', 'max'], value='mean', label="Pooling Method" ) min_samples = gr.Slider( 1, 10, value=2, step=1, label="Min Samples per Context", info="Minimum instances to pool" ) compute_pooling_btn = gr.Button("Apply Pooling", variant="primary") pooling_status = gr.Markdown("") gr.Markdown("#### Analyze Specific Phone") phone_to_analyze = gr.Textbox( label="Phoneme", placeholder="æ", value="æ" ) analyze_phone_btn = gr.Button("Analyze Contexts") with gr.Column(scale=2): pooling_comparison = gr.Markdown("*Apply pooling to see comparison*") context_analysis = gr.Markdown("*Analyze a phone to see contexts*") # with gr.Row(): # pooled_plot = gr.Plot(label="Pooled Embeddings (UMAP)") # Context pooling callbacks def apply_context_pooling(model_name, enable, left, right, method, min_samp): if not model_name or model_name not in analyzer.models: return "Model not available", "" data = analyzer.models[model_name] if not enable: # No pooling metrics = calculate_all_metrics( data['cluster_labels'], data['phone_labels'] ) comparison = "### No Pooling (Baseline)\n\n" comparison += f"- **Points**: {len(data['features'])}\n" comparison += f"- **Cluster Purity**: {metrics['cluster_purity']:.3f}\n" comparison += f"- **Phone Purity**: {metrics['phone_purity']:.3f}\n" comparison += f"- **V-Measure**: {metrics['v_measure']:.3f}\n" comparison += f"- **NMI**: {metrics.get('nmi', 0):.3f}\n" return "No pooling applied (baseline)", comparison try: # Create context config config = ContextConfig( enabled=True, left_context=int(left), right_context=int(right), pooling_method=method, min_samples=int(min_samp) ) # Create pooler pooler = ContextAwarePooler(config) # Pool embeddings # Note: This assumes sequential data. In practice, you'd need # utterance boundaries from preprocessing phone_sequence = data['phone_labels'] # Simplified pooled_embeddings, context_info = pooler.create_context_clusters( data['features'], data['phone_labels'], phone_sequence, utterance_boundaries=None # Would come from data ) # Calculate metrics on pooled space # Need to re-cluster or map clusters from sklearn.cluster import KMeans n_clusters = len(np.unique(data['cluster_labels'])) kmeans = KMeans(n_clusters=n_clusters, random_state=42) pooled_clusters = kmeans.fit_predict(pooled_embeddings) metrics = calculate_all_metrics( pooled_clusters, context_info['labels'] ) # Create comparison comparison = f"### Context Pooling Results\n\n" comparison += f"**Configuration**: L{left}R{right} ({method})\n\n" comparison += f"- **Original Points**: {context_info['n_original']}\n" comparison += f"- **Pooled Points**: {context_info['n_pooled']}\n" comparison += f"- **Reduction**: {(1 - context_info['reduction_ratio'])*100:.1f}%\n\n" comparison += f"**Metrics**:\n" comparison += f"- **Cluster Purity**: {metrics['cluster_purity']:.3f}\n" comparison += f"- **Phone Purity**: {metrics['phone_purity']:.3f}\n" comparison += f"- **V-Measure**: {metrics['v_measure']:.3f}\n" comparison += f"- **NMI**: {metrics.get('nmi', 0):.3f}\n" status = f"Pooled {context_info['n_original']} → {context_info['n_pooled']} points" return status, comparison except Exception as e: return f"Error: {str(e)}", "" def analyze_phone_contexts(model_name, phone, left, right): if not model_name or not phone: return "*Enter phone to analyze*" if model_name not in analyzer.models: return "Model not available" try: data = analyzer.models[model_name] # Create analyzer ctx_analyzer = ContextAwareAnalyzer( embeddings=data['features'], phone_labels=data['phone_labels'], phone_sequence=data['phone_labels'], cluster_labels=data['cluster_labels'] ) # Analyze phone analysis = ctx_analyzer.analyze_context_effects(phone, PHONEMES) if 'error' in analysis: return f"{analysis['error']}" # Format output output = f"### Analysis of /{phone}/\n\n" output += f"- **Total occurrences**: {analysis['total_occurrences']}\n" output += f"- **Unique contexts**: {analysis['unique_contexts']}\n\n" output += f"**Most Common Contexts**:\n\n" # Sort by count contexts_sorted = sorted( analysis['contexts'].items(), key=lambda x: x[1]['count'], reverse=True ) for ctx_str, info in contexts_sorted[:10]: output += f"- **{ctx_str}**: {info['count']} times" if info['cluster_distribution']: clusters = ", ".join(f"C{c}({cnt})" for c, cnt in info['cluster_distribution'].items()) output += f" → {clusters}" output += "\n" if len(contexts_sorted) > 10: output += f"\n*... and {len(contexts_sorted) - 10} more contexts*" return output except Exception as e: return f"Error: {str(e)}" # Connect callbacks compute_pooling_btn.click( fn=apply_context_pooling, inputs=[context_model, enable_pooling, left_context, right_context, pooling_method, min_samples], outputs=[pooling_status, pooling_comparison] ) analyze_phone_btn.click( fn=analyze_phone_contexts, inputs=[context_model, phone_to_analyze, left_context, right_context], outputs=[context_analysis] ) """ # def get_choices(model_name, label_type): # viz = analyzer.projector_vizs[model_name] # df = pd.DataFrame(viz.labels) # choices = [str(x) for x in df[label_type].unique()] # print(choices) # value = choices[0] if choices else None # return choices, value def get_choices(model_name, label_type): viz = analyzer.projector_vizs[model_name] df = pd.DataFrame(viz.labels) if label_type == "phone": choices = df["phone"].unique() elif label_type == "cluster": choices = df["cluster"].unique() else: choices = df["language"].unique() return gr.update( choices=[str(x) for x in choices], # MUST be a Python list of strings value=str(choices[0]) # MUST be one of the choices ) with gr.Tab("Embedding Projector"): gr.Markdown("### TensorFlow Projector-Style 3D Visualization") gr.Markdown("Interactive exploration similar to TensorFlow's Embedding Projector") with gr.Row(): # Left sidebar with gr.Column(scale=1): gr.Markdown("#### Model & Projection") projector_model = gr.Dropdown( model_names, value=model_names[0] if model_names else None, label="Select Model" ) projection_method = gr.Radio( choices=['PCA', 't-SNE', 'UMAP'], # choices=['PCA', 'UMAP'], value='UMAP', label="Projection Method" ) tsne_perplexity = gr.Slider(5, 50, value=30, step=1, label="t-SNE Perplexity", visible=False) tsne_lr = gr.Slider(10, 1000, value=200, step=10, label="t-SNE Learning Rate", visible=False) tsne_iters = gr.Slider(250, 5000, value=1000, step=250, label="t-SNE Iterations", visible=False) projection_method.change( fn=toggle_tsne_params, inputs=[projection_method], outputs=[tsne_perplexity, tsne_lr, tsne_iters] ) dimension = gr.Radio( choices=['3D', '2D'], value='3D', label="Dimensions" ) projector_color_by = gr.Radio( # choices=['cluster', 'phone', 'language'], choices=['cluster', 'language'], value='cluster', label="Color by" ) compute_btn = gr.Button("Compute Projections", variant="primary") compute_status = gr.Markdown("*Click to compute projections*") gr.Markdown("#### Search & Highlight") search_mode = gr.Radio( choices=['By Label', 'By Features'], value='By Label', label="Search Mode" ) phones = ["æ", "ɑ", "ə", "i", "u"] clusters = [0, 1, 2, 3] languages = ["hi", "pa"] # Label search (simple) with gr.Group(visible=True) as label_search_group: # search_label_type = gr.Radio( # choices=['phone', 'cluster', 'language'], # value='phone', # label="Search in" # ) # search_term = gr.Textbox( # label="Search term", # placeholder="e.g., 'æ' or '5'" # ) # search_term = gr.Dropdown( # choices=list(phones), # initial choices # value=phones[0], # initial value # label="Search term", # allow_custom_value=True # ) # # Update dropdown choices when the label type changes # # Update search_term whenever the label type changes # search_label_type.change( # fn=get_choices, # inputs=[projector_model, search_label_type], # outputs=[search_term, search_term] # first = choices, second = value # ) search_label_type = gr.Radio( choices=["phone", "cluster", "language"], value="phone", label="Search in" ) search_term = gr.Dropdown( choices=[str(x) for x in phones], value=str(phones[0]), label="Search term" ) search_label_type.change( fn=get_choices, inputs=[projector_model, search_label_type], outputs=search_term ) # Feature search (advanced) with gr.Group(visible=False) as feature_search_group: search_manner = gr.Dropdown( choices=['stop', 'fricative', 'nasal', 'approximant', 'affricate', 'tap/flap'], multiselect=True, label="Manner" ) search_place = gr.Dropdown( choices=['bilabial', 'labiodental', 'dental', 'alveolar', 'postalveolar', 'palatal', 'velar', 'uvular', 'pharyngeal', 'glottal'], multiselect=True, label="Place" ) search_voicing = gr.Dropdown( choices=['voiced', 'voiceless'], multiselect=True, label="Voicing" ) search_vowel_height = gr.Dropdown( choices=['high', 'mid', 'low'], multiselect=True, label="Vowel Height" ) search_vowel_backness = gr.Dropdown( choices=['front', 'central', 'back'], multiselect=True, label="Vowel Backness" ) search_btn = gr.Button("🔍 Search") # gr.Markdown("#### Nearest Neighbors") # point_idx = gr.Number( # label="Point index", # value=0, # precision=0 # ) # n_neighbors = gr.Slider( # 1, 50, value=10, # step=1, # label="Number of neighbors" # ) # show_nn_btn = gr.Button("Show Neighbors") info_display = gr.Markdown("*Select a point or search*") # Main visualization area with gr.Column(scale=3): projector_plot = gr.Plot(label="Embedding Space") # with gr.Row(): # comparison_btn = gr.Button("Show Comparison View (PCA | t-SNE | UMAP)") # comparison_plot = gr.Plot(label="Comparison", visible=False) # Projector callbacks def compute_projections(model_name, method, tsne_perplexity, tsne_lr, tsne_iters): if not model_name or model_name not in analyzer.projector_vizs: return "Model not available", None viz = analyzer.projector_vizs[model_name] try: method_lower = method.lower() viz.compute_projections(method_lower, tsne_perplexity, tsne_lr, tsne_iters) # Create initial plot proj_key = f"{method_lower}_3d" fig = viz.create_3d_scatter( projection=proj_key, color_by='cluster' ) return f"{method} projections computed!", fig except Exception as e: return f"Error: {str(e)}", None def toggle_search_mode(mode): """Toggle between label and feature search.""" if mode == 'By Label': return gr.update(visible=True), gr.update(visible=False) else: return gr.update(visible=False), gr.update(visible=True) def update_projector_plot(model_name, method, dim, color_by_val, highlight_indices=None): if not model_name or model_name not in analyzer.projector_vizs: return None viz = analyzer.projector_vizs[model_name] proj_key = f"{method.lower()}_{dim.lower()}" # Check if projection exists if proj_key not in viz.projections: return None try: if dim == '3D': fig = viz.create_3d_scatter( projection=proj_key, color_by=color_by_val.lower(), highlight_indices=highlight_indices ) else: fig = viz.create_2d_scatter( projection=proj_key, color_by=color_by_val.lower(), highlight_indices=highlight_indices ) return fig except Exception as e: print(f"Error creating plot: {e}") return None def search_points(model_name, search_mode, search_type, term, method, dim, color_by_val, manner, place, voicing, vheight, vbackness): if not model_name or model_name not in analyzer.projector_vizs: return None, "Model not available" viz = analyzer.projector_vizs[model_name] if search_mode == 'By Label': if not term: fig = update_projector_plot(model_name, method, dim, color_by_val) return fig, "No search term provided" matches = viz.search_by_label(term, search_type.lower()) info = f"Found {len(matches)} matches for '{term}' in {search_type}" else: # By Features matches = viz.search_by_articulatory_features( PHONEMES, manner=manner if manner else None, place=place if place else None, voicing=voicing if voicing else None, vowel_height=vheight if vheight else None, vowel_backness=vbackness if vbackness else None ) # Get summary summary = viz.get_articulatory_summary(matches, PHONEMES) info = f"Found {len(matches)} points matching features:\n\n" if manner: info += f"**Manner**: {', '.join(manner)}\n" if place: info += f"**Place**: {', '.join(place)}\n" if voicing: info += f"**Voicing**: {', '.join(voicing)}\n" if vheight: info += f"**Vowel Height**: {', '.join(vheight)}\n" if vbackness: info += f"**Vowel Backness**: {', '.join(vbackness)}\n" if summary and len(matches) > 0: info += f"\n**Distribution**:\n" if summary.get('manner'): info += "- Manner: " + ", ".join( f"{k}({v})" for k, v in sorted(summary['manner'].items()) ) + "\n" if summary.get('place'): info += "- Place: " + ", ".join( f"{k}({v})" for k, v in sorted(summary['place'].items()) ) + "\n" fig = update_projector_plot(model_name, method, dim, color_by_val, highlight_indices=matches) if matches: if len(matches) <= 10: info += f"\n\nIndices: {matches}" else: info += f"\n\nSample indices: {matches[:10]}... (+{len(matches)-10} more)" return fig, info def show_neighbors(model_name, idx, n, method, dim, color_by_val): if not model_name or model_name not in analyzer.projector_vizs: return None, "Model not available" viz = analyzer.projector_vizs[model_name] if viz.nn_model is None: viz.build_nn_index() neighbors, distances = viz.find_nearest_neighbors(int(idx), int(n)) # Show with lines to neighbors line_pairs = [(int(idx), int(nn)) for nn in neighbors] proj_key = f"{method.lower()}_{dim.lower()}" if proj_key not in viz.projections: return None, "Projections not computed" if dim == '3D': fig = viz.create_3d_scatter( projection=proj_key, color_by=color_by_val.lower(), highlight_indices=[int(idx)] + list(neighbors), show_lines=True, line_pairs=line_pairs ) else: fig = viz.create_2d_scatter( projection=proj_key, color_by=color_by_val.lower(), highlight_indices=[int(idx)] + list(neighbors) ) info = f"Point {idx} - Nearest {n} neighbors:\n\n" for i, (nn_idx, dist) in enumerate(zip(neighbors, distances), 1): info += f"{i}. Index {nn_idx} (distance: {dist:.3f})\n" return fig, info def show_comparison_view(model_name, color_by_val): if not model_name or model_name not in analyzer.projector_vizs: return gr.update(visible=False), None viz = analyzer.projector_vizs[model_name] # Ensure all projections exist for method in ['pca', 'tsne', 'umap']: if f'{method}_3d' not in viz.projections: return gr.update(visible=False), None fig = viz.create_comparison_view(color_by=color_by_val.lower()) return gr.update(visible=True), fig # Connect callbacks # compute_btn.click( # fn=compute_projections, # inputs=[projector_model, projection_method], # outputs=[compute_status, projector_plot] # ) compute_btn.click( fn=compute_projections, inputs=[projector_model, projection_method, tsne_perplexity, tsne_lr, tsne_iters], outputs=[compute_status, projector_plot] ) search_mode.change( fn=toggle_search_mode, inputs=[search_mode], outputs=[label_search_group, feature_search_group] ) for component in [projection_method, dimension, projector_color_by]: component.change( fn=lambda m, meth, d, c: update_projector_plot(m, meth, d, c), inputs=[projector_model, projection_method, dimension, projector_color_by], outputs=[projector_plot] ) search_btn.click( fn=search_points, inputs=[projector_model, search_mode, search_label_type, search_term, projection_method, dimension, projector_color_by, search_manner, search_place, search_voicing, search_vowel_height, search_vowel_backness], outputs=[projector_plot, info_display] ) # show_nn_btn.click( # fn=show_neighbors, # inputs=[projector_model, point_idx, n_neighbors, # projection_method, dimension, projector_color_by], # outputs=[projector_plot, info_display] # ) # comparison_btn.click( # fn=lambda m, c: show_comparison_view(m, c), # inputs=[projector_model, projector_color_by], # outputs=[comparison_plot, comparison_plot] # ) return demo def create_root_interface(output_dir): subdirs = get_top_level_dirs(output_dir) # Load config try: with open("config.json") as f: config = json.load(f) selected = config.get("selected_dirs", []) if selected: subdirs = [d for d in subdirs if d.name in selected] except FileNotFoundError: pass # Load all if no config with gr.Blocks() as demo: gr.Markdown("## Discrete Token Phoneme Analysis") with gr.Tabs(): for subdir in subdirs: with gr.Tab(subdir.name): analyzer = load_analyzer_for_subdir(subdir) create_integrated_gradio_interface(analyzer) return demo if __name__ == "__main__": # # Create analyzer # analyzer = MultiModelAnalyzer(OUTPUT_DIR) # # Create and launch interface # demo = create_integrated_gradio_interface(analyzer) demo = create_root_interface(OUTPUT_DIR) demo.launch( theme=gr.themes.Soft() # server_port=args.port, # share=True # Creates public link ) # # demo = create_interface() # # demo.launch()