drrobot9 commited on
Commit
85a874c
·
verified ·
1 Parent(s): 5de1f52

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +15 -25
app/main.py CHANGED
@@ -1,18 +1,14 @@
1
-
2
-
3
  import asyncio
4
  import json
5
  import torch
6
  import numpy as np
7
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8
- from pydantic import BaseModel
9
  from liquid_audio import (
10
  LFM2AudioModel,
11
  LFM2AudioProcessor,
12
  ChatState,
13
  )
14
 
15
-
16
  HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  SAMPLE_RATE = 24_000
@@ -25,26 +21,18 @@ else:
25
 
26
  torch.backends.cuda.matmul.allow_tf32 = True
27
 
28
-
29
-
30
-
31
 
32
  processor = LFM2AudioProcessor.from_pretrained(HF_REPO)
33
- model = LFM2AudioModel.from_pretrained(
34
- HF_REPO,
35
- torch_dtype=DTYPE,
36
- ).to(DEVICE).eval()
37
-
38
- print(f"[BOOT] LFM2.5 Loaded on {DEVICE}")
39
-
40
 
 
41
 
42
 
43
 
44
  app = FastAPI(title="LFM2.5 WebSocket TTS", version="2.0.0")
45
 
46
 
47
- # WAV HEADER
48
 
49
 
50
  def wav_header(sample_rate: int, channels: int = 1, bits: int = 16) -> bytes:
@@ -66,8 +54,7 @@ def wav_header(sample_rate: int, channels: int = 1, bits: int = 16) -> bytes:
66
  )
67
 
68
 
69
- # STREAM CORE
70
-
71
 
72
  async def stream_lfm_tts(websocket: WebSocket, text: str):
73
  chat = ChatState(processor)
@@ -96,7 +83,7 @@ async def stream_lfm_tts(websocket: WebSocket, text: str):
96
  if data.get("type") == "stop":
97
  stop_flag = True
98
  break
99
- except:
100
  stop_flag = True
101
 
102
  listener_task = asyncio.create_task(listen_for_stop())
@@ -123,7 +110,6 @@ async def stream_lfm_tts(websocket: WebSocket, text: str):
123
  .unsqueeze(0)
124
  .to(DEVICE)
125
  )
126
-
127
  waveform = processor.decode(audio_codes)
128
  waveform = waveform.squeeze().cpu().numpy()
129
  waveform = np.clip(waveform, -1.0, 1.0)
@@ -132,7 +118,7 @@ async def stream_lfm_tts(websocket: WebSocket, text: str):
132
  await websocket.send_bytes(audio_int16.tobytes())
133
  audio_buffer.clear()
134
 
135
- # flush
136
  if not stop_flag and len(audio_buffer) > 1:
137
  audio_codes = (
138
  torch.stack(audio_buffer[:-1], dim=1)
@@ -143,7 +129,6 @@ async def stream_lfm_tts(websocket: WebSocket, text: str):
143
  waveform = waveform.squeeze().cpu().numpy()
144
  waveform = np.clip(waveform, -1.0, 1.0)
145
  audio_int16 = (waveform * 32767.0).astype(np.int16)
146
-
147
  await websocket.send_bytes(audio_int16.tobytes())
148
 
149
  await websocket.send_text(json.dumps({"type": "done"}))
@@ -152,13 +137,11 @@ async def stream_lfm_tts(websocket: WebSocket, text: str):
152
  listener_task.cancel()
153
 
154
 
155
- # WEBSOCKET ENDPOINT
156
-
157
 
158
  @app.websocket("/ws/tts")
159
  async def websocket_tts(websocket: WebSocket):
160
  await websocket.accept()
161
-
162
  try:
163
  while True:
164
  message = await websocket.receive_text()
@@ -176,4 +159,11 @@ async def websocket_tts(websocket: WebSocket):
176
  await stream_lfm_tts(websocket, text)
177
 
178
  except WebSocketDisconnect:
179
- print("Client disconnected")
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
  import json
3
  import torch
4
  import numpy as np
5
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
 
6
  from liquid_audio import (
7
  LFM2AudioModel,
8
  LFM2AudioProcessor,
9
  ChatState,
10
  )
11
 
 
12
  HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  SAMPLE_RATE = 24_000
 
21
 
22
  torch.backends.cuda.matmul.allow_tf32 = True
23
 
24
+ print(f"[BOOT] Loading model on {DEVICE} with dtype {DTYPE}...")
 
 
25
 
26
  processor = LFM2AudioProcessor.from_pretrained(HF_REPO)
27
+ model = LFM2AudioModel.from_pretrained(HF_REPO).to(dtype=DTYPE, device=DEVICE).eval()
 
 
 
 
 
 
28
 
29
+ print(f"[BOOT] LFM2.5 Loaded on {DEVICE}")
30
 
31
 
32
 
33
  app = FastAPI(title="LFM2.5 WebSocket TTS", version="2.0.0")
34
 
35
 
 
36
 
37
 
38
  def wav_header(sample_rate: int, channels: int = 1, bits: int = 16) -> bytes:
 
54
  )
55
 
56
 
57
+ # Stream core
 
58
 
59
  async def stream_lfm_tts(websocket: WebSocket, text: str):
60
  chat = ChatState(processor)
 
83
  if data.get("type") == "stop":
84
  stop_flag = True
85
  break
86
+ except Exception:
87
  stop_flag = True
88
 
89
  listener_task = asyncio.create_task(listen_for_stop())
 
110
  .unsqueeze(0)
111
  .to(DEVICE)
112
  )
 
113
  waveform = processor.decode(audio_codes)
114
  waveform = waveform.squeeze().cpu().numpy()
115
  waveform = np.clip(waveform, -1.0, 1.0)
 
118
  await websocket.send_bytes(audio_int16.tobytes())
119
  audio_buffer.clear()
120
 
121
+ # flush remaining
122
  if not stop_flag and len(audio_buffer) > 1:
123
  audio_codes = (
124
  torch.stack(audio_buffer[:-1], dim=1)
 
129
  waveform = waveform.squeeze().cpu().numpy()
130
  waveform = np.clip(waveform, -1.0, 1.0)
131
  audio_int16 = (waveform * 32767.0).astype(np.int16)
 
132
  await websocket.send_bytes(audio_int16.tobytes())
133
 
134
  await websocket.send_text(json.dumps({"type": "done"}))
 
137
  listener_task.cancel()
138
 
139
 
140
+ # WebSocket endpoint
 
141
 
142
  @app.websocket("/ws/tts")
143
  async def websocket_tts(websocket: WebSocket):
144
  await websocket.accept()
 
145
  try:
146
  while True:
147
  message = await websocket.receive_text()
 
159
  await stream_lfm_tts(websocket, text)
160
 
161
  except WebSocketDisconnect:
162
+ print("[WS] Client disconnected")
163
+
164
+
165
+
166
+
167
+ @app.get("/health")
168
+ async def health():
169
+ return {"status": "ok", "device": DEVICE}