STREAM_TTS / app /main.py
drrobot9's picture
Update app/main.py
89f06ea verified
import asyncio
import json
import torch
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState
HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE_RATE = 24000
CHUNK_SIZE = 20
DTYPE = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
VAD_SILENCE_THRESHOLD = 0.01
VAD_SILENCE_FRAMES = 30
VAD_MIN_SPEECH_FRAMES = 10
print(f"[BOOT] Loading model on {DEVICE}...")
processor = LFM2AudioProcessor.from_pretrained(HF_REPO)
model = LFM2AudioModel.from_pretrained(HF_REPO).to(device=DEVICE, dtype=DTYPE).eval()
print("[BOOT] Model loaded")
app = FastAPI(title="LFM2.5 Real-Time S2S", version="4.0")
# Helpers
def wav_header(sr=SAMPLE_RATE, ch=1, bits=16) -> bytes:
br = sr * ch * bits // 8
ba = ch * bits // 8
return (
b"RIFF" + b"\xff\xff\xff\xff" + b"WAVEfmt "
+ (16).to_bytes(4, "little") + (1).to_bytes(2, "little")
+ ch.to_bytes(2, "little") + sr.to_bytes(4, "little")
+ br.to_bytes(4, "little") + ba.to_bytes(2, "little")
+ bits.to_bytes(2, "little") + b"data" + b"\xff\xff\xff\xff"
)
def decode_chunk(buf: list) -> bytes | None:
"""Decode audio tokens — pass directly to processor, no offset subtraction."""
try:
codes = torch.stack(buf[:-1], dim=1).unsqueeze(0).to(DEVICE)
wf = processor.decode(codes).squeeze().cpu().numpy()
wf = np.clip(wf, -1.0, 1.0)
return (wf * 32767).astype(np.int16).tobytes()
except Exception as e:
print(f"[WARN] decode: {e}")
return None
def is_speech(pcm_int16: np.ndarray) -> bool:
if len(pcm_int16) == 0:
return False
rms = np.sqrt(np.mean(pcm_int16.astype(np.float32) ** 2)) / 32767.0
return rms > VAD_SILENCE_THRESHOLD
def run_generation(audio_np: np.ndarray) -> list[bytes]:
"""Synchronous generation — called via run_in_executor."""
chat = ChatState(processor)
chat.new_turn("system")
chat.add_text(
"You are a helpful real-time voice assistant called chioma. "
"Respond naturally and concisely with audio. "
"When asked who built you, say Kelvin Jackson, an AI Engineer."
)
chat.end_turn()
chat.new_turn("user")
audio_tensor = torch.from_numpy(audio_np[np.newaxis, :]).to(dtype=torch.float32)
chat.add_audio(audio_tensor, sampling_rate=SAMPLE_RATE)
chat.end_turn()
chat.new_turn("assistant")
chunks = []
buf = []
with torch.inference_mode():
for token in model.generate_interleaved(
**chat,
max_new_tokens=2048,
audio_temperature=0.8,
audio_top_k=4,
):
if token.numel() == 1:
continue # text token
buf.append(token)
if len(buf) >= CHUNK_SIZE:
pcm = decode_chunk(buf)
if pcm:
chunks.append(pcm)
buf.clear()
# flush remaining
if len(buf) > 1:
pcm = decode_chunk(buf)
if pcm:
chunks.append(pcm)
return chunks
# WebSocket
@app.websocket("/ws/s2s")
async def websocket_s2s(websocket: WebSocket):
await websocket.accept()
print("[WS] client connected")
loop = asyncio.get_event_loop()
audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
generating = False
async def receiver():
try:
while True:
try:
msg = await websocket.receive()
except RuntimeError:
break
if msg.get("type") == "websocket.disconnect":
break
if "bytes" in msg:
await audio_queue.put(msg["bytes"])
elif "text" in msg:
if json.loads(msg["text"]).get("type") == "stop":
break
finally:
await audio_queue.put(None)
async def vad_and_generate():
nonlocal generating
speech_frames: list[np.ndarray] = []
silence_count = 0
speech_count = 0
in_speech = False
await websocket.send_text(json.dumps({"type": "ready"}))
while True:
frame_bytes = await audio_queue.get()
if frame_bytes is None:
break
frame = np.frombuffer(frame_bytes, dtype=np.int16)
active = is_speech(frame)
if active:
silence_count = 0
speech_count += 1
in_speech = True
speech_frames.append(frame)
elif in_speech:
silence_count += 1
speech_frames.append(frame)
if silence_count >= VAD_SILENCE_FRAMES and speech_count >= VAD_MIN_SPEECH_FRAMES:
if not generating:
generating = True
utterance = np.concatenate(speech_frames).astype(np.float32) / 32767.0
speech_frames = []
silence_count = 0
speech_count = 0
in_speech = False
try:
await websocket.send_text(json.dumps({"type": "generating"}))
await websocket.send_bytes(wav_header())
chunks = await loop.run_in_executor(None, run_generation, utterance)
for chunk in chunks:
await websocket.send_bytes(chunk)
await websocket.send_text(json.dumps({"type": "done"}))
except Exception as e:
print(f"[WS] send error: {e}")
finally:
generating = False
try:
await asyncio.gather(receiver(), vad_and_generate())
except WebSocketDisconnect:
pass
except Exception as e:
print(f"[WS] error: {e}")
finally:
print("[WS] client disconnected")
@app.get("/health")
async def health():
return {"status": "ok", "device": DEVICE}