Qwen3-ASR-Micro / scripts /scale_corpus.py
Luigi's picture
Qwen3-ASR-0.3B: distilled zh-TW/en student + GGUF + reproduction scripts
198b449 verified
Raw
History Blame Contribute Delete
3.46 kB
# Scaled corpus build for Qwen3-ASR-0.3B distillation: download more YouTube-TW videos,
# segment, teacher-label (0.6B) -> s2twp Traditional -> append to kd_dataset.jsonl. Resumable.
import os, glob, subprocess, json, re, time, numpy as np, soundfile as sf, torch
from transformers import Qwen3ASRForConditionalGeneration, AutoProcessor
from opencc import OpenCC
s2twp=OpenCC("s2twp"); dev="cuda"; WIN=16.0
YTDLP="/home/luigi/jetson-stt/.venv/bin/yt-dlp"
N_VIDEOS=int(os.environ.get("N_VIDEOS","150"))
os.makedirs("yt_scale", exist_ok=True)
OUT="kd_dataset.jsonl"
ids=[]; seen=set()
for l in open("yt_manifest.tsv",encoding="utf-8"):
fn=l.split("\t")[0].rsplit("/",1)[-1].rsplit(".",1)[0]; v="_".join(fn.split("_")[:-1])
if v not in seen: seen.add(v); ids.append(v)
done=set()
if os.path.exists("scale_done.txt"): done=set(open("scale_done.txt").read().split())
todo=[v for v in ids[15:] if v not in done][:N_VIDEOS]
print(f"scale: {len(todo)} new videos (target), {len(done)} already done",flush=True)
proc=AutoProcessor.from_pretrained("/tmp/qwen3-asr-hf")
model=Qwen3ASRForConditionalGeneration.from_pretrained("/tmp/qwen3-asr-hf", dtype=torch.float16).to(dev).eval()
def dl(vid,out):
subprocess.run([YTDLP,"--js-runtimes","deno","-f","140/bestaudio[ext=m4a]/bestaudio",
"--match-filter","!is_live & duration>120 & duration<14400","-x","--audio-format","wav",
"--postprocessor-args","ExtractAudio:-ar 16000 -ac 1","--no-playlist","-q","-o",out,
f"https://www.youtube.com/watch?v={vid}"],
env=dict(os.environ,PATH="/home/luigi/.local/bin:"+os.environ.get("PATH","")),
timeout=600,stdout=subprocess.DEVNULL,stderr=subprocess.DEVNULL)
@torch.no_grad()
def label(a):
conv=[{"role":"user","content":[{"type":"audio","audio":a},{"type":"text","text":""}]}]
inp=proc.apply_chat_template(conv,add_generation_prompt=True,tokenize=True,return_dict=True,return_tensors="pt",sampling_rate=16000)
inp={k:(v.to(dev).half() if v.dtype==torch.float32 else v.to(dev)) for k,v in inp.items()}
o=model.generate(**inp,max_new_tokens=96,do_sample=False)
return re.sub(r".*<asr_text>","",proc.batch_decode(o[:,inp["input_ids"].shape[1]:],skip_special_tokens=True)[0]).strip()
out=open(OUT,"a",encoding="utf-8"); dlog=open("scale_done.txt","a")
t0=time.time(); nseg=0; nh=0.0
for vi,vid in enumerate(todo):
raw=f"yt_scale/_tmp_{vid}.wav"
try:
dl(vid,raw); a,_=sf.read(raw,dtype="float32"); a=a.mean(1) if a.ndim>1 else a
except Exception:
for p in glob.glob(f"yt_scale/_tmp_{vid}*"):
try: os.remove(p)
except: pass
continue
W=int(WIN*16000); k=0
for s in range(0,len(a),W):
seg=a[s:s+W]
if len(seg)<16000 or float(np.abs(seg).max())<0.01: continue
raw_txt=label(seg)
cjk=sum(1 for c in raw_txt if '一'<=c<='鿿'); lat=sum(c.isalpha() and ord(c)<128 for c in raw_txt)
if cjk<2 and lat<4: continue
sp=f"yt_scale/{vid}_{k}.wav"; sf.write(sp,seg,16000)
out.write(json.dumps({"wav":os.path.abspath(sp),"text":s2twp.convert(raw_txt)},ensure_ascii=False)+"\n")
nseg+=1; nh+=len(seg)/16000/3600; k+=1
try: os.remove(raw)
except: pass
out.flush(); dlog.write(vid+"\n"); dlog.flush()
print(f" +{vid} ({k} segs) -> {nseg} segs, {nh:.1f}h [{vi+1}/{len(todo)}] {(time.time()-t0)/60:.0f}min",flush=True)
print(f"SCALE-DONE: {nseg} segs, {nh:.1f}h -> {OUT}",flush=True)