tudeplom commited on
Commit
78f7e4b
·
verified ·
1 Parent(s): d44e8cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -44
app.py CHANGED
@@ -6,20 +6,27 @@ import torch
6
  import os
7
  import uvicorn
8
 
 
9
  app = FastAPI()
10
 
 
11
  HF_API_KEY = os.getenv("HF_API_KEY")
12
  if not HF_API_KEY:
13
  raise ValueError("❌ Thiếu HF_API_KEY!")
14
 
 
15
  client = InferenceClient(token=HF_API_KEY)
 
 
16
  TEMP_DIR = "temp"
17
  os.makedirs(TEMP_DIR, exist_ok=True)
18
 
 
19
  STT_MODEL = "openai/whisper-tiny.en"
20
  TTS_MODEL = "facebook/mms-tts-eng"
21
  LLAMA_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
22
 
 
23
  try:
24
  llama_pipeline = pipeline(
25
  "text-generation",
@@ -32,13 +39,14 @@ except Exception as e:
32
  print(f"❌ Lỗi tải LLaMA: {e}")
33
  raise
34
 
 
35
  HTML_CONTENT = """
36
  <!DOCTYPE html>
37
  <html lang="vi">
38
  <head>
39
  <meta charset="UTF-8">
40
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
41
- <title>Chatbot TTS, STT & LLaMA 3</title>
42
  <style>
43
  body { font-family: Arial, sans-serif; margin: 0; padding: 20px; background: #f0f0f0; }
44
  .chat-container { max-width: 600px; margin: auto; }
@@ -52,14 +60,13 @@ HTML_CONTENT = """
52
  </head>
53
  <body>
54
  <div class="chat-container">
55
- <h1>Chatbot TTS, STT & LLaMA 3</h1>
56
  <div class="chat-box" id="chatBox"></div>
57
  <div class="input-area">
58
- <input type="text" id="textInput" placeholder="Nhập văn bản hoặc hỏi LLaMA">
59
- <button onclick="sendText()">Gửi TTS</button>
60
- <button onclick="askLlama()">Hỏi LLaMA</button>
61
- <button id="recordButton" onclick="startRecording()">Bắt đầu ghi âm</button>
62
- <button id="stopButton" onclick="stopRecording()" disabled>Dừng ghi âm</button>
63
  </div>
64
  <audio id="audioPlayer" controls style="display: none;"></audio>
65
  </div>
@@ -68,47 +75,55 @@ HTML_CONTENT = """
68
  let mediaRecorder;
69
  let audioChunks = [];
70
 
71
- async function sendText() {
 
72
  const text = document.getElementById('textInput').value;
73
  if (!text) return;
74
  addMessage('Bạn: ' + text);
75
- const response = await fetch('/tts', {
76
- method: 'POST',
77
- headers: { 'Content-Type': 'application/json' },
78
- body: JSON.stringify({ text: text })
79
- });
80
- const blob = await response.blob();
81
- const url = URL.createObjectURL(blob);
82
- const audio = document.getElementById('audioPlayer');
83
- audio.src = url;
84
- audio.style.display = 'block';
85
- audio.play();
86
- addMessage('Bot: Đã tạo âm thanh!');
87
- }
88
 
89
- async function askLlama() {
90
- const text = document.getElementById('textInput').value;
91
- if (!text) return;
92
- addMessage('Bạn: ' + text);
93
- const response = await fetch('/llama', {
94
- method: 'POST',
95
- headers: { 'Content-Type': 'application/json' },
96
- body: JSON.stringify({ prompt: text })
97
- });
98
- const data = await response.json();
99
- if (data.text) {
100
- addMessage('LLaMA: ' + data.text);
101
- } else {
102
- addMessage('LLaMA: Lỗi - ' + (data.error || 'Không có phản hồi'));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  }
104
  }
105
 
 
106
  async function startRecording() {
107
  const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
108
  mediaRecorder = new MediaRecorder(stream);
109
  audioChunks = [];
110
  mediaRecorder.ondataavailable = event => audioChunks.push(event.data);
111
- mediaRecorder.onstop = sendAudio;
112
  mediaRecorder.start();
113
  document.getElementById('recordButton').disabled = true;
114
  document.getElementById('stopButton').disabled = false;
@@ -121,17 +136,51 @@ HTML_CONTENT = """
121
  document.getElementById('stopButton').disabled = true;
122
  }
123
 
124
- async function sendAudio() {
125
  const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
126
  const formData = new FormData();
127
  formData.append('file', audioBlob, 'recording.wav');
128
- const response = await fetch('/stt', {
129
- method: 'POST',
130
- body: formData
131
- });
132
- const data = await response.json();
133
- if (data.text) addMessage('Bot: ' + data.text);
134
- else addMessage('Bot: Lỗi - ' + data.error);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  }
136
 
