import gradio as gr import numpy as np import torch import struct from transformers import ( WhisperForConditionalGeneration, WhisperTokenizer, WhisperProcessor, modeling_outputs, ) print("Loading Whisper...") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", predict_timestamps=True) processor = WhisperProcessor.from_pretrained("openai/whisper-small") model.eval() print("Ready!") TIME_STEPS = 1500 HIDDEN_SIZE = 768 TOTAL_FLOATS = TIME_STEPS * HIDDEN_SIZE # 1152000 TOTAL_BYTES = TOTAL_FLOATS * 4 # 4608000 TIME_PRECISION = 0.02 # each encoder frame = 0.02 seconds def to_srt_time(sec): if sec is None: sec = 0 h = int(sec // 3600) m = int((sec % 3600) // 60) s = int(sec % 60) ms = int((sec % 1) * 1000) return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" def decode(bin_file, language): if bin_file is None: return "Upload encoder_output.bin", "", "" with open(bin_file, "rb") as f: data = f.read() if len(data) != TOTAL_BYTES: return ( f"Wrong size. Expected {TOTAL_BYTES} bytes, got {len(data)} bytes." ), "", "" # Load encoder hidden states — big-endian (Java DataOutputStream) floats = struct.unpack(f">{TOTAL_FLOATS}f", data) hidden = np.array(floats, dtype=np.float32).reshape(1, TIME_STEPS, HIDDEN_SIZE) hidden_t = torch.from_numpy(hidden) encoder_outputs = modeling_outputs.BaseModelOutput( last_hidden_state=hidden_t ) forced_decoder_ids = processor.get_decoder_prompt_ids( language=language.lower(), task="transcribe" ) with torch.no_grad(): out = model.generate( encoder_outputs=encoder_outputs, forced_decoder_ids=forced_decoder_ids, max_new_tokens=444, return_dict_in_generate=True, output_attentions=True, # needed for cross-attention timestamps return_token_timestamps=True, # whisper computes timestamps from cross-attention ) token_ids = out["sequences"][0] token_ts = out["token_timestamps"][0] # real timestamps from cross-attention weights # Plain text plain = tokenizer.decode(token_ids, skip_special_tokens=True).strip() # Group BPE tokens into words correctly SPECIAL_START = 50257 tokens_list = token_ids.tolist() ts_list = token_ts.tolist() # Collect only text tokens text_tokens = [] text_ts = [] for tok, ts in zip(tokens_list, ts_list): if tok < SPECIAL_START: text_tokens.append(tok) text_ts.append(ts) if not text_tokens: return plain, "No words found", "" # Get BPE pieces bpe_pieces = tokenizer.convert_ids_to_tokens(text_tokens) # Group into words — Whisper BPE uses Ġ (\u0120) as word boundary marker words_grouped = [] cur_tokens = [] cur_start = None cur_end = None for i, (piece, ts) in enumerate(zip(bpe_pieces, text_ts)): next_ts = text_ts[i+1] if i+1 < len(text_ts) else ts + 0.4 is_boundary = ( piece.startswith('\u0120') or # Ġ = word start in BPE piece.startswith(' ') or cur_start is None ) if is_boundary and cur_tokens: # Decode grouped tokens together — fixes broken UTF-8 for Hindi word_text = tokenizer.decode( tokenizer.convert_tokens_to_ids(cur_tokens) ).strip() if word_text: words_grouped.append((word_text, cur_start, cur_end)) cur_tokens = [] cur_start = None cur_tokens.append(piece) if cur_start is None: cur_start = ts cur_end = next_ts # Flush last word if cur_tokens: word_text = tokenizer.decode( tokenizer.convert_tokens_to_ids(cur_tokens) ).strip() if word_text: words_grouped.append((word_text, cur_start, cur_end)) # Build outputs ts_lines = [] srt_lines = [] for i, (word, start, end) in enumerate(words_grouped, 1): ts_lines.append(f"{to_srt_time(start)} --> {to_srt_time(end)} {word}") srt_lines += [str(i), f"{to_srt_time(start)} --> {to_srt_time(end)}", word, ""] return plain, "\n".join(ts_lines), "\n".join(srt_lines) with gr.Blocks(title="Whisper Decoder") as demo: gr.Markdown(""" # Whisper Word-Level Decoder Upload `encoder_output.bin` from your Android app. Word timestamps computed from **cross-attention weights** — no audio needed here. """) with gr.Row(): file_input = gr.File(label="Upload encoder_output.bin", file_types=[".bin"]) language = gr.Dropdown( choices=["Hindi","English","French","German", "Spanish","Chinese","Japanese","Arabic","Korean","Russian"], value="Hindi", label="Language" ) btn = gr.Button("Decode", variant="primary") out_plain = gr.Textbox(label="Plain Text", lines=3) out_ts = gr.Textbox(label="Word Timestamps", lines=12) out_srt = gr.Textbox(label="SRT Format", lines=12) btn.click(fn=decode, inputs=[file_input, language], outputs=[out_plain, out_ts, out_srt]) demo.launch()