Wav2vecTest / app.py
palli23's picture
Update app.py
512e2bc verified
# app.py — wav2vec2 multi-aug (stable + high quality)
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["PYTORCH_ALLOC_CONF"] = "max_split_size_mb:128"
import gradio as gr
import spaces
import torch
import gc
import re
import librosa
from transformers import (
Wav2Vec2Processor,
Wav2Vec2ForCTC
)
#MODEL_ID = "palli23/wav2vec2-icelandic-multi-aug-v2-5e-6"
MODEL_ID = "palli23/wav2vec2-icelandic-clean"
# MODEL_ID = "palli23/wav2vec2-xlsr-300m-icelandic"
# ——————————————————————————————
# Strong Icelandic cleanup
# ——————————————————————————————
def clean_text(text: str) -> str:
text = text.lower()
# collapse repeats (ctc artifacts)
text = re.sub(r"(.)\1{3,}", r"\1\1", text)
# spacing
text = re.sub(r"\s+", " ", text)
# punctuation spacing
text = text.replace(" ,", ",").replace(" .", ".")
text = text.replace(" ?", "?").replace(" !", "!")
return text.strip()
# ——————————————————————————————
# Chunking helper (overlap improves WER)
# ——————————————————————————————
def chunk_audio(audio, sr, chunk_s=20, overlap_s=0):
step = chunk_s - overlap_s
chunk_len = int(chunk_s * sr)
step_len = int(step * sr)
for start in range(0, len(audio), step_len):
chunk = audio[start:start + chunk_len]
if len(chunk) < sr: # too short
break
yield chunk
# ——————————————————————————————
# ZeroGPU worker
# ——————————————————————————————
@spaces.GPU(duration=180)
def transcribe_3min(audio_path):
if not audio_path:
return "Hlaðið upp hljóðskrá"
processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
model.eval().to("cuda")
# Load audio (float32 enforced)
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
audio = audio.astype("float32")
texts = []
for chunk in chunk_audio(audio, sr):
inputs = processor(
chunk,
sampling_rate=16000,
return_tensors="pt",
padding=True
)
with torch.no_grad():
logits = model(
inputs.input_values.to("cuda")
).logits
pred_ids = torch.argmax(logits, dim=-1)
text = processor.batch_decode(pred_ids)[0]
texts.append(text)
final_text = clean_text(" ".join(texts))
# Cleanup (critical)
del model
del processor
gc.collect()
torch.cuda.empty_cache()
return final_text
# ——————————————————————————————
# UI
# ——————————————————————————————
with gr.Blocks() as demo:
gr.Markdown("# Íslenskt ASR – wav2vec2 (multi-aug)")
gr.Markdown("**stöðugt · chunked · post-processed**")
gr.Markdown("**Hafa samband:** pallinr1@protonmail.com")
audio_in = gr.Audio(type="filepath", label="Hlaðið upp .mp3 / .wav")
btn = gr.Button("Transcribe", variant="primary", size="lg")
output = gr.Textbox(lines=20, label="Útskrift")
btn.click(fn=transcribe_3min, inputs=audio_in, outputs=output)
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
)