import torch import gradio as gr from transformers import WhisperForConditionalGeneration, WhisperProcessor from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import torchaudio import numpy as np device = "cuda" if torch.cuda.is_available() else "cpu" # Whisper base_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device) whisper_model = PeftModel.from_pretrained(base_model, "Vardis/Whisper-Small-Greek").to(device) processor = WhisperProcessor.from_pretrained("Vardis/Whisper-Small-Greek") whisper_model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="el", task="transcribe") # GPT-2 lm_tokenizer = AutoTokenizer.from_pretrained("Vardis/Medical_Speech_Greek_GPT2") base_model = AutoModelForCausalLM.from_pretrained("lighteternal/gpt2-finetuned-greek") lm_model = PeftModel.from_pretrained(base_model, "Vardis/Medical_Speech_Greek_GPT2").to(device) def calculate_perplexity(sentence, model, tokenizer, device): model.eval() inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device) labels = inputs.input_ids.clone() labels[labels == tokenizer.pad_token_id] = -100 with torch.no_grad(): loss = model(**inputs, labels=labels).loss.item() return loss def get_whisper_transcriptions(audio_array, sr, n_best=5): input_features = processor(audio_array, sampling_rate=sr, return_tensors="pt").input_features.to(device) beam_outputs = whisper_model.generate( input_features, num_beams=n_best, num_return_sequences=n_best, return_dict_in_generate=True, max_length=225 ) n_best_transcriptions = processor.batch_decode(beam_outputs.sequences, skip_special_tokens=True) return n_best_transcriptions def rerank_hypotheses(hypotheses, model, tokenizer, device): perplexities = [calculate_perplexity(hyp, model, tokenizer, device) for hyp in hypotheses] best_index = perplexities.index(min(perplexities)) return hypotheses[best_index] def transcribe_and_rerank(audio): sr, audio_array = audio # Resample & convert to float32 target_sr = 16000 waveform = torch.tensor(audio_array, dtype=torch.float32) if sr != target_sr: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) waveform = resampler(waveform) sr = target_sr audio_array = waveform.numpy().flatten() if audio_array.dtype != "float32": audio_array = audio_array.astype("float32") / 32768.0 hypotheses = get_whisper_transcriptions(audio_array, sr, n_best=5) best = rerank_hypotheses(hypotheses, lm_model, lm_tokenizer, device) return best def transcribe_and_rerank_gr(audio): try: sr, audio_array = audio # stereo -> mono if len(audio_array.shape) > 1: audio_array = audio_array.mean(axis=1) # dtype -> float32 if audio_array.dtype != np.float32: audio_array = audio_array.astype(np.float32) / 32768.0 return transcribe_and_rerank((sr, audio_array)) except Exception as e: return f"Error: {e}" demo = gr.Interface( fn=transcribe_and_rerank_gr, inputs=gr.Audio(sources=["microphone", "upload"], type="numpy"), outputs="text", title="Medical Dictation (Whisper + GPT-2)" ) if __name__ == "__main__": demo.launch()