File size: 3,579 Bytes
e3207ee
2102ae8
6adf5a9
 
f00cb9e
6adf5a9
 
 
45c12a4
 
8e34a29
f00cb9e
8e34a29
 
 
e3207ee
8e34a29
f00cb9e
e3207ee
 
 
e37e472
e3207ee
 
 
 
395778a
6642f61
e3207ee
 
6642f61
e3207ee
 
6642f61
e3207ee
 
 
6642f61
8e34a29
 
e3207ee
 
 
512e2bc
e3207ee
 
 
f00cb9e
e3207ee
 
 
f00cb9e
 
 
e3207ee
 
 
2102ae8
9648db0
 
 
2102ae8
e3207ee
 
 
2102ae8
e3207ee
f00cb9e
 
2102ae8
e3207ee
2102ae8
f00cb9e
6642f61
 
 
 
 
 
c675e00
f00cb9e
e3207ee
 
 
395778a
f00cb9e
e3207ee
 
395778a
e3207ee
2102ae8
e3207ee
8e34a29
 
f00cb9e
 
 
e3207ee
f00cb9e
e3207ee
f00cb9e
e3207ee
f00cb9e
e3207ee
 
f00cb9e
 
e3207ee
f00cb9e
e3207ee
f00cb9e
e3207ee
f00cb9e
 
 
 
 
e3207ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# app.py — wav2vec2 multi-aug (stable + high quality)

import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["PYTORCH_ALLOC_CONF"] = "max_split_size_mb:128"

import gradio as gr
import spaces
import torch
import gc
import re
import librosa

from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ForCTC
)

#MODEL_ID = "palli23/wav2vec2-icelandic-multi-aug-v2-5e-6"
MODEL_ID = "palli23/wav2vec2-icelandic-clean"
# MODEL_ID = "palli23/wav2vec2-xlsr-300m-icelandic"

# ——————————————————————————————
# Strong Icelandic cleanup
# ——————————————————————————————
def clean_text(text: str) -> str:
    text = text.lower()

    # collapse repeats (ctc artifacts)
    text = re.sub(r"(.)\1{3,}", r"\1\1", text)

    # spacing
    text = re.sub(r"\s+", " ", text)

    # punctuation spacing
    text = text.replace(" ,", ",").replace(" .", ".")
    text = text.replace(" ?", "?").replace(" !", "!")

    return text.strip()

# ——————————————————————————————
# Chunking helper (overlap improves WER)
# ——————————————————————————————
def chunk_audio(audio, sr, chunk_s=20, overlap_s=0):
    step = chunk_s - overlap_s
    chunk_len = int(chunk_s * sr)
    step_len = int(step * sr)

    for start in range(0, len(audio), step_len):
        chunk = audio[start:start + chunk_len]
        if len(chunk) < sr:  # too short
            break
        yield chunk

# ——————————————————————————————
# ZeroGPU worker
# ——————————————————————————————
@spaces.GPU(duration=180)
def transcribe_3min(audio_path):
    if not audio_path:
        return "Hlaðið upp hljóðskrá"

    processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
    model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
    model.eval().to("cuda")

    # Load audio (float32 enforced)
    audio, sr = librosa.load(audio_path, sr=16000, mono=True)
    audio = audio.astype("float32")

    texts = []

    for chunk in chunk_audio(audio, sr):
        inputs = processor(
            chunk,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True
        )

        with torch.no_grad():
            logits = model(
                inputs.input_values.to("cuda")
            ).logits

        pred_ids = torch.argmax(logits, dim=-1)
        text = processor.batch_decode(pred_ids)[0]
        texts.append(text)

    final_text = clean_text(" ".join(texts))

    # Cleanup (critical)
    del model
    del processor
    gc.collect()
    torch.cuda.empty_cache()

    return final_text

# ——————————————————————————————
# UI
# ——————————————————————————————
with gr.Blocks() as demo:
    gr.Markdown("# Íslenskt ASR – wav2vec2 (multi-aug)")
    gr.Markdown("**stöðugt · chunked · post-processed**")
    gr.Markdown("**Hafa samband:** pallinr1@protonmail.com")

    audio_in = gr.Audio(type="filepath", label="Hlaðið upp .mp3 / .wav")
    btn = gr.Button("Transcribe", variant="primary", size="lg")
    output = gr.Textbox(lines=20, label="Útskrift")

    btn.click(fn=transcribe_3min, inputs=audio_in, outputs=output)

demo.launch(
    share=True,
    server_name="0.0.0.0",
    server_port=7860,
)