Streaming-STT-1.5B

Continue pretraining Qwen/Qwen2.5-1.5B on malaysia-ai/Malaysian-STT 3B audio tokens or equivalent to 72k hours, and natively,

  1. Streaming mode by using <|streaming|> prefix.
  2. Semantic VAD by predicting <|endofspeech|> token probability for streaming mode.
  3. Whole mode by using <|whole|> prefix.
  4. Support segment level timestamp by using <|segment|> prefix.
  5. Support word level timestamp by using <|word|> prefix.
  6. Beyond 30 seconds audio prediction.
  7. Plug and play in any continuous batching serving framework such as vLLM, just another Qwen2.5 model.
  8. Use GLM4 Speech Tokenizer, 12.5 TPS. Discrete tokens work like a charm with prefix caching, especially for streaming.

Benchmark

image/png

How do we train

  1. Multipacking with proper document masking on 10240 context length.
  2. FP32-BF16 mixed precision training.
  3. Full parameter finetuning.
  4. WanDB at https://wandb.ai/huseinzol05/Qwen-Qwen2.5-1.5B-STT-10k

How to

First you need to install the speech tokenizer,

pip3 install git+https://github.com/malaysia-ai/glm4-audio-tokenizer

And load the model,

from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from glm4_audio_tokenizer import Glm4Tokenizer
import torch

glm4 = Glm4Tokenizer().to(torch.float16).cuda()
model = AutoModelForCausalLM.from_pretrained('malaysia-ai/Streaming-STT-1.5B').cuda()
tokenizer = AutoTokenizer.from_pretrained('malaysia-ai/Streaming-STT-1.5B')

Whole segment timestamp mode

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat.mp3
speech_tokens = glm4.tokenize(['husein-chat.mp3'])
token = ''.join([f'<|s{t}|>' for t in speech_tokens[0]]) + '<|endofspeech|>'
prompt = '<|whole|><|segment|>' + token
input_ids = tokenizer(prompt, return_tensors = 'pt').to('cuda')
generate_kwargs = dict(
    **input_ids,
    max_new_tokens=1024,
)
generation_output = model.generate(**generate_kwargs)
tokenizer.decode(generation_output[0, input_ids['input_ids'].shape[1]:])

Output,

<|0.30|> Hai,<|0.56|><|1.16|> saya adalah pembantu AI anda.<|3.02|><|3.62|> Selamat berkenalan!<|4.50|><|5.06|> Apa yang saya boleh tolong<|6.20|><|6.52|> untuk buatkan hari anda lebih ceria?<|8.62|><|endoftext|>

Whole word timestamp mode

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat.mp3
speech_tokens = glm4.tokenize(['husein-chat.mp3'])
token = ''.join([f'<|s{t}|>' for t in speech_tokens[0]]) + '<|endofspeech|>'
prompt = '<|whole|><|word|>' + token
input_ids = tokenizer(prompt, return_tensors = 'pt').to('cuda')
generate_kwargs = dict(
    **input_ids,
    max_new_tokens=1024,
)
generation_output = model.generate(**generate_kwargs)
tokenizer.decode(generation_output[0, input_ids['input_ids'].shape[1]:])

Output,

<|0.30|> Hai,<|0.56|><|1.16|> saya<|1.38|><|1.50|> adalah<|1.78|><|1.84|> pembantu<|2.22|><|2.42|> AI<|2.68|><|2.86|> anda.<|3.06|><|3.66|> Selamat<|3.98|><|4.04|> berkenalan!<|4.52|><|5.08|> Apa<|5.22|><|5.30|> yang<|5.42|><|5.48|> saya<|5.62|><|5.68|> boleh<|5.84|><|5.88|> tolong<|6.20|><|6.54|> untuk<|6.74|><|6.80|> buatkan<|7.12|><|7.22|> hari<|7.42|><|7.54|> anda<|7.72|><|7.84|> lebih<|8.02|><|8.08|> ceria?<|8.62|><|endoftext|>

Streaming segment timestamp mode

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part1.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part2.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part3.mp3

speech_tokens = glm4.tokenize(['husein-chat-part1.mp3', 'husein-chat-part2.mp3', 'husein-chat-part3.mp3'])

