Regional_Dialects / src /streamlit_app.py
rocky250's picture
Update src/streamlit_app.py
d1d4652 verified
import io
import time
import numpy as np
import librosa
import torch
import streamlit as st
from typing import Tuple
from transformers import WhisperProcessor, WhisperForConditionalGeneration
# ─── Page Config ─────────────────────────────────────────────────────────────
st.set_page_config(
page_title="RegionalCap Β· ASR",
page_icon="πŸŽ™οΈ",
layout="centered",
initial_sidebar_state="collapsed",
)
# ─── CSS ──────────────────────────────────────────────────────────────────────
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&family=IBM+Plex+Sans:wght@300;400;600&display=swap');
html, body, [class*="css"] {
font-family: 'IBM Plex Sans', sans-serif;
background-color: #0d0d0d;
color: #e8e8e8;
}
#MainMenu, footer, header { visibility: hidden; }
.block-container { padding: 2.4rem 2rem 4rem; max-width: 780px; }
.hero {
border-left: 3px solid #00e5a0;
padding: 0.5rem 0 0.5rem 1.2rem;
margin-bottom: 2rem;
}
.hero h1 {
font-family: 'IBM Plex Mono', monospace;
font-size: 1.65rem; font-weight: 600;
color: #fff; margin: 0; letter-spacing: -0.5px;
}
.hero p { font-size: 0.8rem; color: #777; margin: 0.25rem 0 0; }
.sec-label {
font-family: 'IBM Plex Mono', monospace;
font-size: 0.65rem; color: #00e5a0;
letter-spacing: 1.8px; text-transform: uppercase;
margin-bottom: 0.55rem;
}
.divider { border: none; border-top: 1px solid #1e1e1e; margin: 1.4rem 0; }
.badge {
display: inline-block;
font-family: 'IBM Plex Mono', monospace;
font-size: 0.68rem; padding: 0.15rem 0.5rem;
border-radius: 4px; margin-right: 0.35rem;
}
.bg { background:#003d2a; color:#00e5a0; border:1px solid #00e5a0; }
.by { background:#2e2700; color:#ffd84d; border:1px solid #ffd84d; }
.br { background:#2e0000; color:#ff6b6b; border:1px solid #ff6b6b; }
.bb { background:#001533; color:#4da6ff; border:1px solid #4da6ff; }
.result-wrap {
background: #0a160f;
border: 1px solid #00e5a0;
border-radius: 9px;
padding: 1.1rem 1.3rem;
margin-top: 0.8rem;
}
.result-text {
font-size: 1.05rem; line-height: 1.85;
color: #cff0e0; word-break: break-word;
}
.result-meta {
font-family: 'IBM Plex Mono', monospace;
font-size: 0.65rem; color: #456;
margin-top: 0.75rem; padding-top: 0.6rem;
border-top: 1px solid #1a2e22;
}
.stButton > button {
background: #00e5a0 !important; color: #000 !important;
font-family: 'IBM Plex Mono', monospace !important;
font-weight: 600 !important; font-size: 0.8rem !important;
letter-spacing: 0.8px !important; border: none !important;
border-radius: 6px !important; padding: 0.5rem 1.6rem !important;
}
.stButton > button:hover { opacity: 0.82 !important; }
div[data-testid="stFileUploader"] section {
background: #111 !important;
border: 1.5px dashed #2a2a2a !important;
border-radius: 8px !important;
}
</style>
""", unsafe_allow_html=True)
# ─── Constants ────────────────────────────────────────────────────────────────
REPO_ID = "rocky250/RegionalCap"
PROCESSOR_ID = "openai/whisper-small"
SAMPLE_RATE = 16000
CHECKPOINTS = [f"checkpoint-{n}" for n in range(1000, 11000, 1000)]
DEFAULT_CKPT = "checkpoint-10000"
AUDIO_FMTS = ["wav", "mp3", "flac", "ogg", "m4a"]
# ─── Model loader ─────────────────────────────────────────────────────────────
@st.cache_resource(show_spinner=False)
def load_model(checkpoint: str) -> Tuple:
"""
Load processor + model once; cached across reruns.
Key fix: clear forced_decoder_ids from both model.config and
generation_config so they don't conflict with our runtime language/task.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = WhisperProcessor.from_pretrained(PROCESSOR_ID)
model = WhisperForConditionalGeneration.from_pretrained(
REPO_ID,
subfolder=checkpoint,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True,
).to(device)
# ── Fix: wipe pre-baked forced_decoder_ids so we control language/task ──
model.generation_config.forced_decoder_ids = None
model.config.forced_decoder_ids = None
# ── Fix: set suppress_tokens to empty to avoid duplicate logits processor ─
model.generation_config.suppress_tokens = []
model.eval()
return processor, model, device
# ─── Transcription ────────────────────────────────────────────────────────────
def run_transcription(
audio_bytes: bytes,
processor,
model,
device: str,
language: str,
task: str,
) -> Tuple[str, float, float]:
"""
Fixes vs previous version:
- No temp file: librosa reads from BytesIO directly
- No forced_decoder_ids passed to generate(); use language= task= kwargs
(supported from transformers β‰₯ 4.27, avoids all logits-processor clashes)
- attention_mask passed to avoid padding ambiguity warning
"""
# ── Load audio from memory (no disk write) ──
audio_np, _ = librosa.load(io.BytesIO(audio_bytes), sr=SAMPLE_RATE, mono=True)
duration = len(audio_np) / SAMPLE_RATE
# ── Feature extraction ──
inputs = processor(
audio_np,
sampling_rate=SAMPLE_RATE,
return_tensors="pt",
return_attention_mask=True,
)
input_features = inputs.input_features.to(device)
attention_mask = inputs.get("attention_mask")
if attention_mask is not None:
attention_mask = attention_mask.to(device)
if device == "cuda":
input_features = input_features.half()
# ── Generate ──
t0 = time.perf_counter()
with torch.no_grad():
predicted_ids = model.generate(
input_features,
attention_mask=attention_mask,
language=language,
task=task,
)
elapsed = time.perf_counter() - t0
text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return text.strip(), duration, elapsed
# ══════════════════════════════════════════════════════════════════════════════
# UI
# ══════════════════════════════════════════════════════════════════════════════
# ─── Hero ─────────────────────────────────────────────────────────────────────
st.markdown("""
<div class="hero">
<h1>πŸŽ™οΈ RegionalCap</h1>
<p>Bengali Dialect ASR &nbsp;Β·&nbsp; rocky250/RegionalCap &nbsp;Β·&nbsp; Whisper fine-tune</p>
</div>
""", unsafe_allow_html=True)
# ─── Sidebar ──────────────────────────────────────────────────────────────────
with st.sidebar:
st.markdown("### βš™οΈ Config")
checkpoint = st.selectbox("Checkpoint", CHECKPOINTS,
index=CHECKPOINTS.index(DEFAULT_CKPT))
language = st.selectbox("Language", ["bn", "en"], index=0)
task = st.selectbox("Task", ["transcribe", "translate"], index=0)
st.markdown("---")
st.caption(f"**Repo** `{REPO_ID}`")
st.caption(f"**Processor** `{PROCESSOR_ID}`")
# ─── Load model ───────────────────────────────────────────────────────────────
_status = st.empty()
with st.spinner("Loading model… (first run downloads weights, subsequent runs are instant)"):
try:
processor, model, device = load_model(checkpoint)
_status.markdown(
f'<span class="badge bg">βœ“ READY</span>'
f'<span class="badge bb">{checkpoint}</span>'
f'<span class="badge by">{device.upper()}</span>',
unsafe_allow_html=True,
)
except Exception as exc:
_status.markdown('<span class="badge br">βœ— LOAD FAILED</span>',
unsafe_allow_html=True)
st.error(str(exc))
st.stop()
st.markdown('<hr class="divider">', unsafe_allow_html=True)
# ─── Audio upload ─────────────────────────────────────────────────────────────
st.markdown('<div class="sec-label">01 Β· Upload Audio</div>', unsafe_allow_html=True)
uploaded = st.file_uploader(
"audio",
type=AUDIO_FMTS,
label_visibility="collapsed",
)
if uploaded is None:
st.markdown(
'<p style="color:#444;font-size:0.78rem;margin:0.4rem 0 0;">'
'Supported: WAV Β· MP3 Β· FLAC Β· OGG Β· M4A</p>',
unsafe_allow_html=True,
)
st.stop()
# ─── Preview ──────────────────────────────────────────────────────────────────
audio_bytes = uploaded.read()
ext = uploaded.name.rsplit(".", 1)[-1].lower()
st.audio(audio_bytes, format=f"audio/{ext}")
st.markdown(
f'<span class="badge bb">{uploaded.name}</span>'
f'<span class="badge by">{len(audio_bytes)/1024:.1f} kB</span>',
unsafe_allow_html=True,
)
st.markdown('<hr class="divider">', unsafe_allow_html=True)
# ─── Transcribe ───────────────────────────────────────────────────────────────
st.markdown('<div class="sec-label">02 Β· Transcribe</div>', unsafe_allow_html=True)
if st.button("β–Ά Run Transcription"):
result_slot = st.empty()
result_slot.info("Processing audio…")
try:
text, duration, elapsed = run_transcription(
audio_bytes, processor, model, device, language, task
)
rtf = elapsed / duration if duration > 0 else 0.0
result_slot.empty() # clear the "processing" message
st.markdown(
f'<div class="result-wrap">'
f'<div class="result-text">{text}</div>'
f'<div class="result-meta">'
f'audio {duration:.1f}s &nbsp;Β·&nbsp; '
f'inference {elapsed:.2f}s &nbsp;Β·&nbsp; '
f'RTF {rtf:.3f} &nbsp;Β·&nbsp; '
f'{device}</div>'
f'</div>',
unsafe_allow_html=True,
)
st.download_button(
label="⬇ Download .txt",
data=text,
file_name=f"{uploaded.name.rsplit('.',1)[0]}_transcription.txt",
mime="text/plain",
)
except Exception as exc:
result_slot.empty()
st.markdown('<span class="badge br">βœ— ERROR</span>', unsafe_allow_html=True)
st.error(str(exc))