Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel | |
| from transformers.models.whisper.modeling_whisper import WhisperEncoder | |
| import torch.nn as nn | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| # === CUSTOM MODEL CLASSES === | |
| class WhisperEncoderOnlyConfig(WhisperConfig): | |
| model_type = "whisper_encoder_classifier" | |
| def __init__(self, n_fam=None, n_super=None, n_code=None, **kwargs): | |
| super().__init__(**kwargs) | |
| self.n_fam = n_fam | |
| self.n_super = n_super | |
| self.n_code = n_code | |
| class WhisperEncoderOnlyForClassification(WhisperPreTrainedModel): | |
| config_class = WhisperEncoderOnlyConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.encoder = WhisperEncoder(config) | |
| hidden = config.d_model | |
| self.fam_head = nn.Linear(hidden, config.n_fam) | |
| self.super_head = nn.Linear(hidden, config.n_super) | |
| self.code_head = nn.Linear(hidden, config.n_code) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| """Whisper doesn't have token embeddings""" | |
| return None | |
| def set_input_embeddings(self, value): | |
| """Ignore""" | |
| pass | |
| def enable_input_require_grads(self): | |
| return | |
| def forward(self, input_features, labels=None): | |
| enc_out = self.encoder(input_features=input_features) | |
| pooled = enc_out.last_hidden_state.mean(dim=1) | |
| fam_logits = self.fam_head(pooled) | |
| super_logits = self.super_head(pooled) | |
| code_logits = self.code_head(pooled) | |
| loss = None | |
| if labels is not None: | |
| fam_labels, super_labels, code_labels = labels | |
| loss_fn = nn.CrossEntropyLoss() | |
| loss = ( | |
| loss_fn(fam_logits, fam_labels) + | |
| loss_fn(super_logits, super_labels) + | |
| loss_fn(code_logits, code_labels) | |
| ) | |
| return { | |
| "loss": loss, | |
| "fam_logits": fam_logits, | |
| "super_logits": super_logits, | |
| "code_logits": code_logits, | |
| } | |
| # === REGISTER MODEL === | |
| AutoConfig.register("whisper_encoder_classifier", WhisperEncoderOnlyConfig) | |
| AutoModel.register(WhisperEncoderOnlyConfig, WhisperEncoderOnlyForClassification) | |
| # === LOAD MODEL === | |
| MODEL_REPO = "tachiwin/language_classification_enconly_model_2" | |
| print("Loading model on CPU...") | |
| processor = WhisperProcessor.from_pretrained(MODEL_REPO) | |
| #config = WhisperEncoderOnlyConfig.from_pretrained(MODEL_REPO) | |
| model = WhisperEncoderOnlyForClassification.from_pretrained(MODEL_REPO) | |
| # Load weights from safetensors | |
| #weights_path = hf_hub_download(repo_id=MODEL_REPO, filename="model.safetensors") | |
| #state_dict = load_file(weights_path) | |
| #model.load_state_dict(state_dict) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| # === INFERENCE FUNCTION === | |
| def predict_language(audio): | |
| if audio is None: | |
| return "β οΈ No audio provided", "", "" | |
| sample_rate, audio_array = audio | |
| # Normalization | |
| if audio_array.dtype == np.int16: | |
| audio_array = audio_array.astype(np.float32) / 32768.0 | |
| elif audio_array.dtype == np.int32: | |
| audio_array = audio_array.astype(np.float32) / 2147483648.0 | |
| # Resampling | |
| if sample_rate != 16000: | |
| import librosa | |
| audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000) | |
| # Preprocessing | |
| inputs = processor( | |
| audio_array, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = model(input_features=inputs.input_features) | |
| # Post-processing | |
| fam_probs = torch.softmax(outputs["fam_logits"], dim=-1) | |
| super_probs = torch.softmax(outputs["super_logits"], dim=-1) | |
| code_probs = torch.softmax(outputs["code_logits"], dim=-1) | |
| fam_idx = outputs["fam_logits"].argmax(-1).item() | |
| super_idx = outputs["super_logits"].argmax(-1).item() | |
| code_idx = outputs["code_logits"].argmax(-1).item() | |
| fam_conf = fam_probs[0, fam_idx].item() | |
| super_conf = super_probs[0, super_idx].item() | |
| code_conf = code_probs[0, code_idx].item() | |
| # Formatting results | |
| return ( | |
| {f"{fam_idx}": fam_conf}, | |
| {f"{super_idx}": super_conf}, | |
| {f"{code_idx}": code_conf} | |
| ) | |
| # === UI COMPONENTS === | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; padding: 30px; background: linear-gradient(135deg, #4f46e5 0%, #3b82f6 100%); color: white; border-radius: 15px; margin-bottom: 25px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);"> | |
| <h1 style="color: white; margin: 0; font-size: 2.5em;">π Indigenous Language Classifier</h1> | |
| <p style="font-size: 1.2em; opacity: 0.9; margin-top: 10px;">Hierarchical identification of 300+ Mesoamerican languages</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ποΈ 1. Input Audio") | |
| audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="numpy", | |
| label="Upload or Record" | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
| submit_btn = gr.Button("π Classify", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π 2. Classification Results") | |
| fam_output = gr.Label(num_top_classes=1, label="π Language Family") | |
| super_output = gr.Label(num_top_classes=1, label="π£οΈ Superlanguage") | |
| code_output = gr.Label(num_top_classes=1, label="π€ Language Code") | |
| submit_btn.click( | |
| fn=predict_language, | |
| inputs=audio_input, | |
| outputs=[fam_output, super_output, code_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, None, None, None), | |
| inputs=None, | |
| outputs=[audio_input, fam_output, super_output, code_output] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### βΉοΈ About this Model | |
| This application uses a custom **Whisper Encoder-Only** architecture trained to recognize the hierarchical structure of indigenous languages (Family β Superlanguage β Code). | |
| **Accuracy Overview:** | |
| - **Language Family**: ~73% | |
| - **Superlanguage**: ~59% | |
| - **Language Code**: ~52% | |
| *Developed for the Tachiwin project.* | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue")) | |