drrobot9 commited on
Commit
c56c006
·
verified ·
1 Parent(s): 6b6a9ba

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +13 -18
app/main.py CHANGED
@@ -47,7 +47,6 @@ def wav_header(sr=24000, ch=1, bits=16):
47
 
48
 
49
  async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
50
-
51
  chat = ChatState(processor)
52
 
53
  chat.new_turn("system")
@@ -56,7 +55,7 @@ async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
56
 
57
  chat.new_turn("user")
58
  audio_tensor = torch.from_numpy(audio_np[np.newaxis, :]).to(dtype=torch.float32)
59
- chat.add_audio(audio_tensor, sampling_rate=SAMPLE_RATE)
60
  chat.end_turn()
61
 
62
  chat.new_turn("assistant")
@@ -72,13 +71,12 @@ async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
72
  audio_temperature=0.8,
73
  audio_top_k=4,
74
  ):
 
75
  if token.numel() == 1:
76
  continue
77
 
78
- token_id = token.item()
79
-
80
- if processor.audio_token_start <= token_id <= processor.audio_token_end:
81
- audio_buffer.append(token)
82
 
83
  if len(audio_buffer) >= CHUNK_SIZE:
84
  audio_codes = (
@@ -86,19 +84,16 @@ async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
86
  .unsqueeze(0)
87
  .to(DEVICE)
88
  )
89
-
90
  try:
91
  waveform = processor.decode(audio_codes)
92
- except Exception:
 
 
 
 
 
 
93
  audio_buffer.clear()
94
- continue
95
-
96
- waveform = waveform.squeeze().cpu().numpy()
97
- waveform = np.clip(waveform, -1.0, 1.0)
98
- audio_int16 = (waveform * 32767).astype(np.int16)
99
-
100
- await websocket.send_bytes(audio_int16.tobytes())
101
- audio_buffer.clear()
102
 
103
  # flush remaining
104
  if len(audio_buffer) > 1:
@@ -113,8 +108,8 @@ async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
113
  waveform = np.clip(waveform, -1.0, 1.0)
114
  audio_int16 = (waveform * 32767).astype(np.int16)
115
  await websocket.send_bytes(audio_int16.tobytes())
116
- except Exception:
117
- pass
118
 
119
  await websocket.send_text(json.dumps({"type": "done"}))
120
 
 
47
 
48
 
49
  async def generate_response(websocket: WebSocket, audio_np: np.ndarray):
 
50
  chat = ChatState(processor)
51
 
52
  chat.new_turn("system")
 
55
 
56
  chat.new_turn("user")
57
  audio_tensor = torch.from_numpy(audio_np[np.newaxis, :]).to(dtype=torch.float32)
58
+ chat.add_audio(audio_tensor, sampling_rate=SAMPLE_RATE)
59
  chat.end_turn()
60
 
61
  chat.new_turn("assistant")
 
71
  audio_temperature=0.8,
72
  audio_top_k=4,
73
  ):
74
+ # numel()==1 means text token
75
  if token.numel() == 1:
76
  continue
77
 
78
+ # multi-element tensor = audio codes chunk
79
+ audio_buffer.append(token)
 
 
80
 
81
  if len(audio_buffer) >= CHUNK_SIZE:
82
  audio_codes = (
 
84
  .unsqueeze(0)
85
  .to(DEVICE)
86
  )
 
87
  try:
88
  waveform = processor.decode(audio_codes)
89
+ waveform = waveform.squeeze().cpu().numpy()
90
+ waveform = np.clip(waveform, -1.0, 1.0)
91
+ audio_int16 = (waveform * 32767).astype(np.int16)
92
+ await websocket.send_bytes(audio_int16.tobytes())
93
+ except Exception as e:
94
+ print(f"[WARN] decode error: {e}")
95
+ finally:
96
  audio_buffer.clear()
 
 
 
 
 
 
 
 
97
 
98
  # flush remaining
99
  if len(audio_buffer) > 1:
 
108
  waveform = np.clip(waveform, -1.0, 1.0)
109
  audio_int16 = (waveform * 32767).astype(np.int16)
110
  await websocket.send_bytes(audio_int16.tobytes())
111
+ except Exception as e:
112
+ print(f"[WARN] flush decode error: {e}")
113
 
114
  await websocket.send_text(json.dumps({"type": "done"}))
115