Pro-TeVA / app.py
Obiang's picture
Use gr.Textbox for JSON API to receive raw file paths
4ac44ba
"""
Gradio App for Pro-TeVA Yoruba Tone Recognition
Hugging Face Spaces deployment
"""
import gradio as gr
from speechbrain.inference.interfaces import foreign_class
import numpy as np
import matplotlib.pyplot as plt
import torch
import config
# ============ CONFIGURATION ============
# Import tone info from config
TONE_INFO = config.TONE_INFO
# ============ MODEL LOADING ============
print("Loading Pro-TeVA tone recognition model...")
print(f"Checkpoint folder: {config.CHECKPOINT_FOLDER}")
try:
tone_recognizer = foreign_class(
source="./",
pymodule_file="custom_interface.py",
classname="ProTeVaToneRecognizer",
hparams_file="inference.yaml",
savedir=config.PRETRAINED_MODEL_DIR
)
print("βœ“ Model loaded successfully!")
# Validate configuration
if config.validate_config():
print(f"βœ“ Space detection: {'ENABLED' if config.ENABLE_SPACE_DETECTION else 'DISABLED'}")
if config.ENABLE_SPACE_DETECTION:
print(f" Method: {config.SPACE_DETECTION_METHOD}")
except Exception as e:
print(f"βœ— Error loading model: {e}")
tone_recognizer = None
# ============ HELPER FUNCTIONS ============
def format_tone_sequence(tone_indices, tone_names):
"""Format tone sequence with colors and symbols"""
if not tone_indices:
return "No tones detected"
formatted = []
for idx, name in zip(tone_indices, tone_names):
info = config.get_tone_info(idx)
formatted.append(f"{info['name']} ({info['symbol']})")
return " β†’ ".join(formatted)
def create_f0_comparison_plot(f0_extracted, f0_predicted):
"""Create F0 comparison plot showing both extracted and predicted contours"""
if f0_extracted is None or f0_predicted is None:
return None
# Convert to numpy
if isinstance(f0_extracted, torch.Tensor):
f0_extracted_numpy = f0_extracted.cpu().numpy().flatten()
else:
f0_extracted_numpy = np.array(f0_extracted).flatten()
if isinstance(f0_predicted, torch.Tensor):
f0_predicted_numpy = f0_predicted.cpu().numpy().flatten()
else:
f0_predicted_numpy = np.array(f0_predicted).flatten()
# Create plot with both contours
fig, ax = plt.subplots(figsize=(12, 5))
# Normalized time axis
time_extracted = np.arange(len(f0_extracted_numpy)) / len(f0_extracted_numpy)
time_predicted = np.arange(len(f0_predicted_numpy)) / len(f0_predicted_numpy)
# Plot both F0 contours
ax.plot(time_extracted, f0_extracted_numpy, linewidth=2.5, color='#3498db',
label='Extracted F0 (TorchYIN)', alpha=0.8)
ax.plot(time_predicted, f0_predicted_numpy, linewidth=2.5, color='#e74c3c',
linestyle='--', label='Predicted F0 (Decoder)', alpha=0.8)
# Configure plot
ax.set_xlabel('Normalized Time', fontsize=12)
ax.set_ylabel('F0 (Hz)', fontsize=12)
ax.set_title('F0 Comparison: Extracted vs Predicted', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=11, framealpha=0.9)
plt.tight_layout()
return fig
def create_tone_visualization(tone_indices):
"""Create visual representation of tone sequence"""
if not tone_indices:
return None
fig, ax = plt.subplots(figsize=(max(12, len(tone_indices) * 0.8), 3))
# Prepare data
x_positions = []
colors = []
labels = []
position = 0
for idx in tone_indices:
info = config.get_tone_info(idx)
# Space tokens get different visual treatment
if idx == 4:
# Draw vertical line for space
ax.axvline(x=position - 0.25, color=info['color'],
linewidth=3, linestyle='--', alpha=0.7)
else:
x_positions.append(position)
colors.append(info['color'])
labels.append(info['symbol'])
position += 1
# Draw tone bars
if x_positions:
ax.bar(x_positions, [1] * len(x_positions), color=colors, alpha=0.7,
edgecolor='black', linewidth=2, width=0.8)
# Add tone symbols
for i, (pos, label) in enumerate(zip(x_positions, labels)):
ax.text(pos, 0.5, label, ha='center', va='center',
fontsize=20, fontweight='bold')
# Configure plot
if x_positions:
ax.set_xlim(-0.5, max(x_positions) + 0.5)
ax.set_ylim(0, 1.2)
if x_positions:
ax.set_xticks(x_positions)
ax.set_xticklabels([f"T{i+1}" for i in range(len(x_positions))])
ax.set_ylabel('Tone', fontsize=12)
ax.set_title('Tone Sequence Visualization (| = word boundary)',
fontsize=14, fontweight='bold')
ax.set_yticks([])
plt.tight_layout()
return fig
# ============ PREDICTION FUNCTION ============
def predict_tone(audio_file):
"""Main prediction function for Gradio interface"""
if tone_recognizer is None:
return "❌ Model not loaded. Please check configuration.", None, None, ""
if audio_file is None:
return "⚠️ Please provide an audio file", None, None, ""
try:
# Get predictions (now returns both F0 values)
tone_indices, tone_names, f0_extracted, f0_predicted = tone_recognizer.classify_file(audio_file)
# Format output
tone_text = format_tone_sequence(tone_indices, tone_names)
# Create visualizations - combined F0 comparison plot
f0_comparison_plot = create_f0_comparison_plot(f0_extracted, f0_predicted)
tone_viz = create_tone_visualization(tone_indices)
# Create statistics
num_tones = len([t for t in tone_indices if t != 4])
num_spaces = len([t for t in tone_indices if t == 4])
stats = f"""
πŸ“Š **Prediction Statistics:**
- Total tones detected: {num_tones}
- Word boundaries detected: {num_spaces}
- Sequence length: {len(tone_indices)}
🎡 **Tone Distribution:**
- High tones (H): {tone_indices.count(1)}
- Low tones (B): {tone_indices.count(2)}
- Mid tones (M): {tone_indices.count(3)}
βš™οΈ **Detection Settings:**
- Space detection: {'ENABLED' if config.ENABLE_SPACE_DETECTION else 'DISABLED'}
{f"- Method: {config.SPACE_DETECTION_METHOD}" if config.ENABLE_SPACE_DETECTION else ""}
"""
return tone_text, f0_comparison_plot, tone_viz, stats
except Exception as e:
import traceback
error_details = traceback.format_exc()
return f"❌ Error during prediction: {str(e)}\n\n{error_details}", None, None, ""
# ============ JSON API FUNCTION ============
def predict_tone_json(audio_file):
"""API endpoint that returns pure JSON response for programmatic access"""
if tone_recognizer is None:
return {
"success": False,
"error": "Model not loaded. Please check configuration."
}
if audio_file is None:
return {
"success": False,
"error": "No audio file provided"
}
try:
# Handle different input types from Gradio API
# gr.File returns the file path as a string
if isinstance(audio_file, str):
file_path = audio_file
elif hasattr(audio_file, 'name'):
# File-like object with name attribute
file_path = audio_file.name
elif isinstance(audio_file, dict):
# FileData format from API - prefer 'path' over 'name'/'orig_name'
file_path = audio_file.get('path') or audio_file.get('name', str(audio_file))
else:
# Try to get path attribute or convert to string
file_path = getattr(audio_file, 'path', str(audio_file))
# Get predictions
tone_indices, tone_names, f0_extracted, f0_predicted = tone_recognizer.classify_file(file_path)
# Convert F0 tensors to lists for JSON serialization
if hasattr(f0_extracted, 'cpu'):
f0_extracted_list = f0_extracted.cpu().numpy().flatten().tolist()
else:
f0_extracted_list = list(np.array(f0_extracted).flatten())
if hasattr(f0_predicted, 'cpu'):
f0_predicted_list = f0_predicted.cpu().numpy().flatten().tolist()
else:
f0_predicted_list = list(np.array(f0_predicted).flatten())
# Build response
return {
"success": True,
"tone_sequence": [
{
"index": idx,
"label": name,
"name": config.get_tone_info(idx)["name"],
"symbol": config.get_tone_info(idx)["symbol"]
}
for idx, name in zip(tone_indices, tone_names)
],
"tone_string": " β†’ ".join(tone_names),
"statistics": {
"total_tones": len([t for t in tone_indices if t != 4]),
"word_boundaries": len([t for t in tone_indices if t == 4]),
"sequence_length": len(tone_indices),
"high_tones": tone_indices.count(1),
"low_tones": tone_indices.count(2),
"mid_tones": tone_indices.count(3)
},
"f0_data": {
"extracted": f0_extracted_list,
"predicted": f0_predicted_list,
"length": len(f0_extracted_list)
},
"settings": {
"space_detection_enabled": config.ENABLE_SPACE_DETECTION,
"space_detection_method": config.SPACE_DETECTION_METHOD if config.ENABLE_SPACE_DETECTION else None
}
}
except Exception as e:
import traceback
return {
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}
# ============ GRADIO INTERFACE ============
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.output-text {
font-size: 18px;
font-weight: bold;
}
"""
with gr.Blocks(css=custom_css, title="Pro-TeVA Tone Recognition") as demo:
with gr.Row():
with gr.Column(scale=2):
gr.Markdown(
f"""
# Pro-TeVA: Prototype-based Explainable Tone Recognition for Yoruba
Upload an audio file or record your voice to detect Yoruba tone patterns.
**Yoruba Tones:**
- **High Tone (H)** (β—ŒΜ): Syllable with high pitch
- **Low Tone (B)** (β—ŒΜ€): Syllable with low pitch
- **Mid Tone (M)** (β—Œ): Syllable with neutral/middle pitch
- **Space ( | )**: Word boundary (detected automatically)
**Space Detection:** {config.SPACE_DETECTION_METHOD if config.ENABLE_SPACE_DETECTION else 'OFF'}
"""
)
with gr.Column(scale=1):
gr.Markdown("### 🎧 Audio Examples")
gr.Markdown("**Click on an example to load it**")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 🎀 Input Audio")
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
label="Record or Upload Audio",
waveform_options={"show_controls": True}
)
# Female voice examples
gr.Markdown("**πŸ‘© Female Voice (Yof):**")
gr.Examples(
examples=[
["examples/yof_00295_00024634140.wav"],
["examples/yof_00295_00151151204.wav"],
["examples/yof_00295_00427144639.wav"],
["examples/yof_00295_00564596981.wav"],
],
inputs=audio_input,
label="",
examples_per_page=4
)
# Male voice examples
gr.Markdown("**πŸ‘¨ Male Voice (Yom):**")
gr.Examples(
examples=[
["examples/yom_08784_01544027142.wav"],
["examples/yom_08784_01792196659.wav"],
["examples/yom_09334_00045442417.wav"],
["examples/yom_09334_00091591408.wav"],
],
inputs=audio_input,
label="",
examples_per_page=4
)
predict_btn = gr.Button("πŸ” Predict Tones", variant="primary", size="lg")
gr.Markdown(
"""
### πŸ“ Tips:
- Speak clearly in Yoruba
- Keep recordings under 10 seconds
- Avoid background noise
- Pause slightly between words for better boundary detection
"""
)
with gr.Column(scale=2):
gr.Markdown("### 🎯 Results")
tone_output = gr.Textbox(
label="Predicted Tone Sequence",
lines=3,
elem_classes="output-text"
)
stats_output = gr.Markdown(label="Statistics")
with gr.Tabs():
with gr.Tab("F0 Comparison"):
f0_comparison_plot = gr.Plot(label="Extracted vs Predicted F0")
with gr.Tab("Tone Visualization"):
tone_viz = gr.Plot(label="Tone Sequence")
predict_btn.click(
fn=predict_tone,
inputs=audio_input,
outputs=[tone_output, f0_comparison_plot, tone_viz, stats_output]
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
f"""
---
**About Pro-TeVA:**
**Pro-TeVA** (Prototype-based Temporal Variational Autoencoder) is an explainable neural model for tone recognition.
Unlike black-box models, Pro-TeVA provides transparency through:
- Interpretable F0 (pitch) features
- Visualizable tone prototypes
- F0 reconstruction for explainability
- High performance: 17.74% Tone Error Rate
**Model Architecture:**
- Feature Extractor: HuBERT (Orange/SSA-HuBERT-base-60k)
- Encoder: {config.RNN_LAYERS}-layer Bidirectional GRU ({config.RNN_NEURONS} neurons)
- Variational Autoencoder: Compact latent representations
- Prototype Layer: {config.N_PROTOTYPES} learnable tone prototypes
- Decoder: F0 reconstruction (VanillaNN)
- Output: CTC-based sequence prediction
**Space Detection:**
- Method: {config.SPACE_DETECTION_METHOD if config.ENABLE_SPACE_DETECTION else 'Disabled'}
- Uses F0 contours, silence patterns, and tone duration
- Automatically detects word boundaries in continuous speech
**API Access:**
- REST API enabled for programmatic access
- Use Gradio client: `pip install gradio_client`
- See README for full API documentation
Built with ❀️ using SpeechBrain and Gradio
**Model Checkpoint:** {config.CHECKPOINT_FOLDER}
"""
)
with gr.Column(scale=1):
gr.Image(
value="proteva_archi.png",
label="Pro-TeVA Architecture",
show_label=True
)
# JSON API interface - use gr.Textbox to receive raw file path from API
# This avoids Gradio's file component preprocessing issues
json_api = gr.Interface(
fn=predict_tone_json,
inputs=gr.Textbox(label="Audio File Path", placeholder="Path to uploaded audio file"),
outputs=gr.JSON(label="Prediction Result"),
api_name="predict_json",
title="Pro-TeVA JSON API",
description="Upload file first via /gradio_api/upload, then pass the returned path here"
)
# API Documentation tab
with gr.Blocks() as api_docs:
gr.Markdown(
"""
# API Documentation
Pro-TeVA provides two API endpoints for programmatic access to tone recognition.
---
## Available Endpoints
| Endpoint | Description | Output Type |
|----------|-------------|-------------|
| `/predict` | UI endpoint with visualizations | Text + Plots |
| `/predict_json` | Pure JSON for APIs | Structured JSON |
---
## JSON API Endpoint (Recommended)
**Endpoint:** `/predict_json`
This is the recommended endpoint for programmatic access as it returns pure JSON data.
### Input
- **audio_file**: Audio file (WAV, MP3, FLAC)
- Recommended: 16kHz sample rate, mono
- Max duration: ~10 seconds
### Output Schema
```json
{
"success": true,
"tone_sequence": [
{
"index": 1,
"label": "H",
"name": "High Tone",
"symbol": "β—ŒΜ"
}
],
"tone_string": "H β†’ B β†’ M",
"statistics": {
"total_tones": 3,
"word_boundaries": 1,
"sequence_length": 4,
"high_tones": 1,
"low_tones": 1,
"mid_tones": 1
},
"f0_data": {
"extracted": [120.5, 125.3, ...],
"predicted": [118.2, 123.8, ...],
"length": 100
},
"settings": {
"space_detection_enabled": true,
"space_detection_method": "combined"
}
}
```
---
## Python Examples
### Installation
```bash
pip install gradio_client
```
### Basic Usage
```python
from gradio_client import Client
# Connect to Pro-TeVA
client = Client("https://huggingface.co/spaces/Obiang/Pro-TeVA")
# Get JSON response
result = client.predict(
audio_file="path/to/audio.wav",
api_name="/predict_json"
)
# Parse results
print(f"Success: {result['success']}")
print(f"Tones: {result['tone_string']}")
print(f"Statistics: {result['statistics']}")
```
### Batch Processing
```python
from gradio_client import Client
client = Client("https://huggingface.co/spaces/Obiang/Pro-TeVA")
audio_files = ["audio1.wav", "audio2.wav", "audio3.wav"]
for audio_path in audio_files:
result = client.predict(
audio_file=audio_path,
api_name="/predict_json"
)
if result['success']:
print(f"{audio_path}: {result['tone_string']}")
else:
print(f"{audio_path}: Error - {result['error']}")
```
---
## cURL Examples
### Step 1: Submit Request
```bash
curl -X POST "https://Obiang-Pro-TeVA.hf.space/call/predict_json" \\
-H "Content-Type: application/json" \\
-d '{
"data": ["https://example.com/audio.wav"]
}'
```
**Response:**
```json
{"event_id": "abc123def456"}
```
### Step 2: Get Results
```bash
curl -N "https://Obiang-Pro-TeVA.hf.space/call/predict_json/abc123def456"
```
**Response (Server-Sent Events):**
```
event: complete
data: {"success": true, "tone_sequence": [...], ...}
```
### One-liner with jq
```bash
# Submit and get event_id
EVENT_ID=$(curl -s -X POST "https://Obiang-Pro-TeVA.hf.space/call/predict_json" \\
-H "Content-Type: application/json" \\
-d '{"data": ["audio.wav"]}' | jq -r '.event_id')
# Get results
curl -N "https://Obiang-Pro-TeVA.hf.space/call/predict_json/$EVENT_ID"
```
---
## JavaScript Example
```javascript
import { client } from "@gradio/client";
async function predictTones(audioBlob) {
const app = await client("https://huggingface.co/spaces/Obiang/Pro-TeVA");
const result = await app.predict("/predict_json", {
audio_file: audioBlob
});
console.log("Tones:", result.data.tone_string);
console.log("Statistics:", result.data.statistics);
return result.data;
}
```
---
## Error Handling
### Error Response Schema
```json
{
"success": false,
"error": "Error message here",
"traceback": "Full error traceback..."
}
```
### Python Error Handling
```python
from gradio_client import Client
try:
client = Client("https://huggingface.co/spaces/Obiang/Pro-TeVA")
result = client.predict(
audio_file="audio.wav",
api_name="/predict_json"
)
if result['success']:
print(f"Tones: {result['tone_string']}")
else:
print(f"Error: {result['error']}")
except Exception as e:
print(f"Connection error: {str(e)}")
```
---
## Rate Limits
- Hugging Face Spaces: Standard rate limits apply
- Free tier: Suitable for development and testing
- For high-volume usage: Consider deploying your own instance
---
## Tone Labels Reference
| Index | Label | Name | Symbol |
|-------|-------|------|--------|
| 0 | BLANK | CTC Blank | - |
| 1 | H | High Tone | β—ŒΜ |
| 2 | B | Low Tone | β—ŒΜ€ |
| 3 | M | Mid Tone | β—Œ |
| 4 | SPACE | Word Boundary | \\| |
---
## Support
For questions or issues, please open an issue on the repository or check the README for more details.
"""
)
# Combine all interfaces
app = gr.TabbedInterface(
[demo, json_api, api_docs],
["Tone Recognition", "JSON API", "API Documentation"],
title="Pro-TeVA: Prototype-based Explainable Tone Recognition for Yoruba"
)
if __name__ == "__main__":
app.launch(
share=config.GRADIO_SHARE,
server_name=config.GRADIO_SERVER_NAME,
server_port=config.GRADIO_SERVER_PORT,
show_api=config.ENABLE_API
)