import spaces import gradio as gr import argparse from pathlib import Path import torch import torchaudio from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, WhisperFeatureExtractor, ) CHECKPOINT_DIR = "zai-org/GLM-ASR-Nano-2512" TOKENIZER_PATH = None MAX_NEW_TOKENS = 128 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" WHISPER_FEAT_CFG = { "chunk_length": 30, "feature_extractor_type": "WhisperFeatureExtractor", "feature_size": 128, "hop_length": 160, "n_fft": 400, "n_samples": 480000, "nb_max_frames": 3000, "padding_side": "right", "padding_value": 0.0, "processor_class": "WhisperProcessor", "return_attention_mask": False, "sampling_rate": 16000, } def get_audio_token_length(seconds, merge_factor=2): def get_T_after_cnn(L_in, dilation=1): for padding, kernel_size, stride in eval("[(1,3,1)] + [(1,3,2)] "): L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 L_out = 1 + L_out // stride L_in = L_out return L_out mel_len = int(seconds * 100) audio_len_after_cnn = get_T_after_cnn(mel_len) audio_token_num = (audio_len_after_cnn - merge_factor) // merge_factor + 1 audio_token_num = min(audio_token_num, 1500 // merge_factor) return audio_token_num def build_prompt( audio_path: Path, tokenizer, feature_extractor: WhisperFeatureExtractor, merge_factor: int, chunk_seconds: int = 30, ) -> dict: wav, sr = torchaudio.load(str(audio_path)) wav = wav[:1, :] if sr != feature_extractor.sampling_rate: wav = torchaudio.transforms.Resample(sr, feature_extractor.sampling_rate)(wav) tokens = [] tokens += tokenizer.encode("<|user|>") tokens += tokenizer.encode("\n") audios = [] audio_offsets = [] audio_length = [] chunk_size = chunk_seconds * feature_extractor.sampling_rate for start in range(0, wav.shape[1], chunk_size): chunk = wav[:, start : start + chunk_size] mel = feature_extractor( chunk.numpy(), sampling_rate=feature_extractor.sampling_rate, return_tensors="pt", padding="max_length", )["input_features"] audios.append(mel) seconds = chunk.shape[1] / feature_extractor.sampling_rate num_tokens = get_audio_token_length(seconds, merge_factor) tokens += tokenizer.encode("<|begin_of_audio|>") audio_offsets.append(len(tokens)) tokens += [0] * num_tokens tokens += tokenizer.encode("<|end_of_audio|>") audio_length.append(num_tokens) if not audios: raise gr.Error("Audio content is empty or failed to load.") tokens += tokenizer.encode("<|user|>") tokens += tokenizer.encode("\nPlease transcribe this audio into text") tokens += tokenizer.encode("<|assistant|>") tokens += tokenizer.encode("\n") batch = { "input_ids": torch.tensor([tokens], dtype=torch.long), "audios": torch.cat(audios, dim=0), "audio_offsets": [audio_offsets], "audio_length": [audio_length], "attention_mask": torch.ones(1, len(tokens), dtype=torch.long), } return batch def prepare_inputs(batch: dict, device: torch.device) -> tuple[dict, int]: tokens = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) audios = batch["audios"].to(device) model_inputs = { "inputs": tokens, "attention_mask": attention_mask, "audios": audios.to(torch.bfloat16), "audio_offsets": batch["audio_offsets"], "audio_length": batch["audio_length"], } return model_inputs, tokens.size(1) # Model Loading print(f"Loading model from {CHECKPOINT_DIR} to device {DEVICE}...") try: # 1. Load Tokenizer & Feature Extractor tokenizer_source = TOKENIZER_PATH if TOKENIZER_PATH else CHECKPOINT_DIR tokenizer = AutoTokenizer.from_pretrained(tokenizer_source) feature_extractor = WhisperFeatureExtractor(**WHISPER_FEAT_CFG) config = AutoConfig.from_pretrained(CHECKPOINT_DIR, trust_remote_code=True) # 2. Load Model model = AutoModelForCausalLM.from_pretrained( CHECKPOINT_DIR, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True, ).to(DEVICE) model.eval() # 3. Get merge_factor for build_prompt MERGE_FACTOR = config.merge_factor except Exception as e: print(f"Failed to load model/tokenizer: {e}") # Define placeholder variables to allow the Gradio UI to launch for testing/setup tokenizer, feature_extractor, model, MERGE_FACTOR = None, None, None, 2 # This exception will be caught during the transcription step below. @spaces.GPU(duration=60) def transcribe_wrapper(audio_file_path): """ Wraps the core transcription logic for Gradio. Gradio provides the audio as a temporary file path. """ if model is None: raise gr.Error("Model failed to load. Please check CHECKPOINT_DIR.") if audio_file_path is None: return "[Please upload an audio file or record one.]" try: audio_path = Path(audio_file_path) # Build the prompt (tokenize text, process audio, and integrate audio tokens) batch = build_prompt( audio_path, tokenizer, feature_extractor, merge_factor=MERGE_FACTOR, ) # Prepare inputs for the model model_inputs, prompt_len = prepare_inputs(batch, DEVICE) # Run inference (text generation) with torch.inference_mode(): generated = model.generate( **model_inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, ) # Decode and return the result transcript_ids = generated[0, prompt_len:].cpu().tolist() transcript = tokenizer.decode(transcript_ids, skip_special_tokens=True).strip() return transcript or "[Empty transcription]" except Exception as e: print(f"Transcription error: {e}") return f"An error occurred during transcription: {e}" # Gradio page title = "✨ GLM-ASR-Nano-2512 Transcription Demo" description = ( "This demo uses the sota new GLM-ASR Nano model to transcribe audio files with great accuracy! The architecture is simple and efficient, composed of a whisper encoder and an llm. Upload an audio file (or record one) to transcribe it into text using the model." ) # Define the Gradio Interface components audio_input = gr.Audio( type="filepath", label="Audio Input (WAV/MP3)", sources=["upload", "microphone"] ) output_text = gr.Textbox(label="Transcription Result", lines=5) # Create the Interface demo = gr.Interface( fn=transcribe_wrapper, inputs=[audio_input], outputs=[output_text], title=title, description=description, ) if __name__ == "__main__": demo.launch()