Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoProcessor, CsmForConditionalGeneration | |
| from dia.model import Dia | |
| from pyannote.audio import Pipeline as VAD | |
| import torch, numpy as np | |
| # Load models | |
| ultra_proc = AutoProcessor.from_pretrained("fixie-ai/ultravox-v0_4") | |
| ultra_model = CsmForConditionalGeneration.from_pretrained("fixie-ai/ultravox-v0_4", device_map="auto", torch_dtype=torch.float16) | |
| ser = AutoProcessor.from_pretrained("r-f/wav2vec-english-speech-emotion-recognition") | |
| ser_model = torch.hub.load("jonatasgrosman/wav2vec2-large-xlsr-53-english", "wav2vec2_large_xlsr", pretrained=True).to("cuda") | |
| diff_pipe = torch.hub.load("teticio/audio-diffusion-instrumental-hiphop-256", "audio_diffusion").to("cuda") | |
| rvq = torch.hub.load("ibm/DAC.speech.v1.0", "DAC_speech_v1_0").to("cuda") | |
| vad = VAD.from_pretrained("pyannote/voice-activity-detection") | |
| dia = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16") | |
| def process(audio): | |
| # VAD | |
| speech = vad({"waveform": audio["array"], "sample_rate": audio["sampling_rate"]}) | |
| # RVQ encode/decode | |
| codes = rvq.encode(audio["array"]) | |
| dec_audio = rvq.decode(codes) | |
| # Emotion | |
| emo_inputs = ser(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt") | |
| emotion = ser_model(**emo_inputs).logits.argmax(-1).item() | |
| # Ultravox generation | |
| inputs = ultra_proc(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").to("cuda") | |
| speech_out = ultra_model.generate(**inputs, output_audio=True) | |
| # Diffuse and clone voice | |
| audio_diff = diff_pipe(speech_out.audio).audios[0] | |
| # TTS | |
| text = f"[S1][emotion={emotion}]" + " ".join(["..."]) # placeholder | |
| dia_audio = dia.generate(text) | |
| # Normalize | |
| dia_audio = dia_audio / np.max(np.abs(dia_audio)) * 0.95 | |
| return 44100, dia_audio | |
| with gr.Blocks() as demo: | |
| state = gr.State([]) | |
| audio_in = gr.Audio(source="microphone", type="numpy") | |
| chat = gr.Chatbot() | |
| record = gr.Button("Record") | |
| record.click(process, inputs=audio_in, outputs=[audio_in]).then( | |
| lambda a: chat.update(value=[("User", ""), ("AI", "")]), | |
| ) | |
| demo.queue(concurrency_limit=20, max_size=50).launch() | |