STT-meta-ZH-100m

A dual-head Mandarin Chinese ASR model that simultaneously performs speech-to-text transcription and speaker attribute classification (age, gender, dialect) in a single forward pass.

Built on NVIDIA Citrinet-1024 with language-specific bottleneck adapters and a trailing tag classifier head, fine-tuned on 60 hours of meta-annotated Mandarin speech data using PromptingNemo.

Metric Value
Parameters 157.7M
WER 19.22%
Tag Accuracy 94.2%
Language Mandarin Chinese (zh)
Audio 16kHz mono

Architecture

Audio (16kHz) ──▶ Mel Spectrogram (80-dim) ──▶ Citrinet-1024 Encoder (23 blocks)
                                                    │
                                          ┌─────────┴─────────┐
                                          ▼                   ▼
                                    CTC Decoder          Tag Classifier
                                   (5001 vocab)        (3 linear heads)
                                        │                    │
                                        ▼                    ▼
                                  Transcription +      AGE / GENDER /
                                  Entity Tags          DIALECT labels

Parameter Breakdown

Component Parameters Description
Citrinet-1024 Encoder 140.4M 23 Jasper-style blocks with squeeze-excitation
Language Adapter 12.1M Bottleneck adapters (dim=256) in each encoder block
CTC Decoder 5.1M Conv1d projecting 1024 → 5001 (BPE vocab + blank)
Tag Classifier 12.3K 3 linear heads on mean-pooled encoder output
Total 157.7M

Tag Categories

Category Classes Labels
AGE 5 NONE, AGE_14_25, AGE_26_40, AGE_<14, AGE_>41
GENDER 3 NONE, GENDER_FEMALE, GENDER_MALE
DIALECT 4 NONE, DIALECT_NORTH, DIALECT_OTHERS, DIALECT_SOUTH

The CTC head also outputs inline entity tags (e.g., ENTITY_PERSON_NAME ... END, ENTITY_TEMPERATURE ... END) as part of the transcription vocabulary.

Files

File Description
zh-citrinet-meta-v11.nemo Full NeMo checkpoint (encoder + decoder + adapter + tag classifier)
onnx/model.onnx ONNX model with dual outputs: logprobs (CTC) + encoder_output
onnx/tag_classifier.onnx Standalone tag classifier (input: pooled encoder features)
onnx/tag_classifier.json Tag classifier metadata (labels, class counts)
onnx/config.json Preprocessor configuration (mel spectrogram parameters)
onnx/tokenizer.model SentencePiece BPE tokenizer (5000 tokens)
onnx/vocabulary.json Full vocabulary list with token mappings

Usage

NeMo Inference

import nemo.collections.asr as nemo_asr

# Standard NeMo transcription (CTC head only — tag classifier weights
# are stored in the checkpoint but EncDecCTCModelBPE does not load them
# by default). For full dual-head inference, use ONNX or PromptingNemo.
asr_model = nemo_asr.models.ASRModel.from_pretrained(
    "WhissleAI/STT-meta-ZH-100m"
)

transcriptions = asr_model.transcribe(["audio.wav"])
print(transcriptions[0])
# Output includes inline tags:
# "你好世界。 AGE_26_40 GENDER_MALE ENTITY_PERSON_NAME 张三 END"

PromptingNemo Inference (Full Dual-Head)

For full dual-head inference with the tag classifier, use the PromptingNemo training framework:

# Clone PromptingNemo
# git clone https://github.com/WhissleAI/PromptingNemo.git

import torch
from huggingface_hub import hf_hub_download

# Download the .nemo checkpoint
nemo_path = hf_hub_download(
    repo_id="WhissleAI/STT-meta-ZH-100m",
    filename="zh-citrinet-meta-v11.nemo"
)

# Load with PromptingNemo's custom model class that includes the tag classifier
# See: https://github.com/WhissleAI/PromptingNemo/blob/main/scripts/asr/meta-asr
from scripts.asr.meta_asr.tag_classifier import (
    TrailingTagClassifier,
    build_trailing_tag_maps,
    masked_mean_pool,
)

# The tag_classifier weights are stored inside the .nemo archive.
# PromptingNemo's training script loads them automatically.

ONNX Inference (Production — Recommended)

Self-contained inference using only onnxruntime, numpy, soundfile, and sentencepiece:

import json
import numpy as np
import onnxruntime as ort
import soundfile as sf
import sentencepiece as spm
from huggingface_hub import hf_hub_download

