Spaces:
Build error
Build error
| # App for summarizing the video/audio input and uploaded pdf file for joint summarization. | |
| import gradio as gr | |
| from transformers import pipeline | |
| import torch | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import torchaudio | |
| # get gpu device, if cuda available, then mps, last cpu | |
| # if torch.backends.mps.is_available(): | |
| # device = torch.device('mps') | |
| # else: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # torch mbp | |
| # Initialize the Whisper model pipeline | |
| asr_pipeline = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device) | |
| # for filler | |
| # load model and processor | |
| def transcribe_with_timestamps(audio): | |
| # Use the pipeline to transcribe the audio with timestamps | |
| result = asr_pipeline(audio, return_timestamps="word") | |
| return result["text"], result["chunks"] | |
| def filler_transcribe_with_timestamps(audio, filler=False): | |
| processor = WhisperProcessor.from_pretrained("openai/whisper-base") | |
| processor_filler = WhisperProcessor.from_pretrained("openai/whisper-base", normalize=False, return_timestamps="word") | |
| model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base") | |
| # load dummy dataset and read audio files | |
| sample, sr= torchaudio.load(audio) | |
| if sample.shape[0] > 1: | |
| sample = sample.mean(dim=0, keepdim=True) | |
| # if sr != 16000, resample to 16000 | |
| if sr != 16000: | |
| sample = torchaudio.transforms.Resample(sr, 16000)(sample) | |
| sr = 16000 | |
| sample = sample.to(device) | |
| input_features = processor(sample.squeeze(), sampling_rate=sr, return_tensors="pt").input_features | |
| # generate token ids | |
| # decode token ids to text with normalisation | |
| if filler: | |
| predicted_ids = model.generate(input_features, return_timestamps=True) | |
| # decode token ids to text without normalisation | |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=False) | |
| processor.decode(predicted_ids, skip_special_tokens=True, normalize=False, decode_with_timestamps=True) # decode token ids to text without normalisation | |
| else: | |
| predicted_ids = model.generate(input_features) | |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True, normalize=True) | |
| return transcription | |
| # print(transcription) | |
| # Use the pipeline to transcribe the audio with timestamps | |
| # return result["text"], result["chunks"] | |
| # # Set up Gradio interface | |
| # interface = gr.Interface( | |
| # fn=transcribe_with_timestamps, | |
| # inputs=gr.Audio(label="Upload audio", type="filepath"), | |
| # outputs=[gr.Textbox(label="Transcription"), gr.JSON(label="Timestamps")], | |
| # title="Academic presentation Agent", | |
| # ) | |
| Instructions = """ | |
| # Academic Presentation Agent | |
| Upload a video/audio file to transcribe the audio with timestamps. | |
| Also upload the pdf file to summarize the text. (Optional) | |
| The model will return the transcription and timestamps of the audio. | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(Instructions) | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_audio = gr.Audio(label="Upload audio", type="filepath") | |
| # Dummy PDF input | |
| input_pdf = gr.File(label="Upload PDF", type="filepath") | |
| with gr.Column(): | |
| with gr.Row(): | |
| transcription = gr.Textbox(label="Transcription") | |
| with gr.Row(): | |
| with gr.Accordion(open=False): | |
| timestamps = gr.JSON(label="Timestamps") | |
| with gr.Row(): | |
| transcrible_button = gr.Button("Transcribe") | |
| # ASR summary | |
| ASR_summary = [transcription, timestamps] | |
| transcrible_button.click(transcribe_with_timestamps, input_audio, outputs=ASR_summary) | |
| with gr.Row(): | |
| analyze_button = gr.Button("Analyze") | |
| # with gr.Column(): | |
| # with gr.Row(): | |
| # input_audio = gr.Audio(label="Upload audio", type="filepath") | |
| # transcription = gr.Textbox(label="Transcription") | |
| # timestamps = gr.JSON(label="Timestamps") | |
| # with gr.Row(): | |
| # transcrible_button_filler = gr.Button("Transcribe_filler") | |
| # # ASR summary | |
| # ASR_summary = [transcription, timestamps] | |
| # transcrible_button_filler.click(filler_transcribe_with_timestamps, input_audio, outputs=transcription) | |
| # Launch the Gradio app | |
| demo.launch(share=False) | |