GLM-ASR-Nano-2512 / inference.py
ZHANGYUXUAN-zR's picture
Add files using upload-large-folder tool
05e39c3 verified
import argparse
from pathlib import Path
import torch
import torchaudio
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
WhisperFeatureExtractor,
)
WHISPER_FEAT_CFG = {
"chunk_length": 30,
"feature_extractor_type": "WhisperFeatureExtractor",
"feature_size": 128,
"hop_length": 160,
"n_fft": 400,
"n_samples": 480000,
"nb_max_frames": 3000,
"padding_side": "right",
"padding_value": 0.0,
"processor_class": "WhisperProcessor",
"return_attention_mask": False,
"sampling_rate": 16000,
}
def get_audio_token_length(seconds, merge_factor=2):
def get_T_after_cnn(L_in, dilation=1):
for padding, kernel_size, stride in eval("[(1,3,1)] + [(1,3,2)] "):
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
L_in = L_out
return L_out
mel_len = int(seconds * 100)
audio_len_after_cnn = get_T_after_cnn(mel_len)
audio_token_num = (audio_len_after_cnn - merge_factor) // merge_factor + 1
# TODO: current whisper model can't process longer sequence, maybe cut chunk in the future
audio_token_num = min(audio_token_num, 1500 // merge_factor)
return audio_token_num
def build_prompt(
audio_path: Path,
tokenizer,
feature_extractor: WhisperFeatureExtractor,
merge_factor: int,
chunk_seconds: int = 30,
) -> dict:
audio_path = Path(audio_path)
wav, sr = torchaudio.load(str(audio_path))
wav = wav[:1, :]
if sr != feature_extractor.sampling_rate:
wav = torchaudio.transforms.Resample(sr, feature_extractor.sampling_rate)(wav)
tokens = []
tokens += tokenizer.encode("<|user|>")
tokens += tokenizer.encode("\n")
audios = []
audio_offsets = []
audio_length = []
chunk_size = chunk_seconds * feature_extractor.sampling_rate
for start in range(0, wav.shape[1], chunk_size):
chunk = wav[:, start : start + chunk_size]
mel = feature_extractor(
chunk.numpy(),
sampling_rate=feature_extractor.sampling_rate,
return_tensors="pt",
padding="max_length",
)["input_features"]
audios.append(mel)
seconds = chunk.shape[1] / feature_extractor.sampling_rate
num_tokens = get_audio_token_length(seconds, merge_factor)
tokens += tokenizer.encode("<|begin_of_audio|>")
audio_offsets.append(len(tokens))
tokens += [0] * num_tokens
tokens += tokenizer.encode("<|end_of_audio|>")
audio_length.append(num_tokens)
if not audios:
raise ValueError("音频内容为空或加载失败。")
tokens += tokenizer.encode("<|user|>")
tokens += tokenizer.encode("\nPlease transcribe this audio into text")
tokens += tokenizer.encode("<|assistant|>")
tokens += tokenizer.encode("\n")
batch = {
"input_ids": torch.tensor([tokens], dtype=torch.long),
"audios": torch.cat(audios, dim=0),
"audio_offsets": [audio_offsets],
"audio_length": [audio_length],
"attention_mask": torch.ones(1, len(tokens), dtype=torch.long),
}
return batch
def prepare_inputs(batch: dict, device: torch.device) -> tuple[dict, int]:
tokens = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
audios = batch["audios"].to(device)
model_inputs = {
"inputs": tokens,
"attention_mask": attention_mask,
"audios": audios.to(torch.bfloat16),
"audio_offsets": batch["audio_offsets"],
"audio_length": batch["audio_length"],
}
return model_inputs, tokens.size(1)
def transcribe(
checkpoint_dir: Path,
audio_path: Path,
tokenizer_path: str | None,
max_new_tokens: int,
device: str,
):
tokenizer_source = tokenizer_path if tokenizer_path else checkpoint_dir
tokenizer = AutoTokenizer.from_pretrained(tokenizer_source)
feature_extractor = WhisperFeatureExtractor(**WHISPER_FEAT_CFG)
config = AutoConfig.from_pretrained(checkpoint_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
checkpoint_dir,
config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(device)
model.eval()
batch = build_prompt(
audio_path,
tokenizer,
feature_extractor,
merge_factor=config.merge_factor,
)
model_inputs, prompt_len = prepare_inputs(batch, device)
with torch.inference_mode():
generated = model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
)
transcript_ids = generated[0, prompt_len:].cpu().tolist()
transcript = tokenizer.decode(transcript_ids, skip_special_tokens=True).strip()
print("----------")
print(transcript or "[Empty transcription]")
def main():
parser = argparse.ArgumentParser(description="Minimal ASR transcription demo.")
parser.add_argument("--checkpoint_dir", type=str, default=str(Path(__file__).parent))
parser.add_argument("--audio", type=str, required=True, help="Path to audio file.")
parser.add_argument(
"--tokenizer_path",
type=str,
default=None,
help="Tokenizer directory (defaults to checkpoint dir when omitted).",
)
parser.add_argument("--max_new_tokens", type=int, default=128)
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
transcribe(
checkpoint_dir=Path(args.checkpoint_dir),
audio_path=Path(args.audio),
tokenizer_path=args.tokenizer_path,
max_new_tokens=args.max_new_tokens,
device=args.device,
)
if __name__ == "__main__":
main()