| | import io
|
| | import re
|
| | import wave
|
| |
|
| | import gradio as gr
|
| | import numpy as np
|
| |
|
| | from .fish_e2e import FishE2EAgent, FishE2EEventType
|
| | from .schema import ServeMessage, ServeTextPart, ServeVQPart
|
| |
|
| |
|
| | def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
| | buffer = io.BytesIO()
|
| |
|
| | with wave.open(buffer, "wb") as wav_file:
|
| | wav_file.setnchannels(channels)
|
| | wav_file.setsampwidth(bit_depth // 8)
|
| | wav_file.setframerate(sample_rate)
|
| |
|
| | wav_header_bytes = buffer.getvalue()
|
| | buffer.close()
|
| | return wav_header_bytes
|
| |
|
| |
|
| | class ChatState:
|
| | def __init__(self):
|
| | self.conversation = []
|
| | self.added_systext = False
|
| | self.added_sysaudio = False
|
| |
|
| | def get_history(self):
|
| | results = []
|
| | for msg in self.conversation:
|
| | results.append({"role": msg.role, "content": self.repr_message(msg)})
|
| |
|
| |
|
| | for i, msg in enumerate(results):
|
| | if msg["role"] == "assistant":
|
| | match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
|
| | if match and i > 0 and results[i - 1]["role"] == "user":
|
| |
|
| | results[i - 1]["content"] += "\n" + match.group(1)
|
| |
|
| | msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
|
| | return results
|
| |
|
| | def repr_message(self, msg: ServeMessage):
|
| | response = ""
|
| | for part in msg.parts:
|
| | if isinstance(part, ServeTextPart):
|
| | response += part.text
|
| | elif isinstance(part, ServeVQPart):
|
| | response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
|
| | return response
|
| |
|
| |
|
| | def clear_fn():
|
| | return [], ChatState(), None, None, None
|
| |
|
| |
|
| | async def process_audio_input(
|
| | sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
|
| | ):
|
| | if audio_input is None and not text_input:
|
| | raise gr.Error("No input provided")
|
| |
|
| | agent = FishE2EAgent()
|
| |
|
| |
|
| | if isinstance(audio_input, tuple):
|
| | sr, audio_data = audio_input
|
| | elif text_input:
|
| | sr = 44100
|
| | audio_data = None
|
| | else:
|
| | raise gr.Error("Invalid audio format")
|
| |
|
| | if isinstance(sys_audio_input, tuple):
|
| | sr, sys_audio_data = sys_audio_input
|
| | else:
|
| | sr = 44100
|
| | sys_audio_data = None
|
| |
|
| | def append_to_chat_ctx(
|
| | part: ServeTextPart | ServeVQPart, role: str = "assistant"
|
| | ) -> None:
|
| | if not state.conversation or state.conversation[-1].role != role:
|
| | state.conversation.append(ServeMessage(role=role, parts=[part]))
|
| | else:
|
| | state.conversation[-1].parts.append(part)
|
| |
|
| | if state.added_systext is False and sys_text_input:
|
| | state.added_systext = True
|
| | append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
|
| | if text_input:
|
| | append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
|
| | audio_data = None
|
| |
|
| | result_audio = b""
|
| | async for event in agent.stream(
|
| | sys_audio_data,
|
| | audio_data,
|
| | sr,
|
| | 1,
|
| | chat_ctx={
|
| | "messages": state.conversation,
|
| | "added_sysaudio": state.added_sysaudio,
|
| | },
|
| | ):
|
| | if event.type == FishE2EEventType.USER_CODES:
|
| | append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
|
| | elif event.type == FishE2EEventType.SPEECH_SEGMENT:
|
| | append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
|
| | yield state.get_history(), wav_chunk_header() + event.frame.data, None, None
|
| | elif event.type == FishE2EEventType.TEXT_SEGMENT:
|
| | append_to_chat_ctx(ServeTextPart(text=event.text))
|
| | yield state.get_history(), None, None, None
|
| |
|
| | yield state.get_history(), None, None, None
|
| |
|
| |
|
| | async def process_text_input(
|
| | sys_audio_input, sys_text_input, state: ChatState, text_input: str
|
| | ):
|
| | async for event in process_audio_input(
|
| | sys_audio_input, sys_text_input, None, state, text_input
|
| | ):
|
| | yield event
|
| |
|
| |
|
| | def create_demo():
|
| | with gr.Blocks() as demo:
|
| | state = gr.State(ChatState())
|
| |
|
| | with gr.Row():
|
| |
|
| | with gr.Column(scale=7):
|
| | chatbot = gr.Chatbot(
|
| | [],
|
| | elem_id="chatbot",
|
| | bubble_full_width=False,
|
| | height=600,
|
| | type="messages",
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | notes = gr.Markdown(
|
| | """
|
| | # Fish Agent
|
| | 1. This demo is Fish Audio's self-researh end-to-end language model, Fish Agent version 3B.
|
| | 2. You can find the code and weights in our official repo in [gitub](https://github.com/fishaudio/fish-speech) and [hugging face](https://huggingface.co/fishaudio/fish-agent-v0.1-3b), but the content is released under a CC BY-NC-SA 4.0 licence.
|
| | 3. The demo is an early alpha test version, the inference speed needs to be optimised.
|
| | # Features
|
| | 1. The model automatically integrates ASR and TTS parts, no need to plug-in other models, i.e., true end-to-end, not three-stage (ASR+LLM+TTS).
|
| | 2. The model can use reference audio to control the speech timbre.
|
| | 3. The model can generate speech with strong emotion.
|
| | """
|
| | )
|
| |
|
| |
|
| | with gr.Column(scale=3):
|
| | sys_audio_input = gr.Audio(
|
| | sources=["upload"],
|
| | type="numpy",
|
| | label="Give a timbre for your assistant",
|
| | )
|
| | sys_text_input = gr.Textbox(
|
| | label="What is your assistant's role?",
|
| | value="You are a voice assistant created by Fish Audio, offering end-to-end voice interaction for a seamless user experience. You are required to first transcribe the user's speech, then answer it in the following format: 'Question: [USER_SPEECH]\n\nAnswer: [YOUR_RESPONSE]\n'. You are required to use the following voice in this conversation.",
|
| | type="text",
|
| | )
|
| | audio_input = gr.Audio(
|
| | sources=["microphone"], type="numpy", label="Speak your message"
|
| | )
|
| |
|
| | text_input = gr.Textbox(label="Or type your message", type="text")
|
| |
|
| | output_audio = gr.Audio(
|
| | label="Assistant's Voice",
|
| | streaming=True,
|
| | autoplay=True,
|
| | interactive=False,
|
| | )
|
| |
|
| | send_button = gr.Button("Send", variant="primary")
|
| | clear_button = gr.Button("Clear")
|
| |
|
| |
|
| | audio_input.stop_recording(
|
| | process_audio_input,
|
| | inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
|
| | outputs=[chatbot, output_audio, audio_input, text_input],
|
| | show_progress=True,
|
| | )
|
| |
|
| | send_button.click(
|
| | process_text_input,
|
| | inputs=[sys_audio_input, sys_text_input, state, text_input],
|
| | outputs=[chatbot, output_audio, audio_input, text_input],
|
| | show_progress=True,
|
| | )
|
| |
|
| | text_input.submit(
|
| | process_text_input,
|
| | inputs=[sys_audio_input, sys_text_input, state, text_input],
|
| | outputs=[chatbot, output_audio, audio_input, text_input],
|
| | show_progress=True,
|
| | )
|
| |
|
| | clear_button.click(
|
| | clear_fn,
|
| | inputs=[],
|
| | outputs=[chatbot, state, audio_input, output_audio, text_input],
|
| | )
|
| |
|
| | return demo
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | demo = create_demo()
|
| | demo.launch(server_name="127.0.0.1", server_port=7860, share=True)
|
| |
|