Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| import torchaudio | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| WhisperFeatureExtractor, | |
| ) | |
| CHECKPOINT_DIR = "zai-org/GLM-ASR-Nano-2512" | |
| TOKENIZER_PATH = None | |
| MAX_NEW_TOKENS = 128 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| WHISPER_FEAT_CFG = { | |
| "chunk_length": 30, | |
| "feature_extractor_type": "WhisperFeatureExtractor", | |
| "feature_size": 128, | |
| "hop_length": 160, | |
| "n_fft": 400, | |
| "n_samples": 480000, | |
| "nb_max_frames": 3000, | |
| "padding_side": "right", | |
| "padding_value": 0.0, | |
| "processor_class": "WhisperProcessor", | |
| "return_attention_mask": False, | |
| "sampling_rate": 16000, | |
| } | |
| def get_audio_token_length(seconds, merge_factor=2): | |
| def get_T_after_cnn(L_in, dilation=1): | |
| for padding, kernel_size, stride in eval("[(1,3,1)] + [(1,3,2)] "): | |
| L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 | |
| L_out = 1 + L_out // stride | |
| L_in = L_out | |
| return L_out | |
| mel_len = int(seconds * 100) | |
| audio_len_after_cnn = get_T_after_cnn(mel_len) | |
| audio_token_num = (audio_len_after_cnn - merge_factor) // merge_factor + 1 | |
| audio_token_num = min(audio_token_num, 1500 // merge_factor) | |
| return audio_token_num | |
| def build_prompt( | |
| audio_path: Path, | |
| tokenizer, | |
| feature_extractor: WhisperFeatureExtractor, | |
| merge_factor: int, | |
| chunk_seconds: int = 30, | |
| ) -> dict: | |
| wav, sr = torchaudio.load(str(audio_path)) | |
| wav = wav[:1, :] | |
| if sr != feature_extractor.sampling_rate: | |
| wav = torchaudio.transforms.Resample(sr, feature_extractor.sampling_rate)(wav) | |
| tokens = [] | |
| tokens += tokenizer.encode("<|user|>") | |
| tokens += tokenizer.encode("\n") | |
| audios = [] | |
| audio_offsets = [] | |
| audio_length = [] | |
| chunk_size = chunk_seconds * feature_extractor.sampling_rate | |
| for start in range(0, wav.shape[1], chunk_size): | |
| chunk = wav[:, start : start + chunk_size] | |
| mel = feature_extractor( | |
| chunk.numpy(), | |
| sampling_rate=feature_extractor.sampling_rate, | |
| return_tensors="pt", | |
| padding="max_length", | |
| )["input_features"] | |
| audios.append(mel) | |
| seconds = chunk.shape[1] / feature_extractor.sampling_rate | |
| num_tokens = get_audio_token_length(seconds, merge_factor) | |
| tokens += tokenizer.encode("<|begin_of_audio|>") | |
| audio_offsets.append(len(tokens)) | |
| tokens += [0] * num_tokens | |
| tokens += tokenizer.encode("<|end_of_audio|>") | |
| audio_length.append(num_tokens) | |
| if not audios: | |
| raise gr.Error("Audio content is empty or failed to load.") | |
| tokens += tokenizer.encode("<|user|>") | |
| tokens += tokenizer.encode("\nPlease transcribe this audio into text") | |
| tokens += tokenizer.encode("<|assistant|>") | |
| tokens += tokenizer.encode("\n") | |
| batch = { | |
| "input_ids": torch.tensor([tokens], dtype=torch.long), | |
| "audios": torch.cat(audios, dim=0), | |
| "audio_offsets": [audio_offsets], | |
| "audio_length": [audio_length], | |
| "attention_mask": torch.ones(1, len(tokens), dtype=torch.long), | |
| } | |
| return batch | |
| def prepare_inputs(batch: dict, device: torch.device) -> tuple[dict, int]: | |
| tokens = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| audios = batch["audios"].to(device) | |
| model_inputs = { | |
| "inputs": tokens, | |
| "attention_mask": attention_mask, | |
| "audios": audios.to(torch.bfloat16), | |
| "audio_offsets": batch["audio_offsets"], | |
| "audio_length": batch["audio_length"], | |
| } | |
| return model_inputs, tokens.size(1) | |
| # Model Loading | |
| print(f"Loading model from {CHECKPOINT_DIR} to device {DEVICE}...") | |
| try: | |
| # 1. Load Tokenizer & Feature Extractor | |
| tokenizer_source = TOKENIZER_PATH if TOKENIZER_PATH else CHECKPOINT_DIR | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_source) | |
| feature_extractor = WhisperFeatureExtractor(**WHISPER_FEAT_CFG) | |
| config = AutoConfig.from_pretrained(CHECKPOINT_DIR, trust_remote_code=True) | |
| # 2. Load Model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| CHECKPOINT_DIR, | |
| config=config, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| ).to(DEVICE) | |
| model.eval() | |
| # 3. Get merge_factor for build_prompt | |
| MERGE_FACTOR = config.merge_factor | |
| except Exception as e: | |
| print(f"Failed to load model/tokenizer: {e}") | |
| # Define placeholder variables to allow the Gradio UI to launch for testing/setup | |
| tokenizer, feature_extractor, model, MERGE_FACTOR = None, None, None, 2 | |
| # This exception will be caught during the transcription step below. | |
| def transcribe_wrapper(audio_file_path): | |
| """ | |
| Wraps the core transcription logic for Gradio. | |
| Gradio provides the audio as a temporary file path. | |
| """ | |
| if model is None: | |
| raise gr.Error("Model failed to load. Please check CHECKPOINT_DIR.") | |
| if audio_file_path is None: | |
| return "[Please upload an audio file or record one.]" | |
| try: | |
| audio_path = Path(audio_file_path) | |
| # Build the prompt (tokenize text, process audio, and integrate audio tokens) | |
| batch = build_prompt( | |
| audio_path, | |
| tokenizer, | |
| feature_extractor, | |
| merge_factor=MERGE_FACTOR, | |
| ) | |
| # Prepare inputs for the model | |
| model_inputs, prompt_len = prepare_inputs(batch, DEVICE) | |
| # Run inference (text generation) | |
| with torch.inference_mode(): | |
| generated = model.generate( | |
| **model_inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| ) | |
| # Decode and return the result | |
| transcript_ids = generated[0, prompt_len:].cpu().tolist() | |
| transcript = tokenizer.decode(transcript_ids, skip_special_tokens=True).strip() | |
| return transcript or "[Empty transcription]" | |
| except Exception as e: | |
| print(f"Transcription error: {e}") | |
| return f"An error occurred during transcription: {e}" | |
| # Gradio page | |
| title = "✨ GLM-ASR-Nano-2512 Transcription Demo" | |
| description = ( | |
| "This demo uses the sota new GLM-ASR Nano model to transcribe audio files with great accuracy! The architecture is simple and efficient, composed of a whisper encoder and an llm. Upload an audio file (or record one) to transcribe it into text using the model." | |
| ) | |
| # Define the Gradio Interface components | |
| audio_input = gr.Audio( | |
| type="filepath", | |
| label="Audio Input (WAV/MP3)", | |
| sources=["upload", "microphone"] | |
| ) | |
| output_text = gr.Textbox(label="Transcription Result", lines=5) | |
| # Create the Interface | |
| demo = gr.Interface( | |
| fn=transcribe_wrapper, | |
| inputs=[audio_input], | |
| outputs=[output_text], | |
| title=title, | |
| description=description, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |