Spaces:
Running
Running
Luis J Camargo
feat: Update Gradio theme, header text, and model description in the 'About' section.
9d375e9 | import os | |
| import gc | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import librosa | |
| import pandas as pd | |
| from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel | |
| from transformers.models.whisper.modeling_whisper import WhisperEncoder | |
| import torch.nn as nn | |
| import psutil | |
| import json | |
| # --- CONFIGURATION --- | |
| MAX_AUDIO_SECONDS = 30 | |
| torch.set_num_threads(1) | |
| # === CUSTOM MODEL CLASSES === | |
| class WhisperEncoderOnlyConfig(WhisperConfig): | |
| model_type = "whisper_encoder_classifier" | |
| def __init__(self, n_fam=None, n_super=None, n_code=None, **kwargs): | |
| super().__init__(**kwargs) | |
| self.n_fam = n_fam | |
| self.n_super = n_super | |
| self.n_code = n_code | |
| class WhisperEncoderOnlyForClassification(WhisperPreTrainedModel): | |
| config_class = WhisperEncoderOnlyConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.encoder = WhisperEncoder(config) | |
| hidden = config.d_model | |
| self.fam_head = nn.Linear(hidden, config.n_fam) | |
| self.super_head = nn.Linear(hidden, config.n_super) | |
| self.code_head = nn.Linear(hidden, config.n_code) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return None | |
| def set_input_embeddings(self, value): | |
| pass | |
| def enable_input_require_grads(self): | |
| return | |
| def forward(self, input_features, labels=None): | |
| enc_out = self.encoder(input_features=input_features) | |
| pooled = enc_out.last_hidden_state.mean(dim=1) | |
| fam_logits = self.fam_head(pooled) | |
| super_logits = self.super_head(pooled) | |
| code_logits = self.code_head(pooled) | |
| loss = None | |
| if labels is not None: | |
| fam_labels, super_labels, code_labels = labels | |
| loss_fn = nn.CrossEntropyLoss() | |
| loss = ( | |
| loss_fn(fam_logits, fam_labels) + | |
| loss_fn(super_logits, super_labels) + | |
| loss_fn(code_logits, code_labels) | |
| ) | |
| return { | |
| "loss": loss, | |
| "fam_logits": fam_logits, | |
| "super_logits": super_logits, | |
| "code_logits": code_logits, | |
| } | |
| class LabelExtractor: | |
| """ | |
| Extracts family/super/code labels from tokenized sequences based on training design. | |
| """ | |
| def __init__(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| self.family_tokens = [] | |
| self.super_tokens = [] | |
| self.code_tokens = [] | |
| # Extract special tokens that represent categories from added_vocab | |
| for token_str, token_id in tokenizer.get_added_vocab().items(): | |
| if token_str.startswith("<|") and token_str.endswith("|>"): | |
| if token_str in ["<|startoftranscript|>", "<|endoftext|>", | |
| "<|nospeech|>", "<|notimestamps|>"]: | |
| continue | |
| if token_str.startswith("<|@"): | |
| self.family_tokens.append((token_str, token_id)) | |
| elif self._is_super_token(token_str): | |
| self.super_tokens.append((token_str, token_id)) | |
| else: | |
| self.code_tokens.append((token_str, token_id)) | |
| # Sort by token_id to match model indices | |
| self.family_tokens.sort(key=lambda x: x[1]) | |
| self.super_tokens.sort(key=lambda x: x[1]) | |
| self.code_tokens.sort(key=lambda x: x[1]) | |
| # We only need the flat lists of token names for inference mapping | |
| self.family_labels = [tok for tok, _ in self.family_tokens] | |
| self.super_labels = [tok for tok, _ in self.super_tokens] | |
| self.code_labels = [tok for tok, _ in self.code_tokens] | |
| print(f"Extracted labels:") | |
| print(f" Families: {len(self.family_labels)}") | |
| print(f" Superlanguages: {len(self.super_labels)}") | |
| print(f" Codes: {len(self.code_labels)}") | |
| def _is_super_token(self, token_str): | |
| # Based on training heuristic | |
| return len(token_str) > 2 and token_str[2].isupper() and not token_str.startswith("<|@") | |
| # === REGISTER MODEL === | |
| AutoConfig.register("whisper_encoder_classifier", WhisperEncoderOnlyConfig) | |
| AutoModel.register(WhisperEncoderOnlyConfig, WhisperEncoderOnlyForClassification) | |
| # === LOAD MODEL === | |
| MODEL_REPO = "tachiwin/language_classification_enconly_model_2" | |
| print("Loading model on CPU...") | |
| processor = WhisperProcessor.from_pretrained(MODEL_REPO) | |
| model = WhisperEncoderOnlyForClassification.from_pretrained( | |
| MODEL_REPO, | |
| low_cpu_mem_usage=True | |
| ) | |
| model.eval() | |
| # Initialize LabelExtractor to build text mappings | |
| label_extractor = LabelExtractor(processor.tokenizer) | |
| # Load languages mapping | |
| print("Loading language mappings...") | |
| try: | |
| with open("languages.json", "r", encoding="utf-8") as f: | |
| languages_data = json.load(f) | |
| CODE_TO_NAME = {item.get("code"): item.get("inali_name") for item in languages_data if item.get("code") and item.get("inali_name")} | |
| except Exception as e: | |
| print(f"Warning: Could not load languages.json: {e}") | |
| CODE_TO_NAME = {} | |
| print("Model loaded successfully!") | |
| def get_mem_usage(): | |
| process = psutil.Process(os.getpid()) | |
| return process.memory_info().rss / (1024 ** 2) | |
| # === INFERENCE FUNCTION === | |
| def predict_language(audio_path, fam_k=1, fam_thresh=0.0, super_k=1, super_thresh=0.0, code_k=3, code_thresh=0.0): | |
| if not audio_path: | |
| raise gr.Error("No audio provided! Please upload or record an audio file.") | |
| gc.collect() | |
| start_mem = get_mem_usage() | |
| print(f"\n--- [LOG] New Request ---") | |
| print(f"[LOG] Start Memory: {start_mem:.2f} MB") | |
| try: | |
| print("[LOG] Step 1: Loading and resampling audio from file...") | |
| audio_array, sample_rate = librosa.load(audio_path, sr=16000) | |
| audio_len_sec = len(audio_array) / 16000 | |
| print(f"[LOG] Audio duration: {audio_len_sec:.2f}s, SR: 16000") | |
| print(f"[LOG] Memory after load: {get_mem_usage():.2f} MB") | |
| if audio_len_sec > MAX_AUDIO_SECONDS: | |
| del audio_array | |
| gc.collect() | |
| raise gr.Error(f"Audio too long ({audio_len_sec:.1f}s). Please upload or record up to {MAX_AUDIO_SECONDS} seconds.") | |
| print("[LOG] Step 3: Extracting features...") | |
| inputs = processor( | |
| audio_array, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ) | |
| del audio_array | |
| gc.collect() | |
| print(f"[LOG] Memory after preprocessing: {get_mem_usage():.2f} MB") | |
| print("[LOG] Step 4: Running model inference...") | |
| with torch.no_grad(): | |
| outputs = model(input_features=inputs.input_features) | |
| del inputs | |
| gc.collect() | |
| print("[LOG] Step 5: Post-processing results...") | |
| fam_probs = torch.softmax(outputs["fam_logits"], dim=-1) | |
| super_probs = torch.softmax(outputs["super_logits"], dim=-1) | |
| code_probs = torch.softmax(outputs["code_logits"], dim=-1) | |
| def build_df(probs_tensor, k, thresh, labels_list, apply_mapping=False): | |
| k = int(k) | |
| top_vals, top_idx = torch.topk(probs_tensor[0], min(k, probs_tensor.shape[-1])) | |
| table_data = [] | |
| for i in range(len(top_vals)): | |
| score = top_vals[i].item() | |
| if score < thresh: | |
| continue | |
| idx = top_idx[i].item() | |
| raw_label = labels_list[idx].strip("<|>") if idx < len(labels_list) else f"Unknown ({idx})" | |
| if apply_mapping: | |
| name = f"{CODE_TO_NAME[raw_label]} ({raw_label})" if raw_label in CODE_TO_NAME else raw_label | |
| else: | |
| name = raw_label | |
| table_data.append([name, f"{score:.2%}"]) | |
| if not table_data: | |
| return pd.DataFrame(columns=["Prediction", "Confidence"]) | |
| return pd.DataFrame(table_data, columns=["Prediction", "Confidence"]) | |
| df_fam = build_df(fam_probs, fam_k, fam_thresh, label_extractor.family_labels) | |
| df_super = build_df(super_probs, super_k, super_thresh, label_extractor.super_labels) | |
| df_code = build_df(code_probs, code_k, code_thresh, label_extractor.code_labels, apply_mapping=True) | |
| print(f"[LOG] Final Memory: {get_mem_usage():.2f} MB") | |
| print(f"--- [LOG] Request Finished ---\n") | |
| return df_fam, df_super, df_code | |
| except Exception as e: | |
| print(f"Error during inference: {e}") | |
| raise gr.Error(f"Processing failed: {str(e)}") | |
| # === UI COMPONENTS === | |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; padding: 30px; background: linear-gradient(120deg, rgb(2, 132, 199) 0%, rgb(16, 185, 129) 60%, rgb(5, 150, 105) 100%); color: white; border-radius: 15px; margin-bottom: 25px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);"> | |
| <h1 style="color: white; margin: 0; font-size: 2.5em;">🦡 Tachiwin Language Identifier 🦡</h1> | |
| <p style="font-size: 1.2em; opacity: 0.9; margin-top: 10px;">Identify any of the 68 languages of Mexico and their 360 variants</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 🎙️ 1. Input Audio") | |
| audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", # Changed from numpy to filepath | |
| label="Upload or Record" | |
| ) | |
| with gr.Accordion("⚙️ Advanced Options", open=False): | |
| with gr.Group(): | |
| gr.Markdown("#### Language Family") | |
| with gr.Row(): | |
| fam_k = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Top-K") | |
| fam_thresh = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Threshold") | |
| with gr.Group(): | |
| gr.Markdown("#### Superlanguage") | |
| with gr.Row(): | |
| super_k = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Top-K") | |
| super_thresh = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Threshold") | |
| with gr.Group(): | |
| gr.Markdown("#### Language Code") | |
| with gr.Row(): | |
| code_k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Top-K") | |
| code_thresh = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Threshold") | |
| with gr.Row(): | |
| clear_btn = gr.Button("🗑️ Clear", variant="secondary") | |
| submit_btn = gr.Button("🚀 Classify", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📊 2. Classification Results") | |
| fam_table = gr.Dataframe(headers=["Prediction", "Confidence"], datatype=["str", "str"], label="🌍 Language Family", interactive=False, wrap=True) | |
| super_table = gr.Dataframe(headers=["Prediction", "Confidence"], datatype=["str", "str"], label="🗣️ Superlanguage", interactive=False, wrap=True) | |
| code_table = gr.Dataframe(headers=["Prediction", "Confidence"], datatype=["str", "str"], label="🔤 Language Code", interactive=False, wrap=True) | |
| submit_btn.click( | |
| fn=predict_language, | |
| inputs=[audio_input, fam_k, fam_thresh, super_k, super_thresh, code_k, code_thresh], | |
| outputs=[fam_table, super_table, code_table] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, None, None, None), | |
| inputs=None, | |
| outputs=[audio_input, fam_table, super_table, code_table] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### ℹ️ About this Model | |
| Tachiwin Multilingual Language Classifier is a finetune/fork or encoded-only whisper architecture trained to recognize any of the 68 indigenous superlanguages of México and their 360 variants. | |
| **Accuracy Overview:** | |
| - **Language Family**: ~73% | |
| - **Superlanguage**: ~59% | |
| - **Language Code**: ~52% | |
| *Developed by Tachiwin. May the indigenous languages never be lost.* | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |