Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import torchaudio | |
| import numpy as np | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Whisper | |
| base_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device) | |
| whisper_model = PeftModel.from_pretrained(base_model, "Vardis/Whisper-Small-Greek").to(device) | |
| processor = WhisperProcessor.from_pretrained("Vardis/Whisper-Small-Greek") | |
| whisper_model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="el", task="transcribe") | |
| # GPT-2 | |
| lm_tokenizer = AutoTokenizer.from_pretrained("Vardis/Medical_Speech_Greek_GPT2") | |
| base_model = AutoModelForCausalLM.from_pretrained("lighteternal/gpt2-finetuned-greek") | |
| lm_model = PeftModel.from_pretrained(base_model, "Vardis/Medical_Speech_Greek_GPT2").to(device) | |
| def calculate_perplexity(sentence, model, tokenizer, device): | |
| model.eval() | |
| inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device) | |
| labels = inputs.input_ids.clone() | |
| labels[labels == tokenizer.pad_token_id] = -100 | |
| with torch.no_grad(): | |
| loss = model(**inputs, labels=labels).loss.item() | |
| return loss | |
| def get_whisper_transcriptions(audio_array, sr, n_best=5): | |
| input_features = processor(audio_array, sampling_rate=sr, return_tensors="pt").input_features.to(device) | |
| beam_outputs = whisper_model.generate( | |
| input_features, | |
| num_beams=n_best, | |
| num_return_sequences=n_best, | |
| return_dict_in_generate=True, | |
| max_length=225 | |
| ) | |
| n_best_transcriptions = processor.batch_decode(beam_outputs.sequences, skip_special_tokens=True) | |
| return n_best_transcriptions | |
| def rerank_hypotheses(hypotheses, model, tokenizer, device): | |
| perplexities = [calculate_perplexity(hyp, model, tokenizer, device) for hyp in hypotheses] | |
| best_index = perplexities.index(min(perplexities)) | |
| return hypotheses[best_index] | |
| def transcribe_and_rerank(audio): | |
| sr, audio_array = audio | |
| # Resample & convert to float32 | |
| target_sr = 16000 | |
| waveform = torch.tensor(audio_array, dtype=torch.float32) | |
| if sr != target_sr: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) | |
| waveform = resampler(waveform) | |
| sr = target_sr | |
| audio_array = waveform.numpy().flatten() | |
| if audio_array.dtype != "float32": | |
| audio_array = audio_array.astype("float32") / 32768.0 | |
| hypotheses = get_whisper_transcriptions(audio_array, sr, n_best=5) | |
| best = rerank_hypotheses(hypotheses, lm_model, lm_tokenizer, device) | |
| return best | |
| def transcribe_and_rerank_gr(audio): | |
| try: | |
| sr, audio_array = audio | |
| # stereo -> mono | |
| if len(audio_array.shape) > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| # dtype -> float32 | |
| if audio_array.dtype != np.float32: | |
| audio_array = audio_array.astype(np.float32) / 32768.0 | |
| return transcribe_and_rerank((sr, audio_array)) | |
| except Exception as e: | |
| return f"Error: {e}" | |
| demo = gr.Interface( | |
| fn=transcribe_and_rerank_gr, | |
| inputs=gr.Audio(sources=["microphone", "upload"], type="numpy"), | |
| outputs="text", | |
| title="Medical Dictation (Whisper + GPT-2)" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |