ASR_API2 / app.py
palli23's picture
diarization1Mæló
3407dd3
raw
history blame
2.32 kB
# app.py for HF Spaces (ZeroGPU safe pyannote)
import os
import gradio as gr
import spaces
import tempfile
import torch
from torch.serialization import safe_globals
from pyannote.audio.core.model import Model
from pyannote.audio.core.task import Task, Specifications
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from typing import OrderedDict
from transformers import pipeline
from pyannote.audio import Pipeline
# Required patches for ZeroGPU
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
torch.serialization.add_safe_globals({
"OrderedDict": OrderedDict,
})
MODEL_NAME = "palli23/whisper-small-sam_spjall"
@spaces.GPU(duration=120)
def transcribe_with_diarization(audio_path):
if not audio_path:
return "Hladdu upp hljóðskrá"
# Fix strict unpickling in torch 2.6 (ZeroGPU)
with safe_globals([
torch.torch_version.TorchVersion,
Model,
Task,
Specifications,
SpeakerDiarization,
OrderedDict,
]):
diarization = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=os.getenv("HF_TOKEN")
).to("cuda")
# Run diarization
dia = diarization(audio_path)
# Whisper model
asr = pipeline(
"automatic-speech-recognition",
model=MODEL_NAME,
device=0,
use_auth_token=os.getenv("HF_TOKEN"),
)
# segment-by-segment ASR
result = []
for turn, _, speaker in dia.itertracks(yield_label=True):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
diarization.crop(audio_path, turn).export(f.name, format="wav")
chunk = f.name
text = asr(chunk)["text"].strip()
os.unlink(chunk)
result.append(f"[MÆLENDI {speaker}] {text}")
return "\n".join(result) or "Enginn texti heyrðist."
with gr.Blocks() as demo:
gr.Markdown("# Íslenskt ASR + Mælendagreining")
gr.Markdown("Whisper-small + pyannote 3.1 (ZeroGPU örugg útgáfa)")
audio = gr.Audio(type="filepath", label="Hljóðskrá")
btn = gr.Button("Transcribe með mælendum")
out = gr.Textbox(lines=35, label="Úttak")
btn.click(transcribe_with_diarization, inputs=audio, outputs=out)
demo.launch(auth=("beta", "beta2025"))