import os import gc import gradio as gr import torch import numpy as np import librosa import pandas as pd from transformers import WhisperProcessor, AutoConfig, AutoModel, WhisperConfig, WhisperPreTrainedModel from transformers.models.whisper.modeling_whisper import WhisperEncoder import torch.nn as nn import psutil import json # --- CONFIGURATION --- MAX_AUDIO_SECONDS = 30 torch.set_num_threads(1) # === CUSTOM MODEL CLASSES === class WhisperEncoderOnlyConfig(WhisperConfig): model_type = "whisper_encoder_classifier" def __init__(self, n_fam=None, n_super=None, n_code=None, **kwargs): super().__init__(**kwargs) self.n_fam = n_fam self.n_super = n_super self.n_code = n_code class WhisperEncoderOnlyForClassification(WhisperPreTrainedModel): config_class = WhisperEncoderOnlyConfig def __init__(self, config): super().__init__(config) self.encoder = WhisperEncoder(config) hidden = config.d_model self.fam_head = nn.Linear(hidden, config.n_fam) self.super_head = nn.Linear(hidden, config.n_super) self.code_head = nn.Linear(hidden, config.n_code) self.post_init() def get_input_embeddings(self): return None def set_input_embeddings(self, value): pass def enable_input_require_grads(self): return def forward(self, input_features, labels=None): enc_out = self.encoder(input_features=input_features) pooled = enc_out.last_hidden_state.mean(dim=1) fam_logits = self.fam_head(pooled) super_logits = self.super_head(pooled) code_logits = self.code_head(pooled) loss = None if labels is not None: fam_labels, super_labels, code_labels = labels loss_fn = nn.CrossEntropyLoss() loss = ( loss_fn(fam_logits, fam_labels) + loss_fn(super_logits, super_labels) + loss_fn(code_logits, code_labels) ) return { "loss": loss, "fam_logits": fam_logits, "super_logits": super_logits, "code_logits": code_logits, } class LabelExtractor: """ Extracts family/super/code labels from tokenized sequences based on training design. """ def __init__(self, tokenizer): self.tokenizer = tokenizer self.family_tokens = [] self.super_tokens = [] self.code_tokens = [] # Extract special tokens that represent categories from added_vocab for token_str, token_id in tokenizer.get_added_vocab().items(): if token_str.startswith("<|") and token_str.endswith("|>"): if token_str in ["<|startoftranscript|>", "<|endoftext|>", "<|nospeech|>", "<|notimestamps|>"]: continue if token_str.startswith("<|@"): self.family_tokens.append((token_str, token_id)) elif self._is_super_token(token_str): self.super_tokens.append((token_str, token_id)) else: self.code_tokens.append((token_str, token_id)) # Sort by token_id to match model indices self.family_tokens.sort(key=lambda x: x[1]) self.super_tokens.sort(key=lambda x: x[1]) self.code_tokens.sort(key=lambda x: x[1]) # We only need the flat lists of token names for inference mapping self.family_labels = [tok for tok, _ in self.family_tokens] self.super_labels = [tok for tok, _ in self.super_tokens] self.code_labels = [tok for tok, _ in self.code_tokens] print(f"Extracted labels:") print(f" Families: {len(self.family_labels)}") print(f" Superlanguages: {len(self.super_labels)}") print(f" Codes: {len(self.code_labels)}") def _is_super_token(self, token_str): # Based on training heuristic return len(token_str) > 2 and token_str[2].isupper() and not token_str.startswith("<|@") # === REGISTER MODEL === AutoConfig.register("whisper_encoder_classifier", WhisperEncoderOnlyConfig) AutoModel.register(WhisperEncoderOnlyConfig, WhisperEncoderOnlyForClassification) # === LOAD MODEL === MODEL_REPO = "tachiwin/language_classification_enconly_model_2" print("Loading model on CPU...") processor = WhisperProcessor.from_pretrained(MODEL_REPO) model = WhisperEncoderOnlyForClassification.from_pretrained( MODEL_REPO, low_cpu_mem_usage=True ) model.eval() # Initialize LabelExtractor to build text mappings label_extractor = LabelExtractor(processor.tokenizer) # Load languages mapping print("Loading language mappings...") try: with open("languages.json", "r", encoding="utf-8") as f: languages_data = json.load(f) CODE_TO_NAME = {item.get("code"): item.get("inali_name") for item in languages_data if item.get("code") and item.get("inali_name")} except Exception as e: print(f"Warning: Could not load languages.json: {e}") CODE_TO_NAME = {} print("Model loaded successfully!") def get_mem_usage(): process = psutil.Process(os.getpid()) return process.memory_info().rss / (1024 ** 2) # === INFERENCE FUNCTION === def predict_language(audio_path, fam_k=1, fam_thresh=0.0, super_k=1, super_thresh=0.0, code_k=3, code_thresh=0.0): if not audio_path: raise gr.Error("No audio provided! Please upload or record an audio file.") gc.collect() start_mem = get_mem_usage() print(f"\n--- [LOG] New Request ---") print(f"[LOG] Start Memory: {start_mem:.2f} MB") try: print("[LOG] Step 1: Loading and resampling audio from file...") audio_array, sample_rate = librosa.load(audio_path, sr=16000) audio_len_sec = len(audio_array) / 16000 print(f"[LOG] Audio duration: {audio_len_sec:.2f}s, SR: 16000") print(f"[LOG] Memory after load: {get_mem_usage():.2f} MB") if audio_len_sec > MAX_AUDIO_SECONDS: del audio_array gc.collect() raise gr.Error(f"Audio too long ({audio_len_sec:.1f}s). Please upload or record up to {MAX_AUDIO_SECONDS} seconds.") print("[LOG] Step 3: Extracting features...") inputs = processor( audio_array, sampling_rate=16000, return_tensors="pt" ) del audio_array gc.collect() print(f"[LOG] Memory after preprocessing: {get_mem_usage():.2f} MB") print("[LOG] Step 4: Running model inference...") with torch.no_grad(): outputs = model(input_features=inputs.input_features) del inputs gc.collect() print("[LOG] Step 5: Post-processing results...") fam_probs = torch.softmax(outputs["fam_logits"], dim=-1) super_probs = torch.softmax(outputs["super_logits"], dim=-1) code_probs = torch.softmax(outputs["code_logits"], dim=-1) def build_df(probs_tensor, k, thresh, labels_list, apply_mapping=False): k = int(k) top_vals, top_idx = torch.topk(probs_tensor[0], min(k, probs_tensor.shape[-1])) table_data = [] for i in range(len(top_vals)): score = top_vals[i].item() if score < thresh: continue idx = top_idx[i].item() raw_label = labels_list[idx].strip("<|>") if idx < len(labels_list) else f"Unknown ({idx})" if apply_mapping: name = f"{CODE_TO_NAME[raw_label]} ({raw_label})" if raw_label in CODE_TO_NAME else raw_label else: name = raw_label table_data.append([name, f"{score:.2%}"]) if not table_data: return pd.DataFrame(columns=["Prediction", "Confidence"]) return pd.DataFrame(table_data, columns=["Prediction", "Confidence"]) df_fam = build_df(fam_probs, fam_k, fam_thresh, label_extractor.family_labels) df_super = build_df(super_probs, super_k, super_thresh, label_extractor.super_labels) df_code = build_df(code_probs, code_k, code_thresh, label_extractor.code_labels, apply_mapping=True) print(f"[LOG] Final Memory: {get_mem_usage():.2f} MB") print(f"--- [LOG] Request Finished ---\n") return df_fam, df_super, df_code except Exception as e: print(f"Error during inference: {e}") raise gr.Error(f"Processing failed: {str(e)}") # === UI COMPONENTS === with gr.Blocks(theme=gr.themes.Ocean()) as demo: gr.HTML( """

🦑 Tachiwin Language Identifier 🦑

Identify any of the 68 languages of Mexico and their 360 variants

""" ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### πŸŽ™οΈ 1. Input Audio") audio_input = gr.Audio( sources=["upload", "microphone"], type="filepath", # Changed from numpy to filepath label="Upload or Record" ) with gr.Accordion("βš™οΈ Advanced Options", open=False): with gr.Group(): gr.Markdown("#### Language Family") with gr.Row(): fam_k = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Top-K") fam_thresh = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Threshold") with gr.Group(): gr.Markdown("#### Superlanguage") with gr.Row(): super_k = gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Top-K") super_thresh = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Threshold") with gr.Group(): gr.Markdown("#### Language Code") with gr.Row(): code_k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Top-K") code_thresh = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.0, label="Threshold") with gr.Row(): clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary") submit_btn = gr.Button("πŸš€ Classify", variant="primary") with gr.Column(scale=1): gr.Markdown("### πŸ“Š 2. Classification Results") fam_table = gr.Dataframe(headers=["Prediction", "Confidence"], datatype=["str", "str"], label="🌍 Language Family", interactive=False, wrap=True) super_table = gr.Dataframe(headers=["Prediction", "Confidence"], datatype=["str", "str"], label="πŸ—£οΈ Superlanguage", interactive=False, wrap=True) code_table = gr.Dataframe(headers=["Prediction", "Confidence"], datatype=["str", "str"], label="πŸ”€ Language Code", interactive=False, wrap=True) submit_btn.click( fn=predict_language, inputs=[audio_input, fam_k, fam_thresh, super_k, super_thresh, code_k, code_thresh], outputs=[fam_table, super_table, code_table] ) clear_btn.click( fn=lambda: (None, None, None, None), inputs=None, outputs=[audio_input, fam_table, super_table, code_table] ) gr.Markdown( """ --- ### ℹ️ About this Model Tachiwin Multilingual Language Classifier is a finetune/fork or encoded-only whisper architecture trained to recognize any of the 68 indigenous superlanguages of MΓ©xico and their 360 variants. **Accuracy Overview:** - **Language Family**: ~73% - **Superlanguage**: ~59% - **Language Code**: ~52% *Developed by Tachiwin. May the indigenous languages never be lost.* """ ) if __name__ == "__main__": demo.launch(ssr_mode=False)