Spaces:
Sleeping
Sleeping
| from faster_whisper import WhisperModel | |
| import numpy as np | |
| import scipy.signal | |
| import spaces | |
| model_size = "base.en" | |
| model = WhisperModel(model_size, device="cpu", compute_type="float32") | |
| def whisper_process_audio(audio_file): | |
| sample_rate, audio_data = audio_file | |
| if audio_data.ndim > 1 and audio_data.shape[1] > 1: | |
| # Mix stereo channels by averaging them | |
| audio_data = np.mean(audio_data, axis=1) | |
| #normalise audio data | |
| np_audio_float32 = audio_data.astype(np.float32) / 32768.0 | |
| np_audio_16k = scipy.signal.resample(np_audio_float32, int(len(np_audio_float32) * 16000 / sample_rate)) | |
| return np_audio_16k | |
| def transcribe(audio): | |
| segments, info = model.transcribe(whisper_process_audio(audio), beam_size=5, language='en') | |
| text = "".join([segment.text for segment in segments]) | |
| return text | |
| from kokoro import KModel, KPipeline | |
| import os | |
| import random | |
| import torch | |
| import numpy as np | |
| import kokoro | |
| import misaki | |
| kkmodel = KModel().to('cuda').eval() | |
| pipeline = KPipeline(lang_code='a', model=False) | |
| def generate_tts(text, voice='af_heart', speed=1): | |
| pack = pipeline.load_voice(voice) | |
| audio_chunks = [] | |
| for _, ps, _ in pipeline(text, voice, speed): | |
| ref_s = pack[len(ps)-1] | |
| try: | |
| audio = kkmodel(ps, ref_s, speed) | |
| audio_chunks.append(audio.numpy()) | |
| except: | |
| print("lol there was an issue idk") | |
| # yield 24000, audio.numpy() | |
| if audio_chunks: | |
| concatenated_audio = np.concatenate(audio_chunks) | |
| print(concatenated_audio.shape) | |
| return 24000, concatenated_audio | |
| else: | |
| return 24000, np.array([]) | |
| import io | |
| import os | |
| import time | |
| from dataclasses import dataclass, field | |
| from multiprocessing import freeze_support | |
| # import groq | |
| import gradio as gr | |
| import numpy as np | |
| from vllm import LLM | |
| def initialize_model(): | |
| """Initialize the model - called after proper multiprocessing setup""" | |
| llama3_model_id = "shuyuej/Llama-3.2-1B-Instruct-GPTQ" | |
| llama3_pipe = LLM( | |
| model=llama3_model_id, | |
| quantization="gptq", | |
| gpu_memory_utilization=0.5, | |
| max_model_len=1024 | |
| ) | |
| return llama3_pipe | |
| # Global variable to hold the model | |
| llama3_pipe = None | |
| default_sys_prompt = """You are a helpful chatbot. You respond very conversationally, and help the end user as best as you can.""" | |
| def llama_QA(message_history, system_prompt: str): | |
| """ stupid func for asking llama a question and then getting an answer | |
| inputs: - input_question [str]: question for llama to answer | |
| outputs: - response [str]: llama's response | |
| """ | |
| global llama3_pipe | |
| # set max gen to 512 | |
| sampling_params = llama3_pipe.get_default_sampling_params() | |
| sampling_params.max_tokens = 512 | |
| input_message_history = [{"role": "system", "content": system_prompt}] | |
| input_message_history.extend(message_history) | |
| outputs = llama3_pipe.chat(input_message_history, sampling_params)[0].outputs[0].text | |
| # message_history.append({"role": "assistant", "content": outputs}) | |
| return outputs | |
| class AppState: | |
| conversation: list = field(default_factory=list) | |
| stopped: bool = False | |
| model_outs: any = None | |
| def process_audio(audio: tuple, state: AppState): | |
| return audio, state | |
| def response(state: AppState, audio: tuple, system_prompt): | |
| if not audio: | |
| return state, state.conversation, None | |
| # Transcribe the audio file | |
| transcription = transcribe(audio) | |
| if transcription: | |
| if transcription.startswith("Error"): | |
| transcription = "Error in audio transcription." | |
| # Append the user's message in the proper format | |
| state.conversation.append({"role": "user", "content": transcription}) | |
| # Generate assistant response | |
| assistant_message = llama_QA(state.conversation, system_prompt) | |
| # Append the assistant's message in the proper format | |
| state.conversation.append({"role": "assistant", "content": assistant_message}) | |
| # Generate TTS audio | |
| response_audio = generate_tts(assistant_message) | |
| print(state.conversation) | |
| return state, state.conversation, response_audio | |
| def start_recording_user(state: AppState): | |
| return None | |
| js = """ | |
| async function main() { | |
| const script1 = document.createElement("script"); | |
| script1.src = "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.14.0/dist/ort.js"; | |
| document.head.appendChild(script1) | |
| const script2 = document.createElement("script"); | |
| script2.onload = async () => { | |
| console.log("vad loaded") ; | |
| var record = document.querySelector('.record-button'); | |
| record.textContent = "Just Start Talking!" | |
| record.style = "width: fit-content; padding-right: 0.5vw;" | |
| const myvad = await vad.MicVAD.new({ | |
| onSpeechStart: () => { | |
| var record = document.querySelector('.record-button'); | |
| var player = document.querySelector('#streaming-out audio'); | |
| if (record != null && (player == null || player.paused || player.ended)) { | |
| console.log("Starting recording", record); | |
| record.click(); | |
| } else { | |
| console.log("Audio still playing, not starting recording"); | |
| } | |
| }, | |
| onSpeechEnd: (audio) => { | |
| var stop = document.querySelector('.stop-button'); | |
| if (stop != null) { | |
| console.log("Stopping recording", stop); | |
| stop.click(); | |
| } | |
| } | |
| }) | |
| myvad.start() | |
| } | |
| script2.src = "https://cdn.jsdelivr.net/npm/@ricky0123/vad-web@0.0.7/dist/bundle.min.js"; | |
| script1.onload = () => { | |
| console.log("onnx loaded") | |
| document.head.appendChild(script2) | |
| }; | |
| } | |
| """ | |
| js_reset = """ | |
| () => { | |
| var record = document.querySelector('.record-button'); | |
| record.textContent = "Just Start Talking!" | |
| record.style = "width: fit-content; padding-right: 0.5vw;" | |
| } | |
| """ | |
| def create_demo(): | |
| """Create and return the Gradio demo interface""" | |
| with gr.Blocks(js=js) as demo: | |
| with gr.Row(): | |
| system_prompt = gr.Textbox( | |
| value=default_sys_prompt, | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| input_audio = gr.Audio( | |
| label="Input Audio", | |
| sources=["microphone"], | |
| type="numpy", | |
| streaming=False, | |
| ) | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(label="Conversation", type="messages") | |
| with gr.Row(): | |
| output_audio = gr.Audio( | |
| label="Assistant Audio", | |
| interactive=False, | |
| autoplay=True, | |
| elem_id="streaming-out" | |
| ) | |
| state = gr.State(value=AppState()) | |
| stream = input_audio.start_recording( | |
| process_audio, | |
| [input_audio, state], | |
| [input_audio, state], | |
| ) | |
| respond = input_audio.stop_recording( | |
| response, | |
| inputs=[state, input_audio, system_prompt], | |
| outputs=[state, chatbot, output_audio] | |
| ) | |
| restart = respond.then( | |
| start_recording_user, | |
| [state], | |
| [input_audio] | |
| ).then( | |
| lambda state: state, | |
| state, | |
| state, | |
| js=js_reset | |
| ) | |
| cancel = gr.Button("New Conversation", variant="stop") | |
| cancel.click( | |
| lambda: (AppState(), gr.Audio(recording=False)), | |
| None, | |
| [state, input_audio], | |
| cancels=[respond, restart], | |
| ) | |
| return demo | |
| demo = create_demo() | |
| demo.launch() |