drrobot9 commited on
Commit
4e7f8bc
·
verified ·
1 Parent(s): 85a874c

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +76 -98
app/main.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import asyncio
2
  import json
3
  import torch
@@ -11,157 +13,133 @@ from liquid_audio import (
11
 
12
  HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
- SAMPLE_RATE = 24_000
15
- CHUNK_SIZE = 6
16
-
17
- if DEVICE == "cuda" and torch.cuda.is_bf16_supported():
18
- DTYPE = torch.bfloat16
19
- else:
20
- DTYPE = torch.float32
21
 
 
22
  torch.backends.cuda.matmul.allow_tf32 = True
23
 
24
- print(f"[BOOT] Loading model on {DEVICE} with dtype {DTYPE}...")
25
 
26
  processor = LFM2AudioProcessor.from_pretrained(HF_REPO)
27
- model = LFM2AudioModel.from_pretrained(HF_REPO).to(dtype=DTYPE, device=DEVICE).eval()
28
-
29
- print(f"[BOOT] LFM2.5 Loaded on {DEVICE}")
30
-
31
 
 
32
 
33
- app = FastAPI(title="LFM2.5 WebSocket TTS", version="2.0.0")
34
 
35
 
36
-
37
-
38
- def wav_header(sample_rate: int, channels: int = 1, bits: int = 16) -> bytes:
39
- byte_rate = sample_rate * channels * bits // 8
40
- block_align = channels * bits // 8
41
  return (
42
  b"RIFF"
43
- + (b"\xff\xff\xff\xff")
44
  + b"WAVEfmt "
45
  + (16).to_bytes(4, "little")
46
  + (1).to_bytes(2, "little")
47
- + channels.to_bytes(2, "little")
48
- + sample_rate.to_bytes(4, "little")
49
  + byte_rate.to_bytes(4, "little")
50
  + block_align.to_bytes(2, "little")
51
  + bits.to_bytes(2, "little")
52
  + b"data"
53
- + (b"\xff\xff\xff\xff")
54
  )
55
 
56
 
57
- # Stream core
58
 
59
- async def stream_lfm_tts(websocket: WebSocket, text: str):
60
  chat = ChatState(processor)
61
 
62
  chat.new_turn("system")
63
- chat.add_text("Respond with interleaved text and audio.")
64
  chat.end_turn()
65
 
66
  chat.new_turn("user")
67
- chat.add_text(text)
68
  chat.end_turn()
69
 
70
  chat.new_turn("assistant")
71
 
72
- await websocket.send_bytes(wav_header(SAMPLE_RATE))
73
 
74
  audio_buffer = []
75
- stop_flag = False
76
-
77
- async def listen_for_stop():
78
- nonlocal stop_flag
79
- try:
80
- while True:
81
- msg = await websocket.receive_text()
82
- data = json.loads(msg)
83
- if data.get("type") == "stop":
84
- stop_flag = True
85
- break
86
- except Exception:
87
- stop_flag = True
88
-
89
- listener_task = asyncio.create_task(listen_for_stop())
90
 
91
- try:
92
- with torch.inference_mode():
93
- for token in model.generate_interleaved(
94
- **chat,
95
- max_new_tokens=4096,
96
- audio_temperature=0.8,
97
- audio_top_k=4,
98
- ):
99
- if stop_flag:
100
- break
101
-
102
- if token.numel() == 1:
103
- continue
104
 
 
105
  audio_buffer.append(token)
106
 
107
- if len(audio_buffer) >= CHUNK_SIZE:
108
- audio_codes = (
109
- torch.stack(audio_buffer, dim=1)
110
- .unsqueeze(0)
111
- .to(DEVICE)
112
- )
113
- waveform = processor.decode(audio_codes)
114
- waveform = waveform.squeeze().cpu().numpy()
115
- waveform = np.clip(waveform, -1.0, 1.0)
116
- audio_int16 = (waveform * 32767.0).astype(np.int16)
117
 
118
- await websocket.send_bytes(audio_int16.tobytes())
 
 
119
  audio_buffer.clear()
 
 
 
 
120
 
121
- # flush remaining
122
- if not stop_flag and len(audio_buffer) > 1:
123
- audio_codes = (
124
- torch.stack(audio_buffer[:-1], dim=1)
125
- .unsqueeze(0)
126
- .to(DEVICE)
127
- )
128
- waveform = processor.decode(audio_codes)
129
- waveform = waveform.squeeze().cpu().numpy()
130
- waveform = np.clip(waveform, -1.0, 1.0)
131
- audio_int16 = (waveform * 32767.0).astype(np.int16)
132
- await websocket.send_bytes(audio_int16.tobytes())
133
 
134
- await websocket.send_text(json.dumps({"type": "done"}))
135
 
136
- finally:
137
- listener_task.cancel()
138
 
 
139
 
140
- # WebSocket endpoint
141
 
142
- @app.websocket("/ws/tts")
143
- async def websocket_tts(websocket: WebSocket):
 
144
  await websocket.accept()
 
145
  try:
 
 
 
146
  while True:
147
- message = await websocket.receive_text()
148
- payload = json.loads(message)
149
-
150
- if payload.get("type") == "start":
151
- text = payload.get("text", "").strip()
152
- if not text:
153
- await websocket.send_text(json.dumps({
154
- "type": "error",
155
- "message": "Text is empty"
156
- }))
157
- continue
158
 
159
- await stream_lfm_tts(websocket, text)
160
 
161
- except WebSocketDisconnect:
162
- print("[WS] Client disconnected")
 
 
 
 
 
163
 
 
 
164
 
 
 
 
 
 
 
 
165
 
166
 
167
  @app.get("/health")
 
1
+ # app/main.py
2
+
3
  import asyncio
4
  import json
5
  import torch
 
13
 
14
  HF_REPO = "LiquidAI/LFM2.5-Audio-1.5B"
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ SAMPLE_RATE = 24000
17
+ CHUNK_SIZE = 20
 
 
 
 
 
18
 
19
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
20
  torch.backends.cuda.matmul.allow_tf32 = True
21
 
22
+ print(f"[BOOT] Loading model on {DEVICE}...")
23
 
24
  processor = LFM2AudioProcessor.from_pretrained(HF_REPO)
25
+ model = LFM2AudioModel.from_pretrained(HF_REPO).to(device=DEVICE, dtype=DTYPE).eval()
 
 
 
26
 
27
+ print("[BOOT] Model loaded")
28
 
29
+ app = FastAPI(title="LFM2.5 Speech-to-Speech", version="3.0")
30
 
31
 
32
+ def wav_header(sr=24000, ch=1, bits=16):
33
+ byte_rate = sr * ch * bits // 8
34
+ block_align = ch * bits // 8
 
 
35
  return (
36
  b"RIFF"
37
+ + b"\xff\xff\xff\xff"
38
  + b"WAVEfmt "
39
  + (16).to_bytes(4, "little")
40
  + (1).to_bytes(2, "little")
41
+ + ch.to_bytes(2, "little")
42
+ + sr.to_bytes(4, "little")
43
  + byte_rate.to_bytes(4, "little")
44
  + block_align.to_bytes(2, "little")
45
  + bits.to_bytes(2, "little")
46
  + b"data"
47
+ + b"\xff\xff\xff\xff"
48
  )
49
 
50
 
51
+ async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
52
 
 
53
  chat = ChatState(processor)
54
 
55
  chat.new_turn("system")
56
+ chat.add_text("Respond conversationally with audio.")
57
  chat.end_turn()
58
 
59
  chat.new_turn("user")
60
+ chat.add_audio(audio_np, sample_rate=SAMPLE_RATE)
61
  chat.end_turn()
62
 
63
  chat.new_turn("assistant")
64
 
65
+ await websocket.send_bytes(wav_header())
66
 
67
  audio_buffer = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ with torch.inference_mode():
70
+
71
+ for token in model.generate_interleaved(
72
+ **chat,
73
+ max_new_tokens=4096,
74
+ audio_temperature=0.8,
75
+ audio_top_k=4,
76
+ ):
77
+
78
+ if token.numel() == 1:
79
+ continue
80
+
81
+ token_id = token.item()
82
 
83
+ if processor.audio_token_start <= token_id <= processor.audio_token_end:
84
  audio_buffer.append(token)
85
 
86
+ if len(audio_buffer) >= CHUNK_SIZE:
87
+
88
+ audio_codes = (
89
+ torch.stack(audio_buffer, dim=1)
90
+ .unsqueeze(0)
91
+ .to(DEVICE)
92
+ )
 
 
 
93
 
94
+ try:
95
+ waveform = processor.decode(audio_codes)
96
+ except Exception:
97
  audio_buffer.clear()
98
+ continue
99
+
100
+ waveform = waveform.squeeze().cpu().numpy()
101
+ waveform = np.clip(waveform, -1.0, 1.0)
102
 
103
+ audio_int16 = (waveform * 32767).astype(np.int16)
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ await websocket.send_bytes(audio_int16.tobytes())
106
 
107
+ audio_buffer.clear()
 
108
 
109
+ await websocket.send_text(json.dumps({"type": "done"}))
110
 
 
111
 
112
+ @app.websocket("/ws/s2s")
113
+ async def websocket_s2s(websocket: WebSocket):
114
+
115
  await websocket.accept()
116
+
117
  try:
118
+
119
+ audio_bytes = bytearray()
120
+
121
  while True:
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ message = await websocket.receive()
124
 
125
+ if "text" in message:
126
+ payload = json.loads(message["text"])
127
+
128
+ if payload["type"] == "start":
129
+ audio_bytes.clear()
130
+
131
+ if payload["type"] == "end":
132
 
133
+ audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32)
134
+ audio_np /= 32767.0
135
 
136
+ await generate_response(websocket, audio_np)
137
+
138
+ elif "bytes" in message:
139
+ audio_bytes.extend(message["bytes"])
140
+
141
+ except WebSocketDisconnect:
142
+ print("[WS] client disconnected")
143
 
144
 
145
  @app.get("/health")