137
  function addMessage(message) {
@@ -187,6 +236,47 @@ async def generate_text(prompt: str):
187
  print(f"❌ Lỗi LLaMA: {e}")
188
  return {"error": str(e)}
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  if __name__ == "__main__":
191
  print("🚀 Khởi động FastAPI Server...")
192
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
6
  import os
7
  import uvicorn
8
 
9
+ # Khởi tạo FastAPI
10
  app = FastAPI()
11
 
12
+ # Lấy API key từ biến môi trường
13
  HF_API_KEY = os.getenv("HF_API_KEY")
14
  if not HF_API_KEY:
15
  raise ValueError("❌ Thiếu HF_API_KEY!")
16
 
17
+ # Khởi tạo Hugging Face Client cho TTS/STT
18
  client = InferenceClient(token=HF_API_KEY)
19
+
20
+ # Tạo thư mục lưu file tạm
21
  TEMP_DIR = "temp"
22
  os.makedirs(TEMP_DIR, exist_ok=True)
23
 
24
+ # Mô hình TTS, STT và LLaMA
25
  STT_MODEL = "openai/whisper-tiny.en"
26
  TTS_MODEL = "facebook/mms-tts-eng"
27
  LLAMA_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
28
 
29
+ # Tải pipeline LLaMA 3
30
  try:
31
  llama_pipeline = pipeline(
32
  "text-generation",
 
39
  print(f"❌ Lỗi tải LLaMA: {e}")
40
  raise
41
 
42
+ # Giao diện HTML
43
  HTML_CONTENT = """
44
  <!DOCTYPE html>
45
  <html lang="vi">
46
  <head>
47
  <meta charset="UTF-8">
48
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
49
+ <title>Chatbot Tự Động</title>
50
  <style>
51
  body { font-family: Arial, sans-serif; margin: 0; padding: 20px; background: #f0f0f0; }
52
  .chat-container { max-width: 600px; margin: auto; }
 
60
  </head>
61
  <body>
62
  <div class="chat-container">
63
+ <h1>Chatbot Tự Động</h1>
64
  <div class="chat-box" id="chatBox"></div>
65
  <div class="input-area">
66
+ <input type="text" id="textInput" placeholder="Nhập câu hỏi hoặc nhấn Enter" onkeypress="if(event.key === 'Enter') sendChat()">
67
+ <button onclick="sendChat()">Gửi</button>
68
+ <button id="recordButton" onclick="startRecording()">Ghi âm</button>
69
+ <button id="stopButton" onclick="stopRecording()" disabled>Dừng</button>
 
70
  </div>
71
  <audio id="audioPlayer" controls style="display: none;"></audio>
72
  </div>
 
75
  let mediaRecorder;
76
  let audioChunks = [];
77
 
78
+ // Hàm gửi chat (văn bản -> LLaMA -> TTS)
79
+ async function sendChat() {
80
  const text = document.getElementById('textInput').value;
81
  if (!text) return;
82
  addMessage('Bạn: ' + text);
83
+ document.getElementById('textInput').value = ''; // Xóa input sau khi gửi
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ try {
86
+ // Gửi tới endpoint /chat (tích hợp LLaMA và TTS)
87
+ const response = await fetch('/chat', {
88
+ method: 'POST',
89
+ headers: { 'Content-Type': 'application/json' },
90
+ body: JSON.stringify({ prompt: text })
91
+ });
92
+
93
+ if (response.ok) {
94
+ const blob = await response.blob();
95
+ const url = URL.createObjectURL(blob);
96
+ const audio = document.getElementById('audioPlayer');
97
+ audio.src = url;
98
+ audio.style.display = 'block';
99
+ audio.play();
100
+
101
+ // Lấy văn bản từ LLaMA để hiển thị
102
+ const textResponse = await fetch('/llama', {
103
+ method: 'POST',
104
+ headers: { 'Content-Type': 'application/json' },
105
+ body: JSON.stringify({ prompt: text })
106
+ });
107
+ const textData = await textResponse.json();
108
+ if (textData.text) {
109
+ addMessage('Bot: ' + textData.text);
110
+ }
111
+ } else {
112
+ const errorData = await response.json();
113
+ addMessage('Bot: Lỗi - ' + (errorData.error || 'Không có phản hồi'));
114
+ }
115
+ } catch (e) {
116
+ addMessage('Bot: Lỗi kết nối - ' + e.message);
117
  }
118
  }
119
 
120
+ // Ghi âm (STT -> LLaMA -> TTS)
121
  async function startRecording() {
122
  const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
123
  mediaRecorder = new MediaRecorder(stream);
124
  audioChunks = [];
125
  mediaRecorder.ondataavailable = event => audioChunks.push(event.data);
126
+ mediaRecorder.onstop = processAudio;
127
  mediaRecorder.start();
128
  document.getElementById('recordButton').disabled = true;
129
  document.getElementById('stopButton').disabled = false;
 
136
  document.getElementById('stopButton').disabled = true;
137
  }
138
 
139
+ async function processAudio() {
140
  const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
141
  const formData = new FormData();
142
  formData.append('file', audioBlob, 'recording.wav');
143
+
144
+ try {
145
+ // Gửi tới endpoint /audio_chat (STT -> LLaMA -> TTS)
146
+ const response = await fetch('/audio_chat', {
147
+ method: 'POST',
148
+ body: formData
149
+ });
150
+
151
+ if (response.ok) {
152
+ const blob = await response.blob();
153
+ const url = URL.createObjectURL(blob);
154
+ const audio = document.getElementById('audioPlayer');
155
+ audio.src = url;
156
+ audio.style.display = 'block';
157
+ audio.play();
158
+
159
+ // Lấy văn bản STT và LLaMA để hiển thị
160
+ const sttResponse = await fetch('/stt', {
161
+ method: 'POST',
162
+ body: formData
163
+ });
164
+ const sttData = await sttResponse.json();
165
+ if (sttData.text) {
166
+ addMessage('Bạn: ' + sttData.text);
167
+ const llamaResponse = await fetch('/llama', {
168
+ method: 'POST',
169
+ headers: { 'Content-Type': 'application/json' },
170
+ body: JSON.stringify({ prompt: sttData.text })
171
+ });
172
+ const llamaData = await llamaResponse.json();
173
+ if (llamaData.text) {
174
+ addMessage('Bot: ' + llamaData.text);
175
+ }
176
+ }
177
+ } else {
178
+ const errorData = await response.json();
179
+ addMessage('Bot: Lỗi - ' + (errorData.error || 'Không có phản hồi'));
180
+ }
181
+ } catch (e) {
182
+ addMessage('Bot: Lỗi kết nối - ' + e.message);
183
+ }
184
  }
185
 
186
  function addMessage(message) {
 
236
  print(f"❌ Lỗi LLaMA: {e}")
237
  return {"error": str(e)}
238
 
239
+ # Endpoint tích hợp văn bản -> LLaMA -> TTS
240
+ @app.post("/chat")
241
+ async def chat(prompt: str):
242
+ try:
243
+ # Gửi tới LLaMA
244
+ llama_output = llama_pipeline(prompt, max_new_tokens=100)[0]["generated_text"]
245
+ print(f"LLaMA output: {llama_output}")
246
+
247
+ # Tạo TTS từ output của LLaMA
248
+ output_path = os.path.join(TEMP_DIR, "output.wav")
249
+ audio = client.text_to_speech(model=TTS_MODEL, text=llama_output)
250
+ with open(output_path, "wb") as f:
251
+ f.write(audio)
252
+ return FileResponse(output_path, media_type="audio/wav", filename="output.wav")
253
+ except Exception as e:
254
+ print(f"❌ Lỗi chat: {e}")
255
+ return {"error": str(e)}
256
+
257
+ # Endpoint tích hợp STT -> LLaMA -> TTS
258
+ @app.post("/audio_chat")
259
+ async def audio_chat(file: UploadFile = File(...)):
260
+ try:
261
+ # STT: Chuyển giọng nói thành văn bản
262
+ audio_data = await file.read()
263
+ stt_output = client.automatic_speech_recognition(model=STT_MODEL, data=audio_data).get("text", "")
264
+ print(f"STT output: {stt_output}")
265
+
266
+ # LLaMA: Sinh câu trả lời
267
+ llama_output = llama_pipeline(stt_output, max_new_tokens=100)[0]["generated_text"]
268
+ print(f"LLaMA output: {llama_output}")
269
+
270
+ # TTS: Chuyển câu trả lời thành âm thanh
271
+ output_path = os.path.join(TEMP_DIR, "output.wav")
272
+ audio = client.text_to_speech(model=TTS_MODEL, text=llama_output)
273
+ with open(output_path, "wb") as f:
274
+ f.write(audio)
275
+ return FileResponse(output_path, media_type="audio/wav", filename="output.wav")
276
+ except Exception as e:
277
+ print(f"❌ Lỗi audio_chat: {e}")
278
+ return {"error": str(e)}
279
+
280
  if __name__ == "__main__":
281
  print("🚀 Khởi động FastAPI Server...")
282
  uvicorn.run(app, host="0.0.0.0", port=7860)