File size: 6,348 Bytes
44ae209
 
 
 
 
89f06ea
f79b1a9
44ae209
 
 
4e7f8bc
 
44ae209
4e7f8bc
44ae209
 
89f06ea
 
 
44ae209
f79b1a9
44ae209
f79b1a9
4e7f8bc
44ae209
f79b1a9
 
44ae209
f79b1a9
44ae209
f79b1a9
 
 
44ae209
f79b1a9
89f06ea
 
 
 
44ae209
 
 
f79b1a9
89f06ea
f79b1a9
89f06ea
 
f79b1a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44ae209
 
89f06ea
 
 
 
 
44ae209
 
6b6a9ba
c56c006
44ae209
 
 
f79b1a9
 
4e7f8bc
 
 
f79b1a9
4e7f8bc
 
 
 
89f06ea
f79b1a9
 
 
 
 
 
4e7f8bc
89f06ea
f79b1a9
 
 
 
44ae209
f79b1a9
259c3a6
44ae209
89f06ea
44ae209
4e7f8bc
 
44ae209
f79b1a9
4e7f8bc
f79b1a9
89f06ea
 
4e7f8bc
f79b1a9
 
 
 
 
 
 
 
 
 
 
 
89f06ea
f79b1a9
 
89f06ea
f79b1a9
 
 
89f06ea
 
 
 
f79b1a9
 
4e7f8bc
f79b1a9
 
 
 
 
89f06ea
f79b1a9
 
 
 
 
 
 
89f06ea
 
 
f79b1a9
89f06ea
 
 
 
 
 
 
 
f79b1a9
89f06ea
f79b1a9
 
89f06ea
f79b1a9
89f06ea
 
 
 
 
f79b1a9
 
 
 
4e7f8bc
f79b1a9
 
 
 
4e7f8bc
85a874c
 
f79b1a9
 
89f06ea
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import asyncio
import json
import torch
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState

HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAMPLE_RATE = 24000
CHUNK_SIZE = 20

DTYPE = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True

VAD_SILENCE_THRESHOLD = 0.01
VAD_SILENCE_FRAMES    = 30
VAD_MIN_SPEECH_FRAMES = 10

print(f"[BOOT] Loading model on {DEVICE}...")
processor = LFM2AudioProcessor.from_pretrained(HF_REPO)
model     = LFM2AudioModel.from_pretrained(HF_REPO).to(device=DEVICE, dtype=DTYPE).eval()
print("[BOOT] Model loaded")

app = FastAPI(title="LFM2.5 Real-Time S2S", version="4.0")


#  Helpers

def wav_header(sr=SAMPLE_RATE, ch=1, bits=16) -> bytes:
    br = sr * ch * bits // 8
    ba = ch * bits // 8
    return (
        b"RIFF" + b"\xff\xff\xff\xff" + b"WAVEfmt "
        + (16).to_bytes(4, "little") + (1).to_bytes(2, "little")
        + ch.to_bytes(2, "little")   + sr.to_bytes(4, "little")
        + br.to_bytes(4, "little")   + ba.to_bytes(2, "little")
        + bits.to_bytes(2, "little") + b"data" + b"\xff\xff\xff\xff"
    )


def decode_chunk(buf: list) -> bytes | None:
    """Decode audio tokens — pass directly to processor, no offset subtraction."""
    try:
        
        codes = torch.stack(buf[:-1], dim=1).unsqueeze(0).to(DEVICE)
        wf = processor.decode(codes).squeeze().cpu().numpy()
        wf = np.clip(wf, -1.0, 1.0)
        return (wf * 32767).astype(np.int16).tobytes()
    except Exception as e:
        print(f"[WARN] decode: {e}")
        return None


def is_speech(pcm_int16: np.ndarray) -> bool:
    if len(pcm_int16) == 0:
        return False
    rms = np.sqrt(np.mean(pcm_int16.astype(np.float32) ** 2)) / 32767.0
    return rms > VAD_SILENCE_THRESHOLD


def run_generation(audio_np: np.ndarray) -> list[bytes]:
    """Synchronous generation — called via run_in_executor."""
    chat = ChatState(processor)
    chat.new_turn("system")
    chat.add_text(
        "You are a helpful real-time voice assistant called chioma. "
        "Respond naturally and concisely with audio. "
        "When asked who built you, say Kelvin Jackson, an AI Engineer."
    )
    chat.end_turn()
    chat.new_turn("user")
    audio_tensor = torch.from_numpy(audio_np[np.newaxis, :]).to(dtype=torch.float32)
    chat.add_audio(audio_tensor, sampling_rate=SAMPLE_RATE)
    chat.end_turn()
    chat.new_turn("assistant")

    chunks = []
    buf = []
    with torch.inference_mode():
        for token in model.generate_interleaved(
            **chat,
            max_new_tokens=2048,
            audio_temperature=0.8,
            audio_top_k=4,
        ):
            if token.numel() == 1:
                continue  # text token 
            buf.append(token)
            if len(buf) >= CHUNK_SIZE:
                pcm = decode_chunk(buf)
                if pcm:
                    chunks.append(pcm)
                buf.clear()

    # flush remaining
    if len(buf) > 1:
        pcm = decode_chunk(buf)
        if pcm:
            chunks.append(pcm)

    return chunks


# WebSocket 

@app.websocket("/ws/s2s")
async def websocket_s2s(websocket: WebSocket):
    await websocket.accept()
    print("[WS] client connected")

    loop = asyncio.get_event_loop()
    audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
    generating = False

    async def receiver():
        try:
            while True:
                try:
                    msg = await websocket.receive()
                except RuntimeError:
                    break
                if msg.get("type") == "websocket.disconnect":
                    break
                if "bytes" in msg:
                    await audio_queue.put(msg["bytes"])
                elif "text" in msg:
                    if json.loads(msg["text"]).get("type") == "stop":
                        break
        finally:
            await audio_queue.put(None)

    async def vad_and_generate():
        nonlocal generating
        speech_frames: list[np.ndarray] = []
        silence_count = 0
        speech_count  = 0
        in_speech     = False

        await websocket.send_text(json.dumps({"type": "ready"}))

        while True:
            frame_bytes = await audio_queue.get()
            if frame_bytes is None:
                break

            frame  = np.frombuffer(frame_bytes, dtype=np.int16)
            active = is_speech(frame)

            if active:
                silence_count = 0
                speech_count += 1
                in_speech = True
                speech_frames.append(frame)
            elif in_speech:
                silence_count += 1
                speech_frames.append(frame)

                if silence_count >= VAD_SILENCE_FRAMES and speech_count >= VAD_MIN_SPEECH_FRAMES:
                    if not generating:
                        generating = True
                        utterance = np.concatenate(speech_frames).astype(np.float32) / 32767.0
                        speech_frames = []
                        silence_count = 0
                        speech_count  = 0
                        in_speech     = False

                        try:
                            await websocket.send_text(json.dumps({"type": "generating"}))
                            await websocket.send_bytes(wav_header())
                            chunks = await loop.run_in_executor(None, run_generation, utterance)
                            for chunk in chunks:
                                await websocket.send_bytes(chunk)
                            await websocket.send_text(json.dumps({"type": "done"}))
                        except Exception as e:
                            print(f"[WS] send error: {e}")
                        finally:
                            generating = False

    try:
        await asyncio.gather(receiver(), vad_and_generate())
    except WebSocketDisconnect:
        pass
    except Exception as e:
        print(f"[WS] error: {e}")
    finally:
        print("[WS] client disconnected")




@app.get("/health")
async def health():
    return {"status": "ok", "device": DEVICE}