| import base64
|
| import ctypes
|
| import io
|
| import json
|
| import os
|
| import struct
|
| from dataclasses import dataclass
|
| from enum import Enum
|
| from typing import AsyncGenerator, Union
|
|
|
| import httpx
|
| import numpy as np
|
| import ormsgpack
|
| import soundfile as sf
|
|
|
| from .schema import (
|
| ServeMessage,
|
| ServeRequest,
|
| ServeTextPart,
|
| ServeVQGANDecodeRequest,
|
| ServeVQGANEncodeRequest,
|
| ServeVQPart,
|
| )
|
|
|
|
|
| class CustomAudioFrame:
|
| def __init__(self, data, sample_rate, num_channels, samples_per_channel):
|
| if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
|
| ctypes.c_int16
|
| ):
|
| raise ValueError(
|
| "data length must be >= num_channels * samples_per_channel * sizeof(int16)"
|
| )
|
|
|
| self._data = bytearray(data)
|
| self._sample_rate = sample_rate
|
| self._num_channels = num_channels
|
| self._samples_per_channel = samples_per_channel
|
|
|
| @property
|
| def data(self):
|
| return memoryview(self._data).cast("h")
|
|
|
| @property
|
| def sample_rate(self):
|
| return self._sample_rate
|
|
|
| @property
|
| def num_channels(self):
|
| return self._num_channels
|
|
|
| @property
|
| def samples_per_channel(self):
|
| return self._samples_per_channel
|
|
|
| @property
|
| def duration(self):
|
| return self.samples_per_channel / self.sample_rate
|
|
|
| def __repr__(self):
|
| return (
|
| f"CustomAudioFrame(sample_rate={self.sample_rate}, "
|
| f"num_channels={self.num_channels}, "
|
| f"samples_per_channel={self.samples_per_channel}, "
|
| f"duration={self.duration:.3f})"
|
| )
|
|
|
|
|
| class FishE2EEventType(Enum):
|
| SPEECH_SEGMENT = 1
|
| TEXT_SEGMENT = 2
|
| END_OF_TEXT = 3
|
| END_OF_SPEECH = 4
|
| ASR_RESULT = 5
|
| USER_CODES = 6
|
|
|
|
|
| @dataclass
|
| class FishE2EEvent:
|
| type: FishE2EEventType
|
| frame: np.ndarray = None
|
| text: str = None
|
| vq_codes: list[list[int]] = None
|
|
|
|
|
| client = httpx.AsyncClient(
|
| timeout=None,
|
| limits=httpx.Limits(
|
| max_connections=None,
|
| max_keepalive_connections=None,
|
| keepalive_expiry=None,
|
| ),
|
| )
|
|
|
|
|
| class FishE2EAgent:
|
| def __init__(self):
|
| self.llm_url = "http://localhost:8080/v1/chat"
|
| self.vqgan_url = "http://localhost:8080"
|
| self.client = httpx.AsyncClient(timeout=None)
|
|
|
| async def get_codes(self, audio_data, sample_rate):
|
| audio_buffer = io.BytesIO()
|
| sf.write(audio_buffer, audio_data, sample_rate, format="WAV")
|
| audio_buffer.seek(0)
|
|
|
| encode_request = ServeVQGANEncodeRequest(audios=[audio_buffer.read()])
|
| encode_request_bytes = ormsgpack.packb(
|
| encode_request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
| )
|
| encode_response = await self.client.post(
|
| f"{self.vqgan_url}/v1/vqgan/encode",
|
| data=encode_request_bytes,
|
| headers={"Content-Type": "application/msgpack"},
|
| )
|
| encode_response_data = ormsgpack.unpackb(encode_response.content)
|
| codes = encode_response_data["tokens"][0]
|
| return codes
|
|
|
| async def stream(
|
| self,
|
| system_audio_data: np.ndarray | None,
|
| user_audio_data: np.ndarray | None,
|
| sample_rate: int,
|
| num_channels: int,
|
| chat_ctx: dict | None = None,
|
| ) -> AsyncGenerator[bytes, None]:
|
|
|
| if system_audio_data is not None:
|
| sys_codes = await self.get_codes(system_audio_data, sample_rate)
|
| else:
|
| sys_codes = None
|
| if user_audio_data is not None:
|
| user_codes = await self.get_codes(user_audio_data, sample_rate)
|
|
|
| if chat_ctx is None:
|
| sys_parts = [
|
| ServeTextPart(
|
| text='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nAnswer: [你的回答]\n"。'
|
| ),
|
| ]
|
| if system_audio_data is not None:
|
| sys_parts.append(ServeVQPart(codes=sys_codes))
|
| chat_ctx = {
|
| "messages": [
|
| ServeMessage(
|
| role="system",
|
| parts=sys_parts,
|
| ),
|
| ],
|
| }
|
| else:
|
| if chat_ctx["added_sysaudio"] is False and sys_codes:
|
| chat_ctx["added_sysaudio"] = True
|
| chat_ctx["messages"][0].parts.append(ServeVQPart(codes=sys_codes))
|
|
|
| prev_messages = chat_ctx["messages"].copy()
|
| if user_audio_data is not None:
|
| yield FishE2EEvent(
|
| type=FishE2EEventType.USER_CODES,
|
| vq_codes=user_codes,
|
| )
|
| else:
|
| user_codes = None
|
|
|
| request = ServeRequest(
|
| messages=prev_messages
|
| + (
|
| [
|
| ServeMessage(
|
| role="user",
|
| parts=[ServeVQPart(codes=user_codes)],
|
| )
|
| ]
|
| if user_codes
|
| else []
|
| ),
|
| streaming=True,
|
| num_samples=1,
|
| )
|
|
|
|
|
| buffer = b""
|
| vq_codes = []
|
| current_vq = False
|
|
|
| async def decode_send():
|
| nonlocal current_vq
|
| nonlocal vq_codes
|
|
|
| data = np.concatenate(vq_codes, axis=1).tolist()
|
|
|
| decode_request = ServeVQGANDecodeRequest(tokens=[data])
|
| decode_response = await self.client.post(
|
| f"{self.vqgan_url}/v1/vqgan/decode",
|
| data=ormsgpack.packb(
|
| decode_request,
|
| option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
| ),
|
| headers={"Content-Type": "application/msgpack"},
|
| )
|
| decode_data = ormsgpack.unpackb(decode_response.content)
|
|
|
|
|
| audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
|
| audio_data = (audio_data * 32768).astype(np.int16).tobytes()
|
|
|
| audio_frame = CustomAudioFrame(
|
| data=audio_data,
|
| samples_per_channel=len(audio_data) // 2,
|
| sample_rate=44100,
|
| num_channels=1,
|
| )
|
| yield FishE2EEvent(
|
| type=FishE2EEventType.SPEECH_SEGMENT,
|
| frame=audio_frame,
|
| vq_codes=data,
|
| )
|
|
|
| current_vq = False
|
| vq_codes = []
|
|
|
| async with self.client.stream(
|
| "POST",
|
| self.llm_url,
|
| data=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
| headers={"Content-Type": "application/msgpack"},
|
| ) as response:
|
|
|
| async for chunk in response.aiter_bytes():
|
| buffer += chunk
|
|
|
| while len(buffer) >= 4:
|
| read_length = struct.unpack("I", buffer[:4])[0]
|
| if len(buffer) < 4 + read_length:
|
| break
|
|
|
| body = buffer[4 : 4 + read_length]
|
| buffer = buffer[4 + read_length :]
|
| data = ormsgpack.unpackb(body)
|
|
|
| if data["delta"] and data["delta"]["part"]:
|
| if current_vq and data["delta"]["part"]["type"] == "text":
|
| async for event in decode_send():
|
| yield event
|
| if data["delta"]["part"]["type"] == "text":
|
| yield FishE2EEvent(
|
| type=FishE2EEventType.TEXT_SEGMENT,
|
| text=data["delta"]["part"]["text"],
|
| )
|
| elif data["delta"]["part"]["type"] == "vq":
|
| vq_codes.append(np.array(data["delta"]["part"]["codes"]))
|
| current_vq = True
|
|
|
| if current_vq and vq_codes:
|
| async for event in decode_send():
|
| yield event
|
|
|
| yield FishE2EEvent(type=FishE2EEventType.END_OF_TEXT)
|
| yield FishE2EEvent(type=FishE2EEventType.END_OF_SPEECH)
|
|
|
|
|
|
|
| async def main():
|
| import torchaudio
|
|
|
| agent = FishE2EAgent()
|
|
|
|
|
| with open("uz_story_en.m4a", "rb") as f:
|
| audio_data = f.read()
|
|
|
| audio_data, sample_rate = torchaudio.load("uz_story_en.m4a")
|
| audio_data = (audio_data.numpy() * 32768).astype(np.int16)
|
|
|
| stream = agent.stream(audio_data, sample_rate, 1)
|
| if os.path.exists("audio_segment.wav"):
|
| os.remove("audio_segment.wav")
|
|
|
| async for event in stream:
|
| if event.type == FishE2EEventType.SPEECH_SEGMENT:
|
|
|
| with open("audio_segment.wav", "ab+") as f:
|
| f.write(event.frame.data)
|
| elif event.type == FishE2EEventType.ASR_RESULT:
|
| print(event.text, flush=True)
|
| elif event.type == FishE2EEventType.TEXT_SEGMENT:
|
| print(event.text, flush=True, end="")
|
| elif event.type == FishE2EEventType.END_OF_TEXT:
|
| print("\nEnd of text reached.")
|
| elif event.type == FishE2EEventType.END_OF_SPEECH:
|
| print("End of speech reached.")
|
|
|
|
|
| if __name__ == "__main__":
|
| import asyncio
|
|
|
| asyncio.run(main())
|
|
|