classifier / app.py
Luis J Camargo
import fix
4e7deef
raw
history blame
6.74 kB
# app.py
import os
import gradio as gr
import torch
import numpy as np
from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel
from transformers.models.whisper.modeling_whisper import WhisperEncoder
import torch.nn as nn
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
# === CUSTOM MODEL CLASSES ===
class WhisperEncoderOnlyConfig(WhisperConfig):
model_type = "whisper_encoder_classifier"
def __init__(self, n_fam=None, n_super=None, n_code=None, **kwargs):
super().__init__(**kwargs)
self.n_fam = n_fam
self.n_super = n_super
self.n_code = n_code
class WhisperEncoderOnlyForClassification(WhisperPreTrainedModel):
config_class = WhisperEncoderOnlyConfig
def __init__(self, config):
super().__init__(config)
self.encoder = WhisperEncoder(config)
hidden = config.d_model
self.fam_head = nn.Linear(hidden, config.n_fam)
self.super_head = nn.Linear(hidden, config.n_super)
self.code_head = nn.Linear(hidden, config.n_code)
self.post_init()
def get_input_embeddings(self):
"""Whisper doesn't have token embeddings"""
return None
def set_input_embeddings(self, value):
"""Ignore"""
pass
def enable_input_require_grads(self):
return
def forward(self, input_features, labels=None):
enc_out = self.encoder(input_features=input_features)
pooled = enc_out.last_hidden_state.mean(dim=1)
fam_logits = self.fam_head(pooled)
super_logits = self.super_head(pooled)
code_logits = self.code_head(pooled)
loss = None
if labels is not None:
fam_labels, super_labels, code_labels = labels
loss_fn = nn.CrossEntropyLoss()
loss = (
loss_fn(fam_logits, fam_labels) +
loss_fn(super_logits, super_labels) +
loss_fn(code_logits, code_labels)
)
return {
"loss": loss,
"fam_logits": fam_logits,
"super_logits": super_logits,
"code_logits": code_logits,
}
# === REGISTER MODEL ===
AutoConfig.register("whisper_encoder_classifier", WhisperEncoderOnlyConfig)
AutoModel.register(WhisperEncoderOnlyConfig, WhisperEncoderOnlyForClassification)
# === LOAD MODEL ===
MODEL_REPO = "tachiwin/language_classification_enconly_model_2"
print("Loading model on CPU...")
processor = WhisperProcessor.from_pretrained(MODEL_REPO)
#config = WhisperEncoderOnlyConfig.from_pretrained(MODEL_REPO)
model = WhisperEncoderOnlyForClassification.from_pretrained(MODEL_REPO)
# Load weights from safetensors
#weights_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors")
#state_dict = load_file(weights_path)
#model.load_state_dict(state_dict)
model.eval()
print("Model loaded successfully!")
# === INFERENCE FUNCTION ===
def predict_language(audio):
if audio is None:
return "⚠️ No audio provided", "", ""
sample_rate, audio_array = audio
# Normalization
if audio_array.dtype == np.int16:
audio_array = audio_array.astype(np.float32) / 32768.0
elif audio_array.dtype == np.int32:
audio_array = audio_array.astype(np.float32) / 2147483648.0
# Resampling
if sample_rate != 16000:
import librosa
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000)
# Preprocessing
inputs = processor(
audio_array,
sampling_rate=16000,
return_tensors="pt"
)
# Inference
with torch.no_grad():
outputs = model(input_features=inputs.input_features)
# Post-processing
fam_probs = torch.softmax(outputs["fam_logits"], dim=-1)
super_probs = torch.softmax(outputs["super_logits"], dim=-1)
code_probs = torch.softmax(outputs["code_logits"], dim=-1)
fam_idx = outputs["fam_logits"].argmax(-1).item()
super_idx = outputs["super_logits"].argmax(-1).item()
code_idx = outputs["code_logits"].argmax(-1).item()
fam_conf = fam_probs[0, fam_idx].item()
super_conf = super_probs[0, super_idx].item()
code_conf = code_probs[0, code_idx].item()
# Formatting results
return (
{f"{fam_idx}": fam_conf},
{f"{super_idx}": super_conf},
{f"{code_idx}": code_conf}
)
# === UI COMPONENTS ===
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; padding: 30px; background: linear-gradient(135deg, #4f46e5 0%, #3b82f6 100%); color: white; border-radius: 15px; margin-bottom: 25px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);">
<h1 style="color: white; margin: 0; font-size: 2.5em;">🌎 Indigenous Language Classifier</h1>
<p style="font-size: 1.2em; opacity: 0.9; margin-top: 10px;">Hierarchical identification of 300+ Mesoamerican languages</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸŽ™οΈ 1. Input Audio")
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="numpy",
label="Upload or Record"
)
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
submit_btn = gr.Button("πŸš€ Classify", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### πŸ“Š 2. Classification Results")
fam_output = gr.Label(num_top_classes=1, label="🌍 Language Family")
super_output = gr.Label(num_top_classes=1, label="πŸ—£οΈ Superlanguage")
code_output = gr.Label(num_top_classes=1, label="πŸ”€ Language Code")
submit_btn.click(
fn=predict_language,
inputs=audio_input,
outputs=[fam_output, super_output, code_output]
)
clear_btn.click(
fn=lambda: (None, None, None, None),
inputs=None,
outputs=[audio_input, fam_output, super_output, code_output]
)
gr.Markdown(
"""
---
### ℹ️ About this Model
This application uses a custom **Whisper Encoder-Only** architecture trained to recognize the hierarchical structure of indigenous languages (Family β†’ Superlanguage β†’ Code).
**Accuracy Overview:**
- **Language Family**: ~73%
- **Superlanguage**: ~59%
- **Language Code**: ~52%
*Developed for the Tachiwin project.*
"""
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"))