GLM-ASR-Nano / app.py
YatharthS's picture
Update app.py
671c1b0 verified
import spaces
import gradio as gr
import argparse
from pathlib import Path
import torch
import torchaudio
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
WhisperFeatureExtractor,
)
CHECKPOINT_DIR = "zai-org/GLM-ASR-Nano-2512"
TOKENIZER_PATH = None
MAX_NEW_TOKENS = 128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
WHISPER_FEAT_CFG = {
"chunk_length": 30,
"feature_extractor_type": "WhisperFeatureExtractor",
"feature_size": 128,
"hop_length": 160,
"n_fft": 400,
"n_samples": 480000,
"nb_max_frames": 3000,
"padding_side": "right",
"padding_value": 0.0,
"processor_class": "WhisperProcessor",
"return_attention_mask": False,
"sampling_rate": 16000,
}
def get_audio_token_length(seconds, merge_factor=2):
def get_T_after_cnn(L_in, dilation=1):
for padding, kernel_size, stride in eval("[(1,3,1)] + [(1,3,2)] "):
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
return L_out
mel_len = int(seconds * 100)
audio_len_after_cnn = get_T_after_cnn(mel_len)
audio_token_num = (audio_len_after_cnn - merge_factor) // merge_factor + 1
audio_token_num = min(audio_token_num, 1500 // merge_factor)
return audio_token_num
def build_prompt(
audio_path: Path,
tokenizer,
feature_extractor: WhisperFeatureExtractor,
merge_factor: int,
chunk_seconds: int = 30,
) -> dict:
wav, sr = torchaudio.load(str(audio_path))
wav = wav[:1, :]
if sr != feature_extractor.sampling_rate:
wav = torchaudio.transforms.Resample(sr, feature_extractor.sampling_rate)(wav)
tokens = []
tokens += tokenizer.encode("<|user|>")
tokens += tokenizer.encode("\n")
audios = []
audio_offsets = []
audio_length = []
chunk_size = chunk_seconds * feature_extractor.sampling_rate
for start in range(0, wav.shape[1], chunk_size):
chunk = wav[:, start : start + chunk_size]
mel = feature_extractor(
chunk.numpy(),
sampling_rate=feature_extractor.sampling_rate,
return_tensors="pt",
padding="max_length",
)["input_features"]
audios.append(mel)
seconds = chunk.shape[1] / feature_extractor.sampling_rate
num_tokens = get_audio_token_length(seconds, merge_factor)
tokens += tokenizer.encode("<|begin_of_audio|>")
audio_offsets.append(len(tokens))
tokens += [0] * num_tokens
tokens += tokenizer.encode("<|end_of_audio|>")
audio_length.append(num_tokens)
if not audios:
raise gr.Error("Audio content is empty or failed to load.")
tokens += tokenizer.encode("<|user|>")
tokens += tokenizer.encode("\nPlease transcribe this audio into text")
tokens += tokenizer.encode("<|assistant|>")
tokens += tokenizer.encode("\n")
batch = {
"input_ids": torch.tensor([tokens], dtype=torch.long),
"audios": torch.cat(audios, dim=0),
"audio_offsets": [audio_offsets],
"audio_length": [audio_length],
"attention_mask": torch.ones(1, len(tokens), dtype=torch.long),
}
return batch
def prepare_inputs(batch: dict, device: torch.device) -> tuple[dict, int]:
tokens = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
audios = batch["audios"].to(device)
model_inputs = {
"inputs": tokens,
"attention_mask": attention_mask,
"audios": audios.to(torch.bfloat16),
"audio_offsets": batch["audio_offsets"],
"audio_length": batch["audio_length"],
}
return model_inputs, tokens.size(1)
# Model Loading
print(f"Loading model from {CHECKPOINT_DIR} to device {DEVICE}...")
try:
# 1. Load Tokenizer & Feature Extractor
tokenizer_source = TOKENIZER_PATH if TOKENIZER_PATH else CHECKPOINT_DIR
tokenizer = AutoTokenizer.from_pretrained(tokenizer_source)
feature_extractor = WhisperFeatureExtractor(**WHISPER_FEAT_CFG)
config = AutoConfig.from_pretrained(CHECKPOINT_DIR, trust_remote_code=True)
# 2. Load Model
model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT_DIR,
config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(DEVICE)
model.eval()
# 3. Get merge_factor for build_prompt
MERGE_FACTOR = config.merge_factor
except Exception as e:
print(f"Failed to load model/tokenizer: {e}")
# Define placeholder variables to allow the Gradio UI to launch for testing/setup
tokenizer, feature_extractor, model, MERGE_FACTOR = None, None, None, 2
# This exception will be caught during the transcription step below.
@spaces.GPU(duration=60)
def transcribe_wrapper(audio_file_path):
"""
Wraps the core transcription logic for Gradio.
Gradio provides the audio as a temporary file path.
"""
if model is None:
raise gr.Error("Model failed to load. Please check CHECKPOINT_DIR.")
if audio_file_path is None:
return "[Please upload an audio file or record one.]"
try:
audio_path = Path(audio_file_path)
# Build the prompt (tokenize text, process audio, and integrate audio tokens)
batch = build_prompt(
audio_path,
tokenizer,
feature_extractor,
merge_factor=MERGE_FACTOR,
)
# Prepare inputs for the model
model_inputs, prompt_len = prepare_inputs(batch, DEVICE)
# Run inference (text generation)
with torch.inference_mode():
generated = model.generate(
**model_inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
# Decode and return the result
transcript_ids = generated[0, prompt_len:].cpu().tolist()
transcript = tokenizer.decode(transcript_ids, skip_special_tokens=True).strip()
return transcript or "[Empty transcription]"
except Exception as e:
print(f"Transcription error: {e}")
return f"An error occurred during transcription: {e}"
# Gradio page
title = "✨ GLM-ASR-Nano-2512 Transcription Demo"
description = (
"This demo uses the sota new GLM-ASR Nano model to transcribe audio files with great accuracy! The architecture is simple and efficient, composed of a whisper encoder and an llm. Upload an audio file (or record one) to transcribe it into text using the model."
)
# Define the Gradio Interface components
audio_input = gr.Audio(
type="filepath",
label="Audio Input (WAV/MP3)",
sources=["upload", "microphone"]
)
output_text = gr.Textbox(label="Transcription Result", lines=5)
# Create the Interface
demo = gr.Interface(
fn=transcribe_wrapper,
inputs=[audio_input],
outputs=[output_text],
title=title,
description=description,
)
if __name__ == "__main__":
demo.launch()