Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from groq import Groq | |
| from transformers import pipeline | |
| from transformers.utils import is_flash_attn_2_available | |
| transcriber = pipeline("automatic-speech-recognition", | |
| model="openai/whisper-large-v3", | |
| torch_dtype=torch.float16, | |
| device="cuda:0", | |
| model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}, | |
| ) | |
| groq_client = Groq(api_key=os.getenv('GROQ_API_KEY')) | |
| def transcribe(stream, new_chunk): | |
| """ | |
| Transcribes using whisper | |
| """ | |
| sr, y = new_chunk | |
| # Convert stereo to mono if necessary | |
| if y.ndim == 2 and y.shape[1] == 2: | |
| y = y.mean(axis=1) # Averaging both channels if stereo | |
| y = y.astype(np.float32) | |
| # Normalization | |
| y /= np.max(np.abs(y)) | |
| if stream is not None: | |
| stream = np.concatenate([stream, y]) | |
| else: | |
| stream = y | |
| return stream, transcriber({"sampling_rate": sr, "raw": stream})["text"] | |
| def autocomplete(text): | |
| """ | |
| Autocomplete the text using Gemma. | |
| """ | |
| if text != "": | |
| response = groq_client.chat.completions.create( | |
| model='gemma-7b-it', | |
| messages=[{"role": "system", "content": "You are a friendly assistant named Gemma."}, | |
| {"role": "user", "content": text}] | |
| ) | |
| return response.choices[0].message.content | |
| def process_audio(input_audio, new_chunk): | |
| """ | |
| Process the audio input by transcribing and completing the sentences. | |
| Accumulate results to return to Gradio interface. | |
| """ | |
| stream, transcription = transcribe(input_audio, new_chunk) | |
| text = autocomplete(transcription) | |
| print (transcription, text) | |
| return stream, text | |
| demo = gr.Interface( | |
| fn = process_audio, | |
| inputs = ["state", gr.Audio(sources=["microphone"], streaming=True)], | |
| outputs = ["state", gr.Markdown()], | |
| title="Hey Gemma ☎️", | |
| description="Powered by [whisper-base-en](https://huggingface.co/openai/whisper-base.en), and [gemma-7b-it](https://huggingface.co/google/gemma-7b-it) (via [Groq](https://groq.com/))", | |
| live=True, | |
| allow_flagging="never" | |
| ) | |
| demo.launch() | |