# Download model files
repo = "WhissleAI/STT-meta-ZH-100m"
model_path = hf_hub_download(repo, "onnx/model.onnx")
cls_path = hf_hub_download(repo, "onnx/tag_classifier.onnx")
cls_meta_path = hf_hub_download(repo, "onnx/tag_classifier.json")
tok_path = hf_hub_download(repo, "onnx/tokenizer.model")
vocab_path = hf_hub_download(repo, "onnx/vocabulary.json")
config_path = hf_hub_download(repo, "onnx/config.json")

# Load config and vocabulary
with open(config_path) as f:
    config = json.load(f)
with open(vocab_path) as f:
    vocab_data = json.load(f)
with open(cls_meta_path) as f:
    cls_meta = json.load(f)

vocabulary = vocab_data["vocabulary"]
blank_id = vocab_data.get("blank_id", len(vocabulary))

# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.Load(tok_path)

# Load ONNX sessions
asr_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
cls_session = ort.InferenceSession(cls_path, providers=["CPUExecutionProvider"])

# --- Preprocessing ---
def preprocess_audio(audio_path, config):
    """Convert audio to log-mel spectrogram features."""
    audio, sr = sf.read(audio_path, dtype="float32")
    if sr != 16000:
        raise ValueError(f"Expected 16kHz audio, got {sr}Hz")
    if audio.ndim > 1:
        audio = audio.mean(axis=1)

    # Preemphasis
    preemph = config["preprocessor"]["preemph"]
    audio = np.concatenate([[audio[0]], audio[1:] - preemph * audio[:-1]])

    # STFT
    n_fft = config["preprocessor"]["n_fft"]
    hop = config["preprocessor"]["hop_length"]
    win = config["preprocessor"]["win_length"]
    window = np.hanning(win + 1)[:-1].astype(np.float32)

    # Pad audio
    pad_len = (n_fft - hop) // 2
    audio = np.pad(audio, (pad_len, pad_len), mode="reflect")

    frames = []
    for start in range(0, len(audio) - n_fft + 1, hop):
        frame = audio[start : start + n_fft] * np.pad(window, (0, n_fft - win))
        frames.append(np.fft.rfft(frame))
    spec = np.abs(np.array(frames, dtype=np.complex64)) ** 2

    # Mel filterbank
    n_mels = config["preprocessor"]["features"]
    fmin = config["preprocessor"]["lowfreq"]
    fmax = sr / 2 if config["preprocessor"]["highfreq"] is None else config["preprocessor"]["highfreq"]
    mel_points = np.linspace(
        2595 * np.log10(1 + fmin / 700),
        2595 * np.log10(1 + fmax / 700),
        n_mels + 2,
    )
    hz_points = 700 * (10 ** (mel_points / 2595) - 1)
    bins = np.floor((n_fft + 1) * hz_points / sr).astype(int)
    fbank = np.zeros((n_mels, n_fft // 2 + 1))
    for i in range(n_mels):
        for j in range(bins[i], bins[i + 1]):
            fbank[i, j] = (j - bins[i]) / max(bins[i + 1] - bins[i], 1)
        for j in range(bins[i + 1], bins[i + 2]):
            fbank[i, j] = (bins[i + 2] - j) / max(bins[i + 2] - bins[i + 1], 1)

    mel_spec = spec @ fbank.T
    log_mel = np.log(mel_spec + config["preprocessor"]["log_zero_guard_value"])

    # Per-feature normalization
    mean = log_mel.mean(axis=0, keepdims=True)
    std = log_mel.std(axis=0, keepdims=True)
    log_mel = (log_mel - mean) / (std + 1e-5)

    # Pad to multiple of 16
    pad_to = config["preprocessor"].get("pad_to", 16)
    T = log_mel.shape[0]
    if T % pad_to != 0:
        pad_frames = pad_to - (T % pad_to)
        log_mel = np.pad(log_mel, ((0, pad_frames), (0, 0)))

    # Shape: [1, features, time]
    features = log_mel.T[np.newaxis, :, :].astype(np.float32)
    return features, T

# --- Inference ---
features, valid_len = preprocess_audio("audio.wav", config)
length = np.array([features.shape[2]], dtype=np.int64)

# Run ASR model (dual output)
logprobs, encoder_output = asr_session.run(
    ["logprobs", "encoder_output"],
    {"audio_signal": features, "length": length},
)

# Greedy CTC decode
pred_ids = np.argmax(logprobs[0], axis=-1)
# Collapse repeats and remove blanks
decoded_ids = []
prev = -1
for idx in pred_ids:
    if idx != prev and idx != blank_id:
        decoded_ids.append(int(idx))
    prev = idx
transcript = sp.DecodeIds(decoded_ids)
print(f"Transcript: {transcript}")

# --- Tag Classification ---
# encoder_output shape: [1, 1024, T] -> transpose to [1, T, 1024]
enc = encoder_output.transpose(0, 2, 1)
# Masked mean pooling
mask = np.zeros((1, enc.shape[1], 1), dtype=np.float32)
mask[0, :valid_len // 8, :] = 1.0  # Citrinet has 8x downsampling
pooled = (enc * mask).sum(axis=1) / mask.sum(axis=1).clip(min=1)

# Run tag classifier
tag_outputs = cls_session.run(None, {"pooled_encoder": pooled.astype(np.float32)})
categories = cls_meta["categories"]
for cat_name, cat_info in sorted(categories.items()):
    idx = list(sorted(categories.keys())).index(cat_name)
    pred = int(np.argmax(tag_outputs[idx][0]))
    label = cat_info["labels"][pred]
    print(f"  {cat_name}: {label}")

Example output:

Transcript: 来首歌吻别。 AGE_14_25 GENDER_FEMALE
  AGE: AGE_14_25
  DIALECT: DIALECT_SOUTH
  GENDER: GENDER_FEMALE

Training Details

Setting Value
Base Model stt_zh_citrinet_1024_gamma_0_25.nemo (NVIDIA)
Framework NeMo + PromptingNemo
Training Data 60,098 samples / 60 hours (AISHELL-3 with meta-tags)
Test Data 24,772 samples / 22.4 hours
Optimizer Adam (lr=5e-4, weight_decay=0, warmup=2000 steps)
LR Schedule CosineAnnealing (min_lr=1e-6)
Batch Size 16 (effective 32 with grad accumulation 2)
Max Duration 16s
Mixed Precision FP16
Spec Augment 4 time masks, width 80
Adapter Bottleneck (dim=256, activation=swish, norm=pre)
Tag Classifier Weight 0.1 (auxiliary loss)
Hardware 1x NVIDIA T4 16GB
Training Steps 18,000+ (best at step 17,005)
Tokenizer SentencePiece BPE (5,000 tokens)

What Makes This Model Different

Unlike standard ASR models, this model:

  1. Outputs structured metadata — AGE, GENDER, and DIALECT predictions via a separate classification head on the encoder output, without affecting CTC alignment
  2. Inline entity recognition — Named entities (PERSON_NAME, TEMPERATURE, DATE, etc.) are tagged directly in the transcript using ENTITY_TYPE ... END markers
  3. Adapter-based fine-tuning — Only the bottleneck adapters (12.1M params) and tag classifier (12.3K params) are trained; the base Citrinet encoder is frozen
  4. ONNX-ready — Dual-output ONNX graph exposes both CTC logprobs and raw encoder features for the tag classifier

Evaluation Results

Transcription Quality (CER)

Evaluated on 500 samples from the AISHELL-3 test set, with all meta-tags stripped for fair character-level comparison:

Model Params CER Additional Outputs
nvidia/stt_zh_citrinet_1024 140M 3.19% Transcription only
WhissleAI/STT-meta-ZH-100m 157.7M 11.31% + AGE, GENDER, DIALECT, Entities

The meta model trades ~8% CER for rich per-utterance metadata. The CTC head must learn to output both transcription tokens and inline entity tags (e.g., ENTITY_PERSON_NAME ... END), which reduces pure transcription accuracy compared to the base model that only does transcription.

Tag Classification Accuracy

Category Accuracy
Overall tags 94.2%

Meta-ASR WER (including tags)

Split WER (with tags)
Test 19.22%

Limitations

  • Trained primarily on AISHELL-3 data — may not generalize well to spontaneous/noisy Mandarin speech
  • Limited dialect diversity (North/South/Others) — does not cover specific regional varieties
  • Age classification uses broad buckets (<14, 14-25, 26-40, >41)
  • Entity recognition is limited to entity types seen in training data

Citation

@misc{whissle2025sttmetazh,
  title={STT-meta-ZH-100m: Dual-Head Mandarin ASR with Speaker Attribute Classification},
  author={WhissleAI},
  year={2025},
  url={https://huggingface.co/WhissleAI/STT-meta-ZH-100m}
}

License

Apache 2.0

Downloads last month
11
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support