prompt = '<|streaming|><|segment|>'
for i in range(len(speech_tokens)):
    token = ''.join([f'<|s{t}|>' for t in speech_tokens[i]]) + '<|endofspeech|>'
    
    input_ids = tokenizer(prompt + token, return_tensors = 'pt').to('cuda')
    generate_kwargs = dict(
        **input_ids,
        max_new_tokens=1024,
    )
    generation_output = model.generate(**generate_kwargs)
    new_prompt = tokenizer.decode(generation_output[0])
    prompt = new_prompt
    t = tokenizer.decode(generation_output[0, input_ids['input_ids'].shape[1]:])
    print(f'index {i + 1}: {t}')
    print()

Output,

index 1: <|0.02|> Hai. Saya adalah pembantu AI anda.<|3.24|><|endoftext|>

index 2: <|3.62|> Selamat berkenalan! Apa yang saya boleh tolong?<|6.84|><|endoftext|>

index 3: <|7.24|> Untuk buatkan hari anda lebih ceria.<|9.46|><|endoftext|>

Streaming word timestamp mode

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part1.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part2.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-part3.mp3

speech_tokens = glm4.tokenize(['husein-chat-part1.mp3', 'husein-chat-part2.mp3', 'husein-chat-part3.mp3'])

# we found out using whole mode for word timestamp works better for this example
prompt = '<|whole|><|word|>'
for i in range(len(speech_tokens)):
    token = ''.join([f'<|s{t}|>' for t in speech_tokens[i]]) + '<|endofspeech|>'
    
    input_ids = tokenizer(prompt + token, return_tensors = 'pt').to('cuda')
    generate_kwargs = dict(
        **input_ids,
        max_new_tokens=1024,
    )
    generation_output = model.generate(**generate_kwargs)
    new_prompt = tokenizer.decode(generation_output[0])
    prompt = new_prompt
    t = tokenizer.decode(generation_output[0, input_ids['input_ids'].shape[1]:])
    print(f'index {i + 1}: {t}')
    print()

Output,

index 1: <|0.28|> Hai.<|0.54|><|1.14|> Saya<|1.36|><|1.48|> adalah<|1.76|><|1.82|> pembantu<|2.22|><|2.40|> AI<|2.68|><|2.84|> anda.<|3.20|><|endoftext|>

index 2: <|3.72|> Selamat<|4.44|><|4.50|> berkenalan!<|5.04|><|5.58|> Apa<|5.72|><|5.82|> yang<|5.92|><|5.98|> saya<|6.12|><|6.18|> boleh<|6.34|><|6.38|> tolong?<|6.96|><|endoftext|>

index 3: <|7.38|> Untuk<|7.64|><|7.72|> buatkan<|8.06|><|8.14|> hari<|8.36|><|8.48|> anda<|8.64|><|8.78|> lebih<|8.96|><|9.00|> ceria.<|9.60|><|endoftext|>

Semantic VAD

# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-not-proper-cut.mp3
# !wget https://github.com/mesolitica/malaya-speech/raw/refs/heads/master/speech/record/husein-chat-proper-cut.mp3

speech_tokens = glm4.tokenize(['husein-chat-not-proper-cut.mp3', 'husein-chat-proper-cut.mp3'])
for i in range(len(speech_tokens)):
    prompt = '<|streaming|><|word|>'
    token = ''.join([f'<|s{t}|>' for t in speech_tokens[i]])
    input_ids = tokenizer(prompt + token, return_tensors = 'pt').to('cuda')
    logits = model(**input_ids).logits
    print(i, logits[0, -1, 151665]) # 151665 is <|endofspeech|> token

Output,

0 tensor(51.9292, device='cuda:0') # not proper cut
1 tensor(54.5208, device='cuda:0') # proper cut

Source code

Source code at https://github.com/malaysia-ai/cooking/tree/main/qwen-stt

Acknowledgement

Special thanks to Lambda Research Grant program for Lambda cloud credit!

Downloads last month
2
Safetensors
Model size
2B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for malaysia-ai/Streaming-STT-1.5B

Base model

Qwen/Qwen2.5-1.5B
Finetuned
(291)
this model
Quantizations
2 models

Dataset used to train malaysia-ai/Streaming-STT-1.5B