Vardis's picture
Update app.py
e089813 verified
import torch
import gradio as gr
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torchaudio
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
# Whisper
base_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
whisper_model = PeftModel.from_pretrained(base_model, "Vardis/Whisper-Small-Greek").to(device)
processor = WhisperProcessor.from_pretrained("Vardis/Whisper-Small-Greek")
whisper_model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="el", task="transcribe")
# GPT-2
lm_tokenizer = AutoTokenizer.from_pretrained("Vardis/Medical_Speech_Greek_GPT2")
base_model = AutoModelForCausalLM.from_pretrained("lighteternal/gpt2-finetuned-greek")
lm_model = PeftModel.from_pretrained(base_model, "Vardis/Medical_Speech_Greek_GPT2").to(device)
def calculate_perplexity(sentence, model, tokenizer, device):
model.eval()
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
labels = inputs.input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
with torch.no_grad():
loss = model(**inputs, labels=labels).loss.item()
return loss
def get_whisper_transcriptions(audio_array, sr, n_best=5):
input_features = processor(audio_array, sampling_rate=sr, return_tensors="pt").input_features.to(device)
beam_outputs = whisper_model.generate(
input_features,
num_beams=n_best,
num_return_sequences=n_best,
return_dict_in_generate=True,
max_length=225
)
n_best_transcriptions = processor.batch_decode(beam_outputs.sequences, skip_special_tokens=True)
return n_best_transcriptions
def rerank_hypotheses(hypotheses, model, tokenizer, device):
perplexities = [calculate_perplexity(hyp, model, tokenizer, device) for hyp in hypotheses]
best_index = perplexities.index(min(perplexities))
return hypotheses[best_index]
def transcribe_and_rerank(audio):
sr, audio_array = audio
# Resample & convert to float32
target_sr = 16000
waveform = torch.tensor(audio_array, dtype=torch.float32)
if sr != target_sr:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
waveform = resampler(waveform)
sr = target_sr
audio_array = waveform.numpy().flatten()
if audio_array.dtype != "float32":
audio_array = audio_array.astype("float32") / 32768.0
hypotheses = get_whisper_transcriptions(audio_array, sr, n_best=5)
best = rerank_hypotheses(hypotheses, lm_model, lm_tokenizer, device)
return best
def transcribe_and_rerank_gr(audio):
try:
sr, audio_array = audio
# stereo -> mono
if len(audio_array.shape) > 1:
audio_array = audio_array.mean(axis=1)
# dtype -> float32
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32) / 32768.0
return transcribe_and_rerank((sr, audio_array))
except Exception as e:
return f"Error: {e}"
demo = gr.Interface(
fn=transcribe_and_rerank_gr,
inputs=gr.Audio(sources=["microphone", "upload"], type="numpy"),
outputs="text",
title="Medical Dictation (Whisper + GPT-2)"
)
if __name__ == "__main__":
demo.launch()