Token / app.py
don0726's picture
Update app.py
2e3a146 verified
Raw
History Blame Contribute Delete
5.38 kB
import gradio as gr
import numpy as np
import torch
import struct
from transformers import (
WhisperForConditionalGeneration,
WhisperTokenizer,
WhisperProcessor,
modeling_outputs,
)
print("Loading Whisper...")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", predict_timestamps=True)
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model.eval()
print("Ready!")
TIME_STEPS = 1500
HIDDEN_SIZE = 768
TOTAL_FLOATS = TIME_STEPS * HIDDEN_SIZE # 1152000
TOTAL_BYTES = TOTAL_FLOATS * 4 # 4608000
TIME_PRECISION = 0.02 # each encoder frame = 0.02 seconds
def to_srt_time(sec):
if sec is None: sec = 0
h = int(sec // 3600)
m = int((sec % 3600) // 60)
s = int(sec % 60)
ms = int((sec % 1) * 1000)
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
def decode(bin_file, language):
if bin_file is None:
return "Upload encoder_output.bin", "", ""
with open(bin_file, "rb") as f:
data = f.read()
if len(data) != TOTAL_BYTES:
return (
f"Wrong size. Expected {TOTAL_BYTES} bytes, got {len(data)} bytes."
), "", ""
# Load encoder hidden states — big-endian (Java DataOutputStream)
floats = struct.unpack(f">{TOTAL_FLOATS}f", data)
hidden = np.array(floats, dtype=np.float32).reshape(1, TIME_STEPS, HIDDEN_SIZE)
hidden_t = torch.from_numpy(hidden)
encoder_outputs = modeling_outputs.BaseModelOutput(
last_hidden_state=hidden_t
)
forced_decoder_ids = processor.get_decoder_prompt_ids(
language=language.lower(),
task="transcribe"
)
with torch.no_grad():
out = model.generate(
encoder_outputs=encoder_outputs,
forced_decoder_ids=forced_decoder_ids,
max_new_tokens=444,
return_dict_in_generate=True,
output_attentions=True, # needed for cross-attention timestamps
return_token_timestamps=True, # whisper computes timestamps from cross-attention
)
token_ids = out["sequences"][0]
token_ts = out["token_timestamps"][0] # real timestamps from cross-attention weights
# Plain text
plain = tokenizer.decode(token_ids, skip_special_tokens=True).strip()
# Group BPE tokens into words correctly
SPECIAL_START = 50257
tokens_list = token_ids.tolist()
ts_list = token_ts.tolist()
# Collect only text tokens
text_tokens = []
text_ts = []
for tok, ts in zip(tokens_list, ts_list):
if tok < SPECIAL_START:
text_tokens.append(tok)
text_ts.append(ts)
if not text_tokens:
return plain, "No words found", ""
# Get BPE pieces
bpe_pieces = tokenizer.convert_ids_to_tokens(text_tokens)
# Group into words — Whisper BPE uses Ġ (\u0120) as word boundary marker
words_grouped = []
cur_tokens = []
cur_start = None
cur_end = None
for i, (piece, ts) in enumerate(zip(bpe_pieces, text_ts)):
next_ts = text_ts[i+1] if i+1 < len(text_ts) else ts + 0.4
is_boundary = (
piece.startswith('\u0120') or # Ġ = word start in BPE
piece.startswith(' ') or
cur_start is None
)
if is_boundary and cur_tokens:
# Decode grouped tokens together — fixes broken UTF-8 for Hindi
word_text = tokenizer.decode(
tokenizer.convert_tokens_to_ids(cur_tokens)
).strip()
if word_text:
words_grouped.append((word_text, cur_start, cur_end))
cur_tokens = []
cur_start = None
cur_tokens.append(piece)
if cur_start is None:
cur_start = ts
cur_end = next_ts
# Flush last word
if cur_tokens:
word_text = tokenizer.decode(
tokenizer.convert_tokens_to_ids(cur_tokens)
).strip()
if word_text:
words_grouped.append((word_text, cur_start, cur_end))
# Build outputs
ts_lines = []
srt_lines = []
for i, (word, start, end) in enumerate(words_grouped, 1):
ts_lines.append(f"{to_srt_time(start)} --> {to_srt_time(end)} {word}")
srt_lines += [str(i), f"{to_srt_time(start)} --> {to_srt_time(end)}", word, ""]
return plain, "\n".join(ts_lines), "\n".join(srt_lines)
with gr.Blocks(title="Whisper Decoder") as demo:
gr.Markdown("""
# Whisper Word-Level Decoder
Upload `encoder_output.bin` from your Android app.
Word timestamps computed from **cross-attention weights** — no audio needed here.
""")
with gr.Row():
file_input = gr.File(label="Upload encoder_output.bin", file_types=[".bin"])
language = gr.Dropdown(
choices=["Hindi","English","French","German",
"Spanish","Chinese","Japanese","Arabic","Korean","Russian"],
value="Hindi", label="Language"
)
btn = gr.Button("Decode", variant="primary")
out_plain = gr.Textbox(label="Plain Text", lines=3)
out_ts = gr.Textbox(label="Word Timestamps", lines=12)
out_srt = gr.Textbox(label="SRT Format", lines=12)
btn.click(fn=decode, inputs=[file_input, language],
outputs=[out_plain, out_ts, out_srt])
demo.launch()