drrobot9 commited on
Commit
f79b1a9
·
verified ·
1 Parent(s): c56c006

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +165 -88
app/main.py CHANGED
@@ -3,11 +3,7 @@ import json
3
  import torch
4
  import numpy as np
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
- from liquid_audio import (
7
- LFM2AudioModel,
8
- LFM2AudioProcessor,
9
- ChatState,
10
- )
11
 
12
  HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -17,131 +13,212 @@ CHUNK_SIZE = 20
17
  DTYPE = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
18
  torch.backends.cuda.matmul.allow_tf32 = True
19
 
20
- print(f"[BOOT] Loading model on {DEVICE}...")
 
 
 
21
 
 
22
  processor = LFM2AudioProcessor.from_pretrained(HF_REPO)
23
- model = LFM2AudioModel.from_pretrained(HF_REPO).to(device=DEVICE, dtype=DTYPE).eval()
24
-
25
  print("[BOOT] Model loaded")
26
 
27
- app = FastAPI(title="LFM2.5 Speech-to-Speech", version="3.0")
 
28
 
 
29
 
30
- def wav_header(sr=24000, ch=1, bits=16):
31
- byte_rate = sr * ch * bits // 8
32
- block_align = ch * bits // 8
33
  return (
34
- b"RIFF"
35
- + b"\xff\xff\xff\xff"
36
- + b"WAVEfmt "
37
- + (16).to_bytes(4, "little")
38
- + (1).to_bytes(2, "little")
39
- + ch.to_bytes(2, "little")
40
- + sr.to_bytes(4, "little")
41
- + byte_rate.to_bytes(4, "little")
42
- + block_align.to_bytes(2, "little")
43
- + bits.to_bytes(2, "little")
44
- + b"data"
45
- + b"\xff\xff\xff\xff"
46
  )
47
 
48
 
49
- async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  chat = ChatState(processor)
51
-
52
  chat.new_turn("system")
53
- chat.add_text("Respond conversationally with audio.")
54
  chat.end_turn()
55
-
56
  chat.new_turn("user")
57
  audio_tensor = torch.from_numpy(audio_np[np.newaxis, :]).to(dtype=torch.float32)
58
  chat.add_audio(audio_tensor, sampling_rate=SAMPLE_RATE)
59
  chat.end_turn()
60
-
61
  chat.new_turn("assistant")
62
 
63
- await websocket.send_bytes(wav_header())
64
-
65
- audio_buffer = []
66
-
67
  with torch.inference_mode():
68
  for token in model.generate_interleaved(
69
  **chat,
70
- max_new_tokens=4096,
71
  audio_temperature=0.8,
72
  audio_top_k=4,
73
  ):
74
- # numel()==1 means text token
75
  if token.numel() == 1:
76
  continue
 
 
 
 
 
 
77
 
78
- # multi-element tensor = audio codes chunk
79
- audio_buffer.append(token)
 
 
80
 
81
- if len(audio_buffer) >= CHUNK_SIZE:
82
- audio_codes = (
83
- torch.stack(audio_buffer, dim=1)
84
- .unsqueeze(0)
85
- .to(DEVICE)
86
- )
87
- try:
88
- waveform = processor.decode(audio_codes)
89
- waveform = waveform.squeeze().cpu().numpy()
90
- waveform = np.clip(waveform, -1.0, 1.0)
91
- audio_int16 = (waveform * 32767).astype(np.int16)
92
- await websocket.send_bytes(audio_int16.tobytes())
93
- except Exception as e:
94
- print(f"[WARN] decode error: {e}")
95
- finally:
96
- audio_buffer.clear()
97
-
98
- # flush remaining
99
- if len(audio_buffer) > 1:
100
- audio_codes = (
101
- torch.stack(audio_buffer, dim=1)
102
- .unsqueeze(0)
103
- .to(DEVICE)
104
- )
105
- try:
106
- waveform = processor.decode(audio_codes)
107
- waveform = waveform.squeeze().cpu().numpy()
108
- waveform = np.clip(waveform, -1.0, 1.0)
109
- audio_int16 = (waveform * 32767).astype(np.int16)
110
- await websocket.send_bytes(audio_int16.tobytes())
111
- except Exception as e:
112
- print(f"[WARN] flush decode error: {e}")
113
 
114
- await websocket.send_text(json.dumps({"type": "done"}))
115
 
 
116
 
117
  @app.websocket("/ws/s2s")
118
  async def websocket_s2s(websocket: WebSocket):
119
  await websocket.accept()
 
120
 
121
- try:
122
- audio_bytes = bytearray()
123
 
124
- while True:
125
- message = await websocket.receive()
126
-
127
- if "text" in message:
128
- payload = json.loads(message["text"])
129
-
130
- if payload["type"] == "start":
131
- audio_bytes.clear()
132
 
133
- elif payload["type"] == "end":
134
- audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32)
135
- audio_np /= 32767.0
136
- await generate_response(websocket, audio_np)
137
-
138
- elif "bytes" in message:
139
- audio_bytes.extend(message["bytes"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  except WebSocketDisconnect:
 
 
 
 
142
  print("[WS] client disconnected")
143
 
144
 
145
  @app.get("/health")
146
  async def health():
147
- return {"status": "ok", "device": DEVICE}
 
 
 
 
 
 
 
 
 
3
  import torch
4
  import numpy as np
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
6
+ from liquid_audio import LFM2AudioModel, LFM2AudioProcessor, ChatState
 
 
 
 
7
 
8
  HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
13
  DTYPE = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
14
  torch.backends.cuda.matmul.allow_tf32 = True
15
 
16
+ # VAD settings
17
+ VAD_SILENCE_THRESHOLD = 0.01 # RMS below this = silencE
18
+ VAD_SILENCE_FRAMES = 30 # ~600ms of silence at 160-sample frames
19
+ VAD_MIN_SPEECH_FRAMES = 10 # ignore very short blips
20
 
21
+ print(f"[BOOT] Loading model on {DEVICE}...")
22
  processor = LFM2AudioProcessor.from_pretrained(HF_REPO)
23
+ model = LFM2AudioModel.from_pretrained(HF_REPO).to(device=DEVICE, dtype=DTYPE).eval()
 
24
  print("[BOOT] Model loaded")
25
 
26
+ app = FastAPI(title="LFM2.5 Real-Time S2S", version="4.0")
27
+
28
 
29
+ # Helpers
30
 
31
+ def wav_header(sr=SAMPLE_RATE, ch=1, bits=16) -> bytes:
32
+ br = sr * ch * bits // 8
33
+ ba = ch * bits // 8
34
  return (
35
+ b"RIFF" + b"\xff\xff\xff\xff" + b"WAVEfmt "
36
+ + (16).to_bytes(4,"little") + (1).to_bytes(2,"little")
37
+ + ch.to_bytes(2,"little") + sr.to_bytes(4,"little")
38
+ + br.to_bytes(4,"little") + ba.to_bytes(2,"little")
39
+ + bits.to_bytes(2,"little") + b"data" + b"\xff\xff\xff\xff"
 
 
 
 
 
 
 
40
  )
41
 
42
 
43
+ def decode_chunk(buf: list) -> bytes | None:
44
+ try:
45
+ codes = torch.stack(buf, dim=1).unsqueeze(0).to(DEVICE)
46
+ codes = codes - processor.audio_token_start
47
+ if codes.min() < 0:
48
+ return None
49
+ wf = processor.decode(codes).squeeze().cpu().numpy()
50
+ wf = np.clip(wf, -1.0, 1.0)
51
+ return (wf * 32767).astype(np.int16).tobytes()
52
+ except Exception as e:
53
+ print(f"[WARN] decode: {e}")
54
+ return None
55
+
56
+
57
+ def is_speech(pcm_int16: np.ndarray) -> bool:
58
+ """Simple energy-based VAD."""
59
+ if len(pcm_int16) == 0:
60
+ return False
61
+ rms = np.sqrt(np.mean(pcm_int16.astype(np.float32) ** 2)) / 32767.0
62
+ return rms > VAD_SILENCE_THRESHOLD
63
+
64
+
65
+ # Generation runs in thread so it doesn't block the event loop
66
+
67
+ def run_generation(audio_np: np.ndarray) -> list[bytes]:
68
+ """Synchronous generation — called via run_in_executor."""
69
  chat = ChatState(processor)
 
70
  chat.new_turn("system")
71
+ chat.add_text("You are a helpful real-time voice assistant called chioma. Respond naturally and concisely with audio when asked who built you say kelvin jackson an AI ENGINEER.")
72
  chat.end_turn()
 
73
  chat.new_turn("user")
74
  audio_tensor = torch.from_numpy(audio_np[np.newaxis, :]).to(dtype=torch.float32)
75
  chat.add_audio(audio_tensor, sampling_rate=SAMPLE_RATE)
76
  chat.end_turn()
 
77
  chat.new_turn("assistant")
78
 
79
+ chunks = []
80
+ buf = []
 
 
81
  with torch.inference_mode():
82
  for token in model.generate_interleaved(
83
  **chat,
84
+ max_new_tokens=2048,
85
  audio_temperature=0.8,
86
  audio_top_k=4,
87
  ):
 
88
  if token.numel() == 1:
89
  continue
90
+ buf.append(token)
91
+ if len(buf) >= CHUNK_SIZE:
92
+ pcm = decode_chunk(buf)
93
+ if pcm:
94
+ chunks.append(pcm)
95
+ buf.clear()
96
 
97
+ if len(buf) > 1:
98
+ pcm = decode_chunk(buf)
99
+ if pcm:
100
+ chunks.append(pcm)
101
 
102
+ return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
 
104
 
105
+ # WebSocket endpoint
106
 
107
  @app.websocket("/ws/s2s")
108
  async def websocket_s2s(websocket: WebSocket):
109
  await websocket.accept()
110
+ print("[WS] client connected")
111
 
112
+ loop = asyncio.get_event_loop()
 
113
 
114
+ # Queues
115
+ audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() # incoming PCM frames
116
+ generating = False # lock — only one generation at a time
 
 
 
 
 
117
 
118
+ # Receiver task: reads raw PCM frames from client
119
+ async def receiver():
120
+ try:
121
+ while True:
122
+ try:
123
+ msg = await websocket.receive()
124
+ except RuntimeError:
125
+ break
126
+ if msg.get("type") == "websocket.disconnect":
127
+ break
128
+ if "bytes" in msg:
129
+ await audio_queue.put(msg["bytes"])
130
+ elif "text" in msg:
131
+ data = json.loads(msg["text"])
132
+ if data.get("type") == "stop":
133
+ break
134
+ finally:
135
+ await audio_queue.put(None) # sentinel
136
+
137
+ # VAD + generation task
138
+ async def vad_and_generate():
139
+ nonlocal generating
140
+
141
+ speech_frames: list[np.ndarray] = []
142
+ silence_count = 0
143
+ speech_count = 0
144
+ in_speech = False
145
+
146
+ await websocket.send_text(json.dumps({"type": "ready"}))
147
 
148
+ while True:
149
+ frame_bytes = await audio_queue.get()
150
+ if frame_bytes is None:
151
+ break
152
+
153
+ frame = np.frombuffer(frame_bytes, dtype=np.int16)
154
+ active = is_speech(frame)
155
+
156
+ if active:
157
+ silence_count = 0
158
+ speech_count += 1
159
+ in_speech = True
160
+ speech_frames.append(frame)
161
+
162
+ else:
163
+ if in_speech:
164
+ silence_count += 1
165
+ speech_frames.append(frame) # keep tail for natural cutoff
166
+
167
+ # End-of-utterance detected
168
+ if silence_count >= VAD_SILENCE_FRAMES and speech_count >= VAD_MIN_SPEECH_FRAMES:
169
+ if not generating:
170
+ generating = True
171
+
172
+ # Grab the accumulated speech
173
+ utterance = np.concatenate(speech_frames).astype(np.float32) / 32767.0
174
+
175
+ # Reset VAD state immediately so mic stays live
176
+ speech_frames = []
177
+ silence_count = 0
178
+ speech_count = 0
179
+ in_speech = False
180
+
181
+ # Signal client: AI is responding
182
+ await websocket.send_text(json.dumps({"type": "generating"}))
183
+ await websocket.send_bytes(wav_header())
184
+
185
+ # Run heavy generation off the event loop
186
+ chunks = await loop.run_in_executor(
187
+ None, run_generation, utterance
188
+ )
189
+
190
+ for chunk in chunks:
191
+ try:
192
+ await websocket.send_bytes(chunk)
193
+ except Exception:
194
+ break
195
+
196
+ try:
197
+ await websocket.send_text(json.dumps({"type": "done"}))
198
+ except Exception:
199
+ pass
200
+
201
+ generating = False
202
+
203
+
204
+ try:
205
+ await asyncio.gather(receiver(), vad_and_generate())
206
  except WebSocketDisconnect:
207
+ pass
208
+ except Exception as e:
209
+ print(f"[WS] error: {e}")
210
+ finally:
211
  print("[WS] client disconnected")
212
 
213
 
214
  @app.get("/health")
215
  async def health():
216
+ return {"status": "ok", "device": DEVICE}
217
+
218
+
219
+
220
+ from fastapi.responses import FileResponse
221
+
222
+ @app.get("/")
223
+ async def index():
224
+ return FileResponse("client.html")