tudeplom commited on
Commit
63db033
·
verified ·
1 Parent(s): 09e278b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -48
app.py CHANGED
@@ -88,24 +88,17 @@ HTML_CONTENT = """
88
  body: JSON.stringify({ prompt: text })
89
  });
90
 
91
- if (response.ok) {
92
- const blob = await response.blob();
93
- const url = URL.createObjectURL(blob);
 
94
  const audio = document.getElementById('audioPlayer');
95
  audio.src = url;
96
  audio.style.display = 'block';
97
  audio.play();
98
-
99
- const textResponse = await fetch('/llama', {
100
- method: 'POST',
101
- headers: { 'Content-Type': 'application/json' },
102
- body: JSON.stringify({ prompt: text })
103
- });
104
- const textData = await textResponse.json();
105
- addMessage('Bot: ' + (textData.text || 'Tôi không hiểu bạn nói gì.'));
106
  } else {
107
- const errorData = await response.json();
108
- addMessage('Bot: Lỗi - ' + (errorData.error || 'Không có phản hồi'));
109
  }
110
  } catch (e) {
111
  addMessage('Bot: Lỗi kết nối - ' + e.message);
@@ -141,32 +134,18 @@ HTML_CONTENT = """
141
  body: formData
142
  });
143
 
144
- if (response.ok) {
145
- const blob = await response.blob();
146
- const url = URL.createObjectURL(blob);
 
147
  const audio = document.getElementById('audioPlayer');
148
  audio.src = url;
149
  audio.style.display = 'block';
150
  audio.play();
151
-
152
- const sttResponse = await fetch('/stt', {
153
- method: 'POST',
154
- body: formData
155
- });
156
- const sttData = await sttResponse.json();
157
- if (sttData.text) {
158
- addMessage('Bạn: ' + sttData.text);
159
- const llamaResponse = await fetch('/llama', {
160
- method: 'POST',
161
- headers: { 'Content-Type': 'application/json' },
162
- body: JSON.stringify({ prompt: sttData.text })
163
- });
164
- const llamaData = await llamaResponse.json();
165
- addMessage('Bot: ' + (llamaData.text || 'Tôi không hiểu bạn nói gì.'));
166
- }
167
  } else {
168
- const errorData = await response.json();
169
- addMessage('Bot: Lỗi - ' + (errorData.error || 'Không có phản hồi'));
170
  }
171
  } catch (e) {
172
  addMessage('Bot: Lỗi kết nối - ' + e.message);
@@ -198,7 +177,10 @@ async def text_to_speech(text: str):
198
  audio = client.text_to_speech(model=TTS_MODEL, text=text)
199
  with open(output_path, "wb") as f:
200
  f.write(audio)
201
- return FileResponse(output_path, media_type="audio/wav", filename="output.wav")
 
 
 
202
  except Exception as e:
203
  print(f"❌ Lỗi TTS: {e}")
204
  return {"error": str(e)}
@@ -221,7 +203,6 @@ async def generate_text(prompt: str):
221
  print(f"🔄 Đang xử lý LLaMA với model {LLAMA_MODEL}...")
222
  output = llama_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
223
  print(f"Output từ LLaMA: {output}")
224
- # Nếu output rỗng hoặc không có ý nghĩa, trả về mặc định
225
  if not output or output.strip() == prompt.strip():
226
  output = "Tôi không hiểu bạn nói gì."
227
  return {"text": output}
@@ -235,16 +216,12 @@ async def chat(prompt: str):
235
  # Gửi tới LLaMA
236
  llama_output = llama_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
237
  print(f"LLaMA output: {llama_output}")
238
- # Xử lý nếu LLaMA không trả về kết quả hợp lệ
239
  if not llama_output or llama_output.strip() == prompt.strip():
240
  llama_output = "Tôi không hiểu bạn nói gì."
241
-
242
  # Tạo TTS từ output của LLaMA
243
- output_path = os.path.join(TEMP_DIR, "output.wav")
244
- audio = client.text_to_speech(model=TTS_MODEL, text=llama_output)
245
- with open(output_path, "wb") as f:
246
- f.write(audio)
247
- return FileResponse(output_path, media_type="audio/wav", filename="output.wav")
248
  except Exception as e:
249
  print(f"❌ Lỗi chat: {e}")
250
  return {"error": str(e)}
@@ -256,6 +233,8 @@ async def audio_chat(file: UploadFile = File(...)):
256
  audio_data = await file.read()
257
  stt_output = client.automatic_speech_recognition(model=STT_MODEL, data=audio_data).get("text", "")
258
  print(f"STT output: {stt_output}")
 
 
259
 
260
  # LLaMA: Sinh câu trả lời
261
  llama_output = llama_pipeline(stt_output, max_new_tokens=100)[0]["generated_text"]
@@ -264,11 +243,8 @@ async def audio_chat(file: UploadFile = File(...)):
264
  llama_output = "Tôi không hiểu bạn nói gì."
265
 
266
  # TTS: Chuyển câu trả lời thành âm thanh
267
- output_path = os.path.join(TEMP_DIR, "output.wav")
268
- audio = client.text_to_speech(model=TTS_MODEL, text=llama_output)
269
- with open(output_path, "wb") as f:
270
- f.write(audio)
271
- return FileResponse(output_path, media_type="audio/wav", filename="output.wav")
272
  except Exception as e:
273
  print(f"❌ Lỗi audio_chat: {e}")
274
  return {"error": str(e)}
 
88
  body: JSON.stringify({ prompt: text })
89
  });
90
 
91
+ const data = await response.json();
92
+ if (data.audio) {
93
+ const audioBlob = new Blob([new Uint8Array(atob(data.audio).split('').map(c => c.charCodeAt(0)))], { type: 'audio/wav' });
94
+ const url = URL.createObjectURL(audioBlob);
95
  const audio = document.getElementById('audioPlayer');
96
  audio.src = url;
97
  audio.style.display = 'block';
98
  audio.play();
99
+ addMessage('Bot: ' + data.text);
 
 
 
 
 
 
 
100
  } else {
101
+ addMessage('Bot: Lỗi - ' + (data.error || 'Không có phản hồi'));
 
102
  }
103
  } catch (e) {
104
  addMessage('Bot: Lỗi kết nối - ' + e.message);
 
134
  body: formData
135
  });
136
 
137
+ const data = await response.json();
138
+ if (data.audio) {
139
+ const audioBlob = new Blob([new Uint8Array(atob(data.audio).split('').map(c => c.charCodeAt(0)))], { type: 'audio/wav' });
140
+ const url = URL.createObjectURL(audioBlob);
141
  const audio = document.getElementById('audioPlayer');
142
  audio.src = url;
143
  audio.style.display = 'block';
144
  audio.play();
145
+ addMessage('Bạn: ' + data.input);
146
+ addMessage('Bot: ' + data.text);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  } else {
148
+ addMessage('Bot: Lỗi - ' + (data.error || 'Không có phản hồi'));
 
149
  }
150
  } catch (e) {
151
  addMessage('Bot: Lỗi kết nối - ' + e.message);
 
177
  audio = client.text_to_speech(model=TTS_MODEL, text=text)
178
  with open(output_path, "wb") as f:
179
  f.write(audio)
180
+ with open(output_path, "rb") as f:
181
+ audio_data = f.read()
182
+ import base64
183
+ return base64.b64encode(audio_data).decode('utf-8')
184
  except Exception as e:
185
  print(f"❌ Lỗi TTS: {e}")
186
  return {"error": str(e)}
 
203
  print(f"🔄 Đang xử lý LLaMA với model {LLAMA_MODEL}...")
204
  output = llama_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
205
  print(f"Output từ LLaMA: {output}")
 
206
  if not output or output.strip() == prompt.strip():
207
  output = "Tôi không hiểu bạn nói gì."
208
  return {"text": output}
 
216
  # Gửi tới LLaMA
217
  llama_output = llama_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
218
  print(f"LLaMA output: {llama_output}")
 
219
  if not llama_output or llama_output.strip() == prompt.strip():
220
  llama_output = "Tôi không hiểu bạn nói gì."
221
+
222
  # Tạo TTS từ output của LLaMA
223
+ audio_base64 = await text_to_speech(llama_output)
224
+ return {"text": llama_output, "audio": audio_base64}
 
 
 
225
  except Exception as e:
226
  print(f"❌ Lỗi chat: {e}")
227
  return {"error": str(e)}
 
233
  audio_data = await file.read()
234
  stt_output = client.automatic_speech_recognition(model=STT_MODEL, data=audio_data).get("text", "")
235
  print(f"STT output: {stt_output}")
236
+ if not stt_output:
237
+ stt_output = "Không nghe được gì."
238
 
239
  # LLaMA: Sinh câu trả lời
240
  llama_output = llama_pipeline(stt_output, max_new_tokens=100)[0]["generated_text"]
 
243
  llama_output = "Tôi không hiểu bạn nói gì."
244
 
245
  # TTS: Chuyển câu trả lời thành âm thanh
246
+ audio_base64 = await text_to_speech(llama_output)
247
+ return {"input": stt_output, "text": llama_output, "audio": audio_base64}
 
 
 
248
  except Exception as e:
249
  print(f"❌ Lỗi audio_chat: {e}")
250
  return {"error": str(e)}