Spaces:
Runtime error
Runtime error
| import io | |
| import time | |
| import numpy as np | |
| import librosa | |
| import torch | |
| import streamlit as st | |
| from typing import Tuple | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| # βββ Page Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="RegionalCap Β· ASR", | |
| page_icon="ποΈ", | |
| layout="centered", | |
| initial_sidebar_state="collapsed", | |
| ) | |
| # βββ CSS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&family=IBM+Plex+Sans:wght@300;400;600&display=swap'); | |
| html, body, [class*="css"] { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| background-color: #0d0d0d; | |
| color: #e8e8e8; | |
| } | |
| #MainMenu, footer, header { visibility: hidden; } | |
| .block-container { padding: 2.4rem 2rem 4rem; max-width: 780px; } | |
| .hero { | |
| border-left: 3px solid #00e5a0; | |
| padding: 0.5rem 0 0.5rem 1.2rem; | |
| margin-bottom: 2rem; | |
| } | |
| .hero h1 { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 1.65rem; font-weight: 600; | |
| color: #fff; margin: 0; letter-spacing: -0.5px; | |
| } | |
| .hero p { font-size: 0.8rem; color: #777; margin: 0.25rem 0 0; } | |
| .sec-label { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 0.65rem; color: #00e5a0; | |
| letter-spacing: 1.8px; text-transform: uppercase; | |
| margin-bottom: 0.55rem; | |
| } | |
| .divider { border: none; border-top: 1px solid #1e1e1e; margin: 1.4rem 0; } | |
| .badge { | |
| display: inline-block; | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 0.68rem; padding: 0.15rem 0.5rem; | |
| border-radius: 4px; margin-right: 0.35rem; | |
| } | |
| .bg { background:#003d2a; color:#00e5a0; border:1px solid #00e5a0; } | |
| .by { background:#2e2700; color:#ffd84d; border:1px solid #ffd84d; } | |
| .br { background:#2e0000; color:#ff6b6b; border:1px solid #ff6b6b; } | |
| .bb { background:#001533; color:#4da6ff; border:1px solid #4da6ff; } | |
| .result-wrap { | |
| background: #0a160f; | |
| border: 1px solid #00e5a0; | |
| border-radius: 9px; | |
| padding: 1.1rem 1.3rem; | |
| margin-top: 0.8rem; | |
| } | |
| .result-text { | |
| font-size: 1.05rem; line-height: 1.85; | |
| color: #cff0e0; word-break: break-word; | |
| } | |
| .result-meta { | |
| font-family: 'IBM Plex Mono', monospace; | |
| font-size: 0.65rem; color: #456; | |
| margin-top: 0.75rem; padding-top: 0.6rem; | |
| border-top: 1px solid #1a2e22; | |
| } | |
| .stButton > button { | |
| background: #00e5a0 !important; color: #000 !important; | |
| font-family: 'IBM Plex Mono', monospace !important; | |
| font-weight: 600 !important; font-size: 0.8rem !important; | |
| letter-spacing: 0.8px !important; border: none !important; | |
| border-radius: 6px !important; padding: 0.5rem 1.6rem !important; | |
| } | |
| .stButton > button:hover { opacity: 0.82 !important; } | |
| div[data-testid="stFileUploader"] section { | |
| background: #111 !important; | |
| border: 1.5px dashed #2a2a2a !important; | |
| border-radius: 8px !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # βββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| REPO_ID = "rocky250/RegionalCap" | |
| PROCESSOR_ID = "openai/whisper-small" | |
| SAMPLE_RATE = 16000 | |
| CHECKPOINTS = [f"checkpoint-{n}" for n in range(1000, 11000, 1000)] | |
| DEFAULT_CKPT = "checkpoint-10000" | |
| AUDIO_FMTS = ["wav", "mp3", "flac", "ogg", "m4a"] | |
| # βββ Model loader βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model(checkpoint: str) -> Tuple: | |
| """ | |
| Load processor + model once; cached across reruns. | |
| Key fix: clear forced_decoder_ids from both model.config and | |
| generation_config so they don't conflict with our runtime language/task. | |
| """ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = WhisperProcessor.from_pretrained(PROCESSOR_ID) | |
| model = WhisperForConditionalGeneration.from_pretrained( | |
| REPO_ID, | |
| subfolder=checkpoint, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True, | |
| ).to(device) | |
| # ββ Fix: wipe pre-baked forced_decoder_ids so we control language/task ββ | |
| model.generation_config.forced_decoder_ids = None | |
| model.config.forced_decoder_ids = None | |
| # ββ Fix: set suppress_tokens to empty to avoid duplicate logits processor β | |
| model.generation_config.suppress_tokens = [] | |
| model.eval() | |
| return processor, model, device | |
| # βββ Transcription ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_transcription( | |
| audio_bytes: bytes, | |
| processor, | |
| model, | |
| device: str, | |
| language: str, | |
| task: str, | |
| ) -> Tuple[str, float, float]: | |
| """ | |
| Fixes vs previous version: | |
| - No temp file: librosa reads from BytesIO directly | |
| - No forced_decoder_ids passed to generate(); use language= task= kwargs | |
| (supported from transformers β₯ 4.27, avoids all logits-processor clashes) | |
| - attention_mask passed to avoid padding ambiguity warning | |
| """ | |
| # ββ Load audio from memory (no disk write) ββ | |
| audio_np, _ = librosa.load(io.BytesIO(audio_bytes), sr=SAMPLE_RATE, mono=True) | |
| duration = len(audio_np) / SAMPLE_RATE | |
| # ββ Feature extraction ββ | |
| inputs = processor( | |
| audio_np, | |
| sampling_rate=SAMPLE_RATE, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| input_features = inputs.input_features.to(device) | |
| attention_mask = inputs.get("attention_mask") | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.to(device) | |
| if device == "cuda": | |
| input_features = input_features.half() | |
| # ββ Generate ββ | |
| t0 = time.perf_counter() | |
| with torch.no_grad(): | |
| predicted_ids = model.generate( | |
| input_features, | |
| attention_mask=attention_mask, | |
| language=language, | |
| task=task, | |
| ) | |
| elapsed = time.perf_counter() - t0 | |
| text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return text.strip(), duration, elapsed | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UI | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # βββ Hero βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown(""" | |
| <div class="hero"> | |
| <h1>ποΈ RegionalCap</h1> | |
| <p>Bengali Dialect ASR Β· rocky250/RegionalCap Β· Whisper fine-tune</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # βββ Sidebar ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.sidebar: | |
| st.markdown("### βοΈ Config") | |
| checkpoint = st.selectbox("Checkpoint", CHECKPOINTS, | |
| index=CHECKPOINTS.index(DEFAULT_CKPT)) | |
| language = st.selectbox("Language", ["bn", "en"], index=0) | |
| task = st.selectbox("Task", ["transcribe", "translate"], index=0) | |
| st.markdown("---") | |
| st.caption(f"**Repo** `{REPO_ID}`") | |
| st.caption(f"**Processor** `{PROCESSOR_ID}`") | |
| # βββ Load model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _status = st.empty() | |
| with st.spinner("Loading model⦠(first run downloads weights, subsequent runs are instant)"): | |
| try: | |
| processor, model, device = load_model(checkpoint) | |
| _status.markdown( | |
| f'<span class="badge bg">β READY</span>' | |
| f'<span class="badge bb">{checkpoint}</span>' | |
| f'<span class="badge by">{device.upper()}</span>', | |
| unsafe_allow_html=True, | |
| ) | |
| except Exception as exc: | |
| _status.markdown('<span class="badge br">β LOAD FAILED</span>', | |
| unsafe_allow_html=True) | |
| st.error(str(exc)) | |
| st.stop() | |
| st.markdown('<hr class="divider">', unsafe_allow_html=True) | |
| # βββ Audio upload βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown('<div class="sec-label">01 Β· Upload Audio</div>', unsafe_allow_html=True) | |
| uploaded = st.file_uploader( | |
| "audio", | |
| type=AUDIO_FMTS, | |
| label_visibility="collapsed", | |
| ) | |
| if uploaded is None: | |
| st.markdown( | |
| '<p style="color:#444;font-size:0.78rem;margin:0.4rem 0 0;">' | |
| 'Supported: WAV Β· MP3 Β· FLAC Β· OGG Β· M4A</p>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.stop() | |
| # βββ Preview ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| audio_bytes = uploaded.read() | |
| ext = uploaded.name.rsplit(".", 1)[-1].lower() | |
| st.audio(audio_bytes, format=f"audio/{ext}") | |
| st.markdown( | |
| f'<span class="badge bb">{uploaded.name}</span>' | |
| f'<span class="badge by">{len(audio_bytes)/1024:.1f} kB</span>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown('<hr class="divider">', unsafe_allow_html=True) | |
| # βββ Transcribe βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown('<div class="sec-label">02 Β· Transcribe</div>', unsafe_allow_html=True) | |
| if st.button("βΆ Run Transcription"): | |
| result_slot = st.empty() | |
| result_slot.info("Processing audioβ¦") | |
| try: | |
| text, duration, elapsed = run_transcription( | |
| audio_bytes, processor, model, device, language, task | |
| ) | |
| rtf = elapsed / duration if duration > 0 else 0.0 | |
| result_slot.empty() # clear the "processing" message | |
| st.markdown( | |
| f'<div class="result-wrap">' | |
| f'<div class="result-text">{text}</div>' | |
| f'<div class="result-meta">' | |
| f'audio {duration:.1f}s Β· ' | |
| f'inference {elapsed:.2f}s Β· ' | |
| f'RTF {rtf:.3f} Β· ' | |
| f'{device}</div>' | |
| f'</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.download_button( | |
| label="β¬ Download .txt", | |
| data=text, | |
| file_name=f"{uploaded.name.rsplit('.',1)[0]}_transcription.txt", | |
| mime="text/plain", | |
| ) | |
| except Exception as exc: | |
| result_slot.empty() | |
| st.markdown('<span class="badge br">β ERROR</span>', unsafe_allow_html=True) | |
| st.error(str(exc)) |