| import gradio as gr |
| import numpy as np |
| import torch |
| import struct |
| from transformers import ( |
| WhisperForConditionalGeneration, |
| WhisperTokenizer, |
| WhisperProcessor, |
| modeling_outputs, |
| ) |
|
|
| print("Loading Whisper...") |
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") |
| tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", predict_timestamps=True) |
| processor = WhisperProcessor.from_pretrained("openai/whisper-small") |
| model.eval() |
| print("Ready!") |
|
|
| TIME_STEPS = 1500 |
| HIDDEN_SIZE = 768 |
| TOTAL_FLOATS = TIME_STEPS * HIDDEN_SIZE |
| TOTAL_BYTES = TOTAL_FLOATS * 4 |
| TIME_PRECISION = 0.02 |
|
|
| def to_srt_time(sec): |
| if sec is None: sec = 0 |
| h = int(sec // 3600) |
| m = int((sec % 3600) // 60) |
| s = int(sec % 60) |
| ms = int((sec % 1) * 1000) |
| return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" |
|
|
| def decode(bin_file, language): |
| if bin_file is None: |
| return "Upload encoder_output.bin", "", "" |
|
|
| with open(bin_file, "rb") as f: |
| data = f.read() |
|
|
| if len(data) != TOTAL_BYTES: |
| return ( |
| f"Wrong size. Expected {TOTAL_BYTES} bytes, got {len(data)} bytes." |
| ), "", "" |
|
|
| |
| floats = struct.unpack(f">{TOTAL_FLOATS}f", data) |
| hidden = np.array(floats, dtype=np.float32).reshape(1, TIME_STEPS, HIDDEN_SIZE) |
| hidden_t = torch.from_numpy(hidden) |
|
|
| encoder_outputs = modeling_outputs.BaseModelOutput( |
| last_hidden_state=hidden_t |
| ) |
|
|
| forced_decoder_ids = processor.get_decoder_prompt_ids( |
| language=language.lower(), |
| task="transcribe" |
| ) |
|
|
| with torch.no_grad(): |
| out = model.generate( |
| encoder_outputs=encoder_outputs, |
| forced_decoder_ids=forced_decoder_ids, |
| max_new_tokens=444, |
| return_dict_in_generate=True, |
| output_attentions=True, |
| return_token_timestamps=True, |
| ) |
|
|
| token_ids = out["sequences"][0] |
| token_ts = out["token_timestamps"][0] |
|
|
| |
| plain = tokenizer.decode(token_ids, skip_special_tokens=True).strip() |
|
|
| |
| SPECIAL_START = 50257 |
| tokens_list = token_ids.tolist() |
| ts_list = token_ts.tolist() |
|
|
| |
| text_tokens = [] |
| text_ts = [] |
| for tok, ts in zip(tokens_list, ts_list): |
| if tok < SPECIAL_START: |
| text_tokens.append(tok) |
| text_ts.append(ts) |
|
|
| if not text_tokens: |
| return plain, "No words found", "" |
|
|
| |
| bpe_pieces = tokenizer.convert_ids_to_tokens(text_tokens) |
|
|
| |
| words_grouped = [] |
| cur_tokens = [] |
| cur_start = None |
| cur_end = None |
|
|
| for i, (piece, ts) in enumerate(zip(bpe_pieces, text_ts)): |
| next_ts = text_ts[i+1] if i+1 < len(text_ts) else ts + 0.4 |
|
|
| is_boundary = ( |
| piece.startswith('\u0120') or |
| piece.startswith(' ') or |
| cur_start is None |
| ) |
|
|
| if is_boundary and cur_tokens: |
| |
| word_text = tokenizer.decode( |
| tokenizer.convert_tokens_to_ids(cur_tokens) |
| ).strip() |
| if word_text: |
| words_grouped.append((word_text, cur_start, cur_end)) |
| cur_tokens = [] |
| cur_start = None |
|
|
| cur_tokens.append(piece) |
| if cur_start is None: |
| cur_start = ts |
| cur_end = next_ts |
|
|
| |
| if cur_tokens: |
| word_text = tokenizer.decode( |
| tokenizer.convert_tokens_to_ids(cur_tokens) |
| ).strip() |
| if word_text: |
| words_grouped.append((word_text, cur_start, cur_end)) |
|
|
| |
| ts_lines = [] |
| srt_lines = [] |
| for i, (word, start, end) in enumerate(words_grouped, 1): |
| ts_lines.append(f"{to_srt_time(start)} --> {to_srt_time(end)} {word}") |
| srt_lines += [str(i), f"{to_srt_time(start)} --> {to_srt_time(end)}", word, ""] |
|
|
| return plain, "\n".join(ts_lines), "\n".join(srt_lines) |
|
|
|
|
| with gr.Blocks(title="Whisper Decoder") as demo: |
| gr.Markdown(""" |
| # Whisper Word-Level Decoder |
| Upload `encoder_output.bin` from your Android app. |
| Word timestamps computed from **cross-attention weights** — no audio needed here. |
| """) |
|
|
| with gr.Row(): |
| file_input = gr.File(label="Upload encoder_output.bin", file_types=[".bin"]) |
| language = gr.Dropdown( |
| choices=["Hindi","English","French","German", |
| "Spanish","Chinese","Japanese","Arabic","Korean","Russian"], |
| value="Hindi", label="Language" |
| ) |
|
|
| btn = gr.Button("Decode", variant="primary") |
|
|
| out_plain = gr.Textbox(label="Plain Text", lines=3) |
| out_ts = gr.Textbox(label="Word Timestamps", lines=12) |
| out_srt = gr.Textbox(label="SRT Format", lines=12) |
|
|
| btn.click(fn=decode, inputs=[file_input, language], |
| outputs=[out_plain, out_ts, out_srt]) |
|
|
| demo.launch() |