drrobot9 commited on
Commit
259c3a6
·
verified ·
1 Parent(s): 2e2a280

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +25 -22
app/main.py CHANGED
@@ -1,5 +1,3 @@
1
- # app/main.py
2
-
3
  import asyncio
4
  import json
5
  import torch
@@ -29,26 +27,25 @@ print("[BOOT] Model loaded")
29
  app = FastAPI(title="LFM2.5 Speech-to-Speech", version="3.0")
30
 
31
 
32
- def wav_header(sample_rate: int, channels: int = 1, bits: int = 16) -> bytes:
33
- byte_rate = sample_rate * channels * bits // 8
34
- block_align = channels * bits // 8
35
  return (
36
  b"RIFF"
37
- + (b"\xff\xff\xff\xff")
38
  + b"WAVEfmt "
39
  + (16).to_bytes(4, "little")
40
  + (1).to_bytes(2, "little")
41
- + channels.to_bytes(2, "little")
42
- + sample_rate.to_bytes(4, "little")
43
  + byte_rate.to_bytes(4, "little")
44
  + block_align.to_bytes(2, "little")
45
  + bits.to_bytes(2, "little")
46
  + b"data"
47
- + (b"\xff\xff\xff\xff")
48
  )
49
 
50
 
51
-
52
  async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
53
 
54
  chat = ChatState(processor)
@@ -58,7 +55,7 @@ async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
58
  chat.end_turn()
59
 
60
  chat.new_turn("user")
61
- chat.add_audio(audio_np)
62
  chat.end_turn()
63
 
64
  chat.new_turn("assistant")
@@ -68,14 +65,12 @@ async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
68
  audio_buffer = []
69
 
70
  with torch.inference_mode():
71
-
72
  for token in model.generate_interleaved(
73
  **chat,
74
  max_new_tokens=4096,
75
  audio_temperature=0.8,
76
  audio_top_k=4,
77
  ):
78
-
79
  if token.numel() == 1:
80
  continue
81
 
@@ -85,7 +80,6 @@ async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
85
  audio_buffer.append(token)
86
 
87
  if len(audio_buffer) >= CHUNK_SIZE:
88
-
89
  audio_codes = (
90
  torch.stack(audio_buffer, dim=1)
91
  .unsqueeze(0)
@@ -100,27 +94,38 @@ async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
100
 
101
  waveform = waveform.squeeze().cpu().numpy()
102
  waveform = np.clip(waveform, -1.0, 1.0)
103
-
104
  audio_int16 = (waveform * 32767).astype(np.int16)
105
 
106
  await websocket.send_bytes(audio_int16.tobytes())
107
-
108
  audio_buffer.clear()
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  await websocket.send_text(json.dumps({"type": "done"}))
111
 
112
 
113
  @app.websocket("/ws/s2s")
114
  async def websocket_s2s(websocket: WebSocket):
115
-
116
  await websocket.accept()
117
 
118
  try:
119
-
120
  audio_bytes = bytearray()
121
 
122
  while True:
123
-
124
  message = await websocket.receive()
125
 
126
  if "text" in message:
@@ -129,11 +134,9 @@ async def websocket_s2s(websocket: WebSocket):
129
  if payload["type"] == "start":
130
  audio_bytes.clear()
131
 
132
- if payload["type"] == "end":
133
-
134
  audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32)
135
  audio_np /= 32767.0
136
-
137
  await generate_response(websocket, audio_np)
138
 
139
  elif "bytes" in message:
 
 
 
1
  import asyncio
2
  import json
3
  import torch
 
27
  app = FastAPI(title="LFM2.5 Speech-to-Speech", version="3.0")
28
 
29
 
30
+ def wav_header(sr=24000, ch=1, bits=16):
31
+ byte_rate = sr * ch * bits // 8
32
+ block_align = ch * bits // 8
33
  return (
34
  b"RIFF"
35
+ + b"\xff\xff\xff\xff"
36
  + b"WAVEfmt "
37
  + (16).to_bytes(4, "little")
38
  + (1).to_bytes(2, "little")
39
+ + ch.to_bytes(2, "little")
40
+ + sr.to_bytes(4, "little")
41
  + byte_rate.to_bytes(4, "little")
42
  + block_align.to_bytes(2, "little")
43
  + bits.to_bytes(2, "little")
44
  + b"data"
45
+ + b"\xff\xff\xff\xff"
46
  )
47
 
48
 
 
49
  async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
50
 
51
  chat = ChatState(processor)
 
55
  chat.end_turn()
56
 
57
  chat.new_turn("user")
58
+ chat.add_audio(audio_np, sampling_rate=SAMPLE_RATE)
59
  chat.end_turn()
60
 
61
  chat.new_turn("assistant")
 
65
  audio_buffer = []
66
 
67
  with torch.inference_mode():
 
68
  for token in model.generate_interleaved(
69
  **chat,
70
  max_new_tokens=4096,
71
  audio_temperature=0.8,
72
  audio_top_k=4,
73
  ):
 
74
  if token.numel() == 1:
75
  continue
76
 
 
80
  audio_buffer.append(token)
81
 
82
  if len(audio_buffer) >= CHUNK_SIZE:
 
83
  audio_codes = (
84
  torch.stack(audio_buffer, dim=1)
85
  .unsqueeze(0)
 
94
 
95
  waveform = waveform.squeeze().cpu().numpy()
96
  waveform = np.clip(waveform, -1.0, 1.0)
 
97
  audio_int16 = (waveform * 32767).astype(np.int16)
98
 
99
  await websocket.send_bytes(audio_int16.tobytes())
 
100
  audio_buffer.clear()
101
 
102
+ # flush remaining
103
+ if len(audio_buffer) > 1:
104
+ audio_codes = (
105
+ torch.stack(audio_buffer, dim=1)
106
+ .unsqueeze(0)
107
+ .to(DEVICE)
108
+ )
109
+ try:
110
+ waveform = processor.decode(audio_codes)
111
+ waveform = waveform.squeeze().cpu().numpy()
112
+ waveform = np.clip(waveform, -1.0, 1.0)
113
+ audio_int16 = (waveform * 32767).astype(np.int16)
114
+ await websocket.send_bytes(audio_int16.tobytes())
115
+ except Exception:
116
+ pass
117
+
118
  await websocket.send_text(json.dumps({"type": "done"}))
119
 
120
 
121
  @app.websocket("/ws/s2s")
122
  async def websocket_s2s(websocket: WebSocket):
 
123
  await websocket.accept()
124
 
125
  try:
 
126
  audio_bytes = bytearray()
127
 
128
  while True:
 
129
  message = await websocket.receive()
130
 
131
  if "text" in message:
 
134
  if payload["type"] == "start":
135
  audio_bytes.clear()
136
 
137
+ elif payload["type"] == "end":
 
138
  audio_np = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32)
139
  audio_np /= 32767.0
 
140
  await generate_response(websocket, audio_np)
141
 
142
  elif "bytes" in message: