Tolulope Ogunremi
adding uk and roundingm metrics
7e2d7f4
import gradio as gr
import os
import sys
import subprocess
import importlib
from pathlib import Path
import json
import pandas as pd
def install_private_package():
"""Install private package from GitHub using token"""
print("Installing private package...")
gh_token = os.environ.get("GH_TOKEN")
if not gh_token:
raise ValueError("GH_TOKEN not found in environment variables")
package_url = f"git+https://{gh_token}@github.com/tolulope/speech-model-analysis.git"
# Use subprocess for better error handling
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "--no-cache-dir", package_url],
capture_output=True,
text=True
)
if result.returncode != 0:
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)
raise RuntimeError(f"Failed to install package: {result.stderr}")
print("✓ Package installed successfully!")
# Clear import caches so Python recognizes the new package
importlib.invalidate_caches()
# Install the package first
install_private_package()
# # Install private package at startup
# print("Installing private package...")
# gh_token = os.environ.get("GH_TOKEN")
# if not gh_token:
# raise ValueError("GH_TOKEN not found in environment variables")
# package_url = f"git+https://{gh_token}@github.com/tolulope/speech-model-analysis.git"
# os.system(f"{sys.executable} -m pip install {package_url}")
# Now import from your private package
from speech_model_analysis import (
VoxCommunisPreprocessor,
MultiModelAnalyzer,
create_hubert_configs,
)
from speech_model_analysis.phoneme_manager import PHONEMES, index_to_phoneme
from speech_model_analysis.voxcommunis_preprocessing import VoxCommunisPreprocessor, create_hubert_configs
from speech_model_analysis.gradio_viz import ClusterVisualizer
from speech_model_analysis.enhanced_analysis import calculate_all_metrics
from speech_model_analysis.audio_player import ClusterAudioExplorer, create_audio_grid
from speech_model_analysis.embedding_projector_viz import EmbeddingProjectorViz
from speech_model_analysis.context_pooling import ContextConfig, ContextAwarePooler, ContextAwareAnalyzer
print("Private package loaded successfully!")
from huggingface_hub import hf_hub_download, snapshot_download, login
login(os.environ["HF_TOKEN"])
# Download the full repo snapshot to a local dir
OUTPUT_DIR = snapshot_download("tolulope/speech-model-analysis", repo_type="dataset")
def get_top_level_dirs(root):
root = Path(root)
return [d for d in root.iterdir() if d.is_dir()]
def load_analyzer_for_subdir(subdir_path):
return MultiModelAnalyzer(str(subdir_path))
def toggle_tsne_params(method):
visible = method == "t-SNE"
return [
gr.update(visible=visible),
gr.update(visible=visible),
gr.update(visible=visible)
]
def create_integrated_gradio_interface(analyzer: MultiModelAnalyzer):
"""
Create comprehensive Gradio interface with model comparison.
Args:
analyzer: MultiModelAnalyzer instance
"""
# Extract feature options (same as before)
all_manners = sorted(set(p.manner.name for p in PHONEMES.values()
if p.manner))
all_places = sorted(set(p.place.name for p in PHONEMES.values()
if p.place))
all_voicings = ['voiced', 'voiceless']
all_heights = ['high', 'mid', 'low']
all_backness = ['front', 'central', 'back']
model_names = analyzer.get_model_names()
with gr.Blocks(title="Discrete Token Analysis") as demo:
gr.Markdown("# Discrete Token Phoneme Analysis")
# gr.Markdown("Compare HuBERT models and analyze discrete representations")
with gr.Tabs():
# Tab 1: Model Comparison
with gr.Tab("Model Comparison"):
gr.Markdown("### Compare Clustering Quality Across Models")
with gr.Row():
# comparison_plot = gr.Plot(label="Metrics Comparison")
metrics_table = gr.Dataframe(label="Detailed Metrics")
refresh_comparison_btn = gr.Button("Refresh Comparison", variant="primary")
def update_comparison():
# fig = analyzer.create_comparison_plot()
df = analyzer.compare_metrics()
df = df.round(2)
return df
# refresh_comparison_btn.click(
# fn=update_comparison,
# outputs=[comparison_plot, metrics_table]
# )
# Initialize
demo.load(
fn=update_comparison,
# outputs=[comparison_plot, metrics_table]
outputs=[metrics_table]
)
# Tab 2: Single Model Analysis
"""
with gr.Tab("Single Model Analysis"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Model & Filters")
model_selector = gr.Dropdown(
model_names,
value=model_names[0] if model_names else None,
label="Select Model"
)
color_by = gr.Radio(
['cluster', 'phone'],
value='cluster',
label="Color by"
)
gr.Markdown("#### Articulatory Filters")
manner_filter = gr.Dropdown(
all_manners,
multiselect=True,
label="Manner"
)
place_filter = gr.Dropdown(
all_places,
multiselect=True,
label="Place"
)
voicing_filter = gr.Dropdown(
all_voicings,
multiselect=True,
label="Voicing"
)
vowel_height_filter = gr.Dropdown(
all_heights,
multiselect=True,
label="Vowel Height"
)
vowel_backness_filter = gr.Dropdown(
all_backness,
multiselect=True,
label="Vowel Backness"
)
update_btn = gr.Button("Update Visualization", variant="primary")
with gr.Column(scale=2):
plot_output = gr.Plot(label="Cluster Visualization")
gr.Markdown("💡 **Tip**: Click on points to hear audio in the Audio Explorer tab!")
with gr.Row():
with gr.Column():
metrics_output = gr.Markdown()
with gr.Column():
confusion_output = gr.Plot(label="Confusion Matrix")
def update_single_model(model_name, color, manner, place, voicing, height, backness):
if not model_name:
return None, "Select a model", None
visualizer = analyzer.visualizers[model_name]
# Create scatter plot
fig = visualizer.create_scatter_plot(
color_by=color,
filter_manner=manner if manner else None,
filter_place=place if place else None,
filter_voicing=voicing if voicing else None,
filter_vowel_height=height if height else None,
filter_vowel_backness=backness if backness else None
)
# Calculate metrics
metrics = visualizer.calculate_metrics(
filter_manner=manner if manner else None,
filter_place=place if place else None,
filter_voicing=voicing if voicing else None,
filter_vowel_height=height if height else None,
filter_vowel_backness=backness if backness else None
)
# Create confusion matrix
confusion_fig = analyzer.create_confusion_heatmap(model_name)
return fig, metrics, confusion_fig
update_btn.click(
fn=update_single_model,
inputs=[model_selector, color_by, manner_filter, place_filter,
voicing_filter, vowel_height_filter, vowel_backness_filter],
outputs=[plot_output, metrics_output, confusion_output]
)
"""
# Tab 3: Audio Explorer
"""
with gr.Tab("Audio Explorer"):
gr.Markdown("### Listen to Cluster Samples")
gr.Markdown("Explore audio segments from clusters and phonemes")
with gr.Row():
with gr.Column():
audio_model_selector = gr.Dropdown(
model_names,
value=model_names[0] if model_names else None,
label="Select Model"
)
exploration_mode = gr.Radio(
['By Cluster', 'By Phoneme', 'Compare Phoneme Across Clusters'],
value='By Cluster',
label="Exploration Mode"
)
# Cluster mode inputs
with gr.Group(visible=True) as cluster_inputs:
cluster_id_audio = gr.Number(
label="Cluster ID",
value=0,
precision=0
)
n_cluster_samples = gr.Slider(
1, 10, value=5,
step=1,
label="Number of samples"
)
# Phoneme mode inputs
with gr.Group(visible=False) as phoneme_inputs:
phoneme_select = gr.Dropdown(
sorted(list(PHONEMES.keys())),
label="Select Phoneme",
value="æ"
)
n_phoneme_samples = gr.Slider(
1, 10, value=5,
step=1,
label="Number of samples"
)
# Compare mode inputs
with gr.Group(visible=False) as compare_inputs:
phoneme_compare = gr.Dropdown(
sorted(list(PHONEMES.keys())),
label="Phoneme to Compare",
value="æ"
)
n_per_cluster = gr.Slider(
1, 5, value=3,
step=1,
label="Samples per cluster"
)
play_audio_btn = gr.Button("🎵 Load Audio Samples", variant="primary")
with gr.Column(scale=2):
audio_output = gr.HTML(label="Audio Player")
audio_info = gr.Markdown()
# Toggle visibility based on mode
def update_visibility(mode):
return (
gr.update(visible=(mode == 'By Cluster')),
gr.update(visible=(mode == 'By Phoneme')),
gr.update(visible=(mode == 'Compare Phoneme Across Clusters'))
)
exploration_mode.change(
fn=update_visibility,
inputs=[exploration_mode],
outputs=[cluster_inputs, phoneme_inputs, compare_inputs]
)
def load_audio_samples(model_name, mode, cluster_id, n_cluster,
phoneme, n_phoneme, phoneme_cmp, n_per_clust):
if not model_name or model_name not in analyzer.audio_explorers:
return "<p>Audio not available for this model</p>", "No audio data loaded"
explorer = analyzer.audio_explorers[model_name]
try:
if mode == 'By Cluster':
samples = explorer.get_cluster_samples(
cluster_id=int(cluster_id),
n_samples=int(n_cluster)
)
info = f"### Cluster {cluster_id}\n\nShowing {len(samples)} samples"
elif mode == 'By Phoneme':
samples = explorer.get_phoneme_samples(
phoneme=phoneme,
n_samples=int(n_phoneme)
)
info = f"### Phoneme: {phoneme}\n\nShowing {len(samples)} samples"
else: # Compare mode
cluster_samples = explorer.compare_phoneme_in_clusters(
phoneme=phoneme_cmp,
n_per_cluster=int(n_per_clust)
)
# Flatten samples and add cluster headers
html = ""
info_lines = [f"### Phoneme: {phoneme_cmp} across clusters\n"]
for cluster_id, samps in sorted(cluster_samples.items()):
html += f'<h4>Cluster {cluster_id}</h4>'
html += create_audio_grid(samps, columns=3)
info_lines.append(f"- Cluster {cluster_id}: {len(samps)} samples")
return html, "\n".join(info_lines)
if not samples:
return "<p>No samples found</p>", "No matching samples"
html = create_audio_grid(samples, columns=3)
return html, info
except Exception as e:
return f"<p>Error loading audio: {str(e)}</p>", f"Error: {str(e)}"
play_audio_btn.click(
fn=load_audio_samples,
inputs=[audio_model_selector, exploration_mode,
cluster_id_audio, n_cluster_samples,
phoneme_select, n_phoneme_samples,
phoneme_compare, n_per_cluster],
outputs=[audio_output, audio_info]
)
"""
# Tab 4: Export & Analysis
"""
with gr.Tab("Export & Analysis"):
gr.Markdown("### Export Results")
with gr.Row():
export_model = gr.Dropdown(
model_names,
label="Select Model to Export"
)
export_format = gr.Radio(
['CSV', 'JSON', 'NPZ'],
value='CSV',
label="Format"
)
export_btn = gr.Button("Export Data", variant="primary")
export_output = gr.File(label="Download")
def export_data(model_name, format_type):
if not model_name:
return None
data = analyzer.models[model_name]
output_path = f"{model_name}_export.{format_type.lower()}"
if format_type == 'CSV':
df = pd.DataFrame({
'cluster': data['cluster_labels'],
'phoneme': data['phoneme_strings'],
'phone_idx': data['phone_labels']
})
df.to_csv(output_path, index=False)
elif format_type == 'JSON':
export_dict = {
'clusters': data['cluster_labels'].tolist(),
'phonemes': data['phoneme_strings'].tolist(),
'phone_indices': data['phone_labels'].tolist()
}
with open(output_path, 'w') as f:
json.dump(export_dict, f, indent=2)
else: # NPZ
np.savez(
output_path,
features=data['features'],
clusters=data['cluster_labels'],
phones=data['phone_labels']
)
return output_path
export_btn.click(
fn=export_data,
inputs=[export_model, export_format],
outputs=[export_output]
)
"""
# Tab 6: Context Pooling Analysis
"""
with gr.Tab("Context Pooling"):
gr.Markdown("### Coarticulation Analysis")
gr.Markdown("Pool phoneme embeddings by context to account for coarticulation effects")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### Pooling Configuration")
context_model = gr.Dropdown(
model_names,
value=model_names[0] if model_names else None,
label="Select Model"
)
enable_pooling = gr.Checkbox(
label="Enable Context Pooling",
value=False
)
left_context = gr.Slider(
0, 3, value=1, step=1,
label="Left Context (# phones)",
info="How many phones before target"
)
right_context = gr.Slider(
0, 3, value=1, step=1,
label="Right Context (# phones)",
info="How many phones after target"
)
pooling_method = gr.Radio(
choices=['mean', 'median', 'max'],
value='mean',
label="Pooling Method"
)
min_samples = gr.Slider(
1, 10, value=2, step=1,
label="Min Samples per Context",
info="Minimum instances to pool"
)
compute_pooling_btn = gr.Button("Apply Pooling", variant="primary")
pooling_status = gr.Markdown("")
gr.Markdown("#### Analyze Specific Phone")
phone_to_analyze = gr.Textbox(
label="Phoneme",
placeholder="æ",
value="æ"
)
analyze_phone_btn = gr.Button("Analyze Contexts")
with gr.Column(scale=2):
pooling_comparison = gr.Markdown("*Apply pooling to see comparison*")
context_analysis = gr.Markdown("*Analyze a phone to see contexts*")
# with gr.Row():
# pooled_plot = gr.Plot(label="Pooled Embeddings (UMAP)")
# Context pooling callbacks
def apply_context_pooling(model_name, enable, left, right, method, min_samp):
if not model_name or model_name not in analyzer.models:
return "Model not available", ""
data = analyzer.models[model_name]
if not enable:
# No pooling
metrics = calculate_all_metrics(
data['cluster_labels'],
data['phone_labels']
)
comparison = "### No Pooling (Baseline)\n\n"
comparison += f"- **Points**: {len(data['features'])}\n"
comparison += f"- **Cluster Purity**: {metrics['cluster_purity']:.3f}\n"
comparison += f"- **Phone Purity**: {metrics['phone_purity']:.3f}\n"
comparison += f"- **V-Measure**: {metrics['v_measure']:.3f}\n"
comparison += f"- **NMI**: {metrics.get('nmi', 0):.3f}\n"
return "No pooling applied (baseline)", comparison
try:
# Create context config
config = ContextConfig(
enabled=True,
left_context=int(left),
right_context=int(right),
pooling_method=method,
min_samples=int(min_samp)
)
# Create pooler
pooler = ContextAwarePooler(config)
# Pool embeddings
# Note: This assumes sequential data. In practice, you'd need
# utterance boundaries from preprocessing
phone_sequence = data['phone_labels'] # Simplified
pooled_embeddings, context_info = pooler.create_context_clusters(
data['features'],
data['phone_labels'],
phone_sequence,
utterance_boundaries=None # Would come from data
)
# Calculate metrics on pooled space
# Need to re-cluster or map clusters
from sklearn.cluster import KMeans
n_clusters = len(np.unique(data['cluster_labels']))
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
pooled_clusters = kmeans.fit_predict(pooled_embeddings)
metrics = calculate_all_metrics(
pooled_clusters,
context_info['labels']
)
# Create comparison
comparison = f"### Context Pooling Results\n\n"
comparison += f"**Configuration**: L{left}R{right} ({method})\n\n"
comparison += f"- **Original Points**: {context_info['n_original']}\n"
comparison += f"- **Pooled Points**: {context_info['n_pooled']}\n"
comparison += f"- **Reduction**: {(1 - context_info['reduction_ratio'])*100:.1f}%\n\n"
comparison += f"**Metrics**:\n"
comparison += f"- **Cluster Purity**: {metrics['cluster_purity']:.3f}\n"
comparison += f"- **Phone Purity**: {metrics['phone_purity']:.3f}\n"
comparison += f"- **V-Measure**: {metrics['v_measure']:.3f}\n"
comparison += f"- **NMI**: {metrics.get('nmi', 0):.3f}\n"
status = f"Pooled {context_info['n_original']} → {context_info['n_pooled']} points"
return status, comparison
except Exception as e:
return f"Error: {str(e)}", ""
def analyze_phone_contexts(model_name, phone, left, right):
if not model_name or not phone:
return "*Enter phone to analyze*"
if model_name not in analyzer.models:
return "Model not available"
try:
data = analyzer.models[model_name]
# Create analyzer
ctx_analyzer = ContextAwareAnalyzer(
embeddings=data['features'],
phone_labels=data['phone_labels'],
phone_sequence=data['phone_labels'],
cluster_labels=data['cluster_labels']
)
# Analyze phone
analysis = ctx_analyzer.analyze_context_effects(phone, PHONEMES)
if 'error' in analysis:
return f"{analysis['error']}"
# Format output
output = f"### Analysis of /{phone}/\n\n"
output += f"- **Total occurrences**: {analysis['total_occurrences']}\n"
output += f"- **Unique contexts**: {analysis['unique_contexts']}\n\n"
output += f"**Most Common Contexts**:\n\n"
# Sort by count
contexts_sorted = sorted(
analysis['contexts'].items(),
key=lambda x: x[1]['count'],
reverse=True
)
for ctx_str, info in contexts_sorted[:10]:
output += f"- **{ctx_str}**: {info['count']} times"
if info['cluster_distribution']:
clusters = ", ".join(f"C{c}({cnt})"
for c, cnt in info['cluster_distribution'].items())
output += f" → {clusters}"
output += "\n"
if len(contexts_sorted) > 10:
output += f"\n*... and {len(contexts_sorted) - 10} more contexts*"
return output
except Exception as e:
return f"Error: {str(e)}"
# Connect callbacks
compute_pooling_btn.click(
fn=apply_context_pooling,
inputs=[context_model, enable_pooling, left_context, right_context,
pooling_method, min_samples],
outputs=[pooling_status, pooling_comparison]
)
analyze_phone_btn.click(
fn=analyze_phone_contexts,
inputs=[context_model, phone_to_analyze, left_context, right_context],
outputs=[context_analysis]
)
"""
# def get_choices(model_name, label_type):
# viz = analyzer.projector_vizs[model_name]
# df = pd.DataFrame(viz.labels)
# choices = [str(x) for x in df[label_type].unique()]
# print(choices)
# value = choices[0] if choices else None
# return choices, value
def get_choices(model_name, label_type):
viz = analyzer.projector_vizs[model_name]
df = pd.DataFrame(viz.labels)
if label_type == "phone":
choices = df["phone"].unique()
elif label_type == "cluster":
choices = df["cluster"].unique()
else:
choices = df["language"].unique()
return gr.update(
choices=[str(x) for x in choices], # MUST be a Python list of strings
value=str(choices[0]) # MUST be one of the choices
)
with gr.Tab("Embedding Projector"):
gr.Markdown("### TensorFlow Projector-Style 3D Visualization")
gr.Markdown("Interactive exploration similar to TensorFlow's Embedding Projector")
with gr.Row():
# Left sidebar
with gr.Column(scale=1):
gr.Markdown("#### Model & Projection")
projector_model = gr.Dropdown(
model_names,
value=model_names[0] if model_names else None,
label="Select Model"
)
projection_method = gr.Radio(
choices=['PCA', 't-SNE', 'UMAP'],
# choices=['PCA', 'UMAP'],
value='UMAP',
label="Projection Method"
)
tsne_perplexity = gr.Slider(5, 50, value=30, step=1, label="t-SNE Perplexity", visible=False)
tsne_lr = gr.Slider(10, 1000, value=200, step=10, label="t-SNE Learning Rate", visible=False)
tsne_iters = gr.Slider(250, 5000, value=1000, step=250, label="t-SNE Iterations", visible=False)
projection_method.change(
fn=toggle_tsne_params,
inputs=[projection_method],
outputs=[tsne_perplexity, tsne_lr, tsne_iters]
)
dimension = gr.Radio(
choices=['3D', '2D'],
value='3D',
label="Dimensions"
)
projector_color_by = gr.Radio(
# choices=['cluster', 'phone', 'language'],
choices=['cluster', 'language'],
value='cluster',
label="Color by"
)
compute_btn = gr.Button("Compute Projections", variant="primary")
compute_status = gr.Markdown("*Click to compute projections*")
gr.Markdown("#### Search & Highlight")
search_mode = gr.Radio(
choices=['By Label', 'By Features'],
value='By Label',
label="Search Mode"
)
phones = ["æ", "ɑ", "ə", "i", "u"]
clusters = [0, 1, 2, 3]
languages = ["hi", "pa"]
# Label search (simple)
with gr.Group(visible=True) as label_search_group:
# search_label_type = gr.Radio(
# choices=['phone', 'cluster', 'language'],
# value='phone',
# label="Search in"
# )
# search_term = gr.Textbox(
# label="Search term",
# placeholder="e.g., 'æ' or '5'"
# )
# search_term = gr.Dropdown(
# choices=list(phones), # initial choices
# value=phones[0], # initial value
# label="Search term",
# allow_custom_value=True
# )
# # Update dropdown choices when the label type changes
# # Update search_term whenever the label type changes
# search_label_type.change(
# fn=get_choices,
# inputs=[projector_model, search_label_type],
# outputs=[search_term, search_term] # first = choices, second = value
# )
search_label_type = gr.Radio(
choices=["phone", "cluster", "language"],
value="phone",
label="Search in"
)
search_term = gr.Dropdown(
choices=[str(x) for x in phones],
value=str(phones[0]),
label="Search term"
)
search_label_type.change(
fn=get_choices,
inputs=[projector_model, search_label_type],
outputs=search_term
)
# Feature search (advanced)
with gr.Group(visible=False) as feature_search_group:
search_manner = gr.Dropdown(
choices=['stop', 'fricative', 'nasal', 'approximant',
'affricate', 'tap/flap'],
multiselect=True,
label="Manner"
)
search_place = gr.Dropdown(
choices=['bilabial', 'labiodental', 'dental', 'alveolar',
'postalveolar', 'palatal', 'velar', 'uvular',
'pharyngeal', 'glottal'],
multiselect=True,
label="Place"
)
search_voicing = gr.Dropdown(
choices=['voiced', 'voiceless'],
multiselect=True,
label="Voicing"
)
search_vowel_height = gr.Dropdown(
choices=['high', 'mid', 'low'],
multiselect=True,
label="Vowel Height"
)
search_vowel_backness = gr.Dropdown(
choices=['front', 'central', 'back'],
multiselect=True,
label="Vowel Backness"
)
search_btn = gr.Button("🔍 Search")
# gr.Markdown("#### Nearest Neighbors")
# point_idx = gr.Number(
# label="Point index",
# value=0,
# precision=0
# )
# n_neighbors = gr.Slider(
# 1, 50, value=10,
# step=1,
# label="Number of neighbors"
# )
# show_nn_btn = gr.Button("Show Neighbors")
info_display = gr.Markdown("*Select a point or search*")
# Main visualization area
with gr.Column(scale=3):
projector_plot = gr.Plot(label="Embedding Space")
# with gr.Row():
# comparison_btn = gr.Button("Show Comparison View (PCA | t-SNE | UMAP)")
# comparison_plot = gr.Plot(label="Comparison", visible=False)
# Projector callbacks
def compute_projections(model_name, method, tsne_perplexity, tsne_lr, tsne_iters):
if not model_name or model_name not in analyzer.projector_vizs:
return "Model not available", None
viz = analyzer.projector_vizs[model_name]
try:
method_lower = method.lower()
viz.compute_projections(method_lower, tsne_perplexity, tsne_lr, tsne_iters)
# Create initial plot
proj_key = f"{method_lower}_3d"
fig = viz.create_3d_scatter(
projection=proj_key,
color_by='cluster'
)
return f"{method} projections computed!", fig
except Exception as e:
return f"Error: {str(e)}", None
def toggle_search_mode(mode):
"""Toggle between label and feature search."""
if mode == 'By Label':
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
def update_projector_plot(model_name, method, dim, color_by_val, highlight_indices=None):
if not model_name or model_name not in analyzer.projector_vizs:
return None
viz = analyzer.projector_vizs[model_name]
proj_key = f"{method.lower()}_{dim.lower()}"
# Check if projection exists
if proj_key not in viz.projections:
return None
try:
if dim == '3D':
fig = viz.create_3d_scatter(
projection=proj_key,
color_by=color_by_val.lower(),
highlight_indices=highlight_indices
)
else:
fig = viz.create_2d_scatter(
projection=proj_key,
color_by=color_by_val.lower(),
highlight_indices=highlight_indices
)
return fig
except Exception as e:
print(f"Error creating plot: {e}")
return None
def search_points(model_name, search_mode, search_type, term, method, dim,
color_by_val, manner, place, voicing, vheight, vbackness):
if not model_name or model_name not in analyzer.projector_vizs:
return None, "Model not available"
viz = analyzer.projector_vizs[model_name]
if search_mode == 'By Label':
if not term:
fig = update_projector_plot(model_name, method, dim, color_by_val)
return fig, "No search term provided"
matches = viz.search_by_label(term, search_type.lower())
info = f"Found {len(matches)} matches for '{term}' in {search_type}"
else: # By Features
matches = viz.search_by_articulatory_features(
PHONEMES,
manner=manner if manner else None,
place=place if place else None,
voicing=voicing if voicing else None,
vowel_height=vheight if vheight else None,
vowel_backness=vbackness if vbackness else None
)
# Get summary
summary = viz.get_articulatory_summary(matches, PHONEMES)
info = f"Found {len(matches)} points matching features:\n\n"
if manner:
info += f"**Manner**: {', '.join(manner)}\n"
if place:
info += f"**Place**: {', '.join(place)}\n"
if voicing:
info += f"**Voicing**: {', '.join(voicing)}\n"
if vheight:
info += f"**Vowel Height**: {', '.join(vheight)}\n"
if vbackness:
info += f"**Vowel Backness**: {', '.join(vbackness)}\n"
if summary and len(matches) > 0:
info += f"\n**Distribution**:\n"
if summary.get('manner'):
info += "- Manner: " + ", ".join(
f"{k}({v})" for k, v in sorted(summary['manner'].items())
) + "\n"
if summary.get('place'):
info += "- Place: " + ", ".join(
f"{k}({v})" for k, v in sorted(summary['place'].items())
) + "\n"
fig = update_projector_plot(model_name, method, dim, color_by_val,
highlight_indices=matches)
if matches:
if len(matches) <= 10:
info += f"\n\nIndices: {matches}"
else:
info += f"\n\nSample indices: {matches[:10]}... (+{len(matches)-10} more)"
return fig, info
def show_neighbors(model_name, idx, n, method, dim, color_by_val):
if not model_name or model_name not in analyzer.projector_vizs:
return None, "Model not available"
viz = analyzer.projector_vizs[model_name]
if viz.nn_model is None:
viz.build_nn_index()
neighbors, distances = viz.find_nearest_neighbors(int(idx), int(n))
# Show with lines to neighbors
line_pairs = [(int(idx), int(nn)) for nn in neighbors]
proj_key = f"{method.lower()}_{dim.lower()}"
if proj_key not in viz.projections:
return None, "Projections not computed"
if dim == '3D':
fig = viz.create_3d_scatter(
projection=proj_key,
color_by=color_by_val.lower(),
highlight_indices=[int(idx)] + list(neighbors),
show_lines=True,
line_pairs=line_pairs
)
else:
fig = viz.create_2d_scatter(
projection=proj_key,
color_by=color_by_val.lower(),
highlight_indices=[int(idx)] + list(neighbors)
)
info = f"Point {idx} - Nearest {n} neighbors:\n\n"
for i, (nn_idx, dist) in enumerate(zip(neighbors, distances), 1):
info += f"{i}. Index {nn_idx} (distance: {dist:.3f})\n"
return fig, info
def show_comparison_view(model_name, color_by_val):
if not model_name or model_name not in analyzer.projector_vizs:
return gr.update(visible=False), None
viz = analyzer.projector_vizs[model_name]
# Ensure all projections exist
for method in ['pca', 'tsne', 'umap']:
if f'{method}_3d' not in viz.projections:
return gr.update(visible=False), None
fig = viz.create_comparison_view(color_by=color_by_val.lower())
return gr.update(visible=True), fig
# Connect callbacks
# compute_btn.click(
# fn=compute_projections,
# inputs=[projector_model, projection_method],
# outputs=[compute_status, projector_plot]
# )
compute_btn.click(
fn=compute_projections,
inputs=[projector_model, projection_method,
tsne_perplexity, tsne_lr, tsne_iters],
outputs=[compute_status, projector_plot]
)
search_mode.change(
fn=toggle_search_mode,
inputs=[search_mode],
outputs=[label_search_group, feature_search_group]
)
for component in [projection_method, dimension, projector_color_by]:
component.change(
fn=lambda m, meth, d, c: update_projector_plot(m, meth, d, c),
inputs=[projector_model, projection_method, dimension, projector_color_by],
outputs=[projector_plot]
)
search_btn.click(
fn=search_points,
inputs=[projector_model, search_mode, search_label_type, search_term,
projection_method, dimension, projector_color_by,
search_manner, search_place, search_voicing,
search_vowel_height, search_vowel_backness],
outputs=[projector_plot, info_display]
)
# show_nn_btn.click(
# fn=show_neighbors,
# inputs=[projector_model, point_idx, n_neighbors,
# projection_method, dimension, projector_color_by],
# outputs=[projector_plot, info_display]
# )
# comparison_btn.click(
# fn=lambda m, c: show_comparison_view(m, c),
# inputs=[projector_model, projector_color_by],
# outputs=[comparison_plot, comparison_plot]
# )
return demo
def create_root_interface(output_dir):
subdirs = get_top_level_dirs(output_dir)
# Load config
try:
with open("config.json") as f:
config = json.load(f)
selected = config.get("selected_dirs", [])
if selected:
subdirs = [d for d in subdirs if d.name in selected]
except FileNotFoundError:
pass # Load all if no config
with gr.Blocks() as demo:
gr.Markdown("## Discrete Token Phoneme Analysis")
with gr.Tabs():
for subdir in subdirs:
with gr.Tab(subdir.name):
analyzer = load_analyzer_for_subdir(subdir)
create_integrated_gradio_interface(analyzer)
return demo
if __name__ == "__main__":
# # Create analyzer
# analyzer = MultiModelAnalyzer(OUTPUT_DIR)
# # Create and launch interface
# demo = create_integrated_gradio_interface(analyzer)
demo = create_root_interface(OUTPUT_DIR)
demo.launch(
theme=gr.themes.Soft()
# server_port=args.port,
# share=True # Creates public link
)
# # demo = create_interface()
# # demo.launch()