zsolnai commited on
Commit
102e36f
Β·
1 Parent(s): 276657d

Fix claude mistake v5

Browse files
Files changed (1) hide show
  1. app.py +268 -118
app.py CHANGED
@@ -2,90 +2,59 @@ import os
2
  import tempfile
3
 
4
  import gradio as gr
5
-
6
- # Note: Added numpy/soundfile import which might be needed by TTS/Whisper internally
7
  import numpy as np
8
  import soundfile as sf
9
  import torch
 
 
10
 
11
  # --- Device Setup (Explicitly set to CPU) ---
12
  device = "cpu"
13
 
14
- # --- STT Setup (using Hugging Face's transformers pipeline for Whisper) ---
15
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
16
-
17
  STT_MODEL_NAME = "openai/whisper-tiny.en"
18
  stt_pipe = pipeline("automatic-speech-recognition", model=STT_MODEL_NAME, device=device)
19
 
20
- # --- LLM Setup (using Hugging Face's transformers for text generation) ---
21
  LLM_MODEL_NAME = "microsoft/DialoGPT-medium"
22
  chatbot_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
23
  chatbot_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME)
24
  chatbot_model.to(device)
25
 
26
- # --- TTS Setup (using coqui-ai/TTS) ---
27
- from TTS.api import TTS
28
-
29
  TTS_MODEL_NAME = "tts_models/en/ljspeech/tacotron2-DDC"
30
  tts_model = TTS(model_name=TTS_MODEL_NAME, progress_bar=False)
31
 
32
 
33
- def speech_to_text(audio_file_path):
34
- """Performs Speech-to-Text using the Whisper model."""
35
- if audio_file_path is None:
36
- return "Please upload an audio file or record your voice."
37
- try:
38
- result = stt_pipe(audio_file_path)
39
- return result["text"]
40
- except Exception as e:
41
- return f"Error during STT: {e}"
42
-
43
-
44
- def text_to_speech(text):
45
- """Performs Text-to-Speech using the Coqui TTS model."""
46
- if not text:
47
- return None, "Please enter text for synthesis."
48
- try:
49
- # Create a temporary file for each request
50
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
51
- output_path = temp_file.name
52
- temp_file.close()
53
-
54
- # Generate the speech (slow on CPU)
55
- tts_model.tts_to_file(
56
- text=text,
57
- file_path=output_path,
58
- )
59
- return output_path, "Speech synthesis complete. (Completed slowly on CPU)"
60
- except Exception as e:
61
- # Clean up temp file on failure
62
- if os.path.exists(output_path):
63
- os.remove(output_path)
64
- return None, f"Error during TTS: {e}"
65
 
66
 
67
  def chat_with_bot(message, history, chat_history_ids=None):
68
- """Chat with the conversational AI model using DialoGPT."""
 
 
 
69
  if not message or not message.strip():
70
- # If message is empty, return current history and state
71
- return history, chat_history_ids
 
72
 
73
  try:
74
- # Move inputs to CPU (required for consistent CPU-only operation)
75
  new_input_ids = chatbot_tokenizer.encode(
76
  message + chatbot_tokenizer.eos_token, return_tensors="pt"
77
  ).to(device)
78
 
79
- # Append the new user input tokens to the chat history
80
  if chat_history_ids is not None:
81
- # Ensure history is on the correct device before concatenation
82
- bot_input_ids = torch.cat(
83
- [chat_history_ids.to(device), new_input_ids], dim=-1
84
- )
85
  else:
86
  bot_input_ids = new_input_ids
87
 
88
- # Generate a response
89
  chat_history_ids = chatbot_model.generate(
90
  bot_input_ids,
91
  max_length=1000,
@@ -96,23 +65,92 @@ def chat_with_bot(message, history, chat_history_ids=None):
96
  top_p=0.95,
97
  )
98
 
99
- # Decode the response
100
  response = chatbot_tokenizer.decode(
101
- # Select the response part only
102
- chat_history_ids[:, bot_input_ids.shape[-1] :][0],
103
- skip_special_tokens=True,
104
  )
105
 
106
- # CRITICAL FIX: Append to history in the Gradio Chatbot (list of lists/tuples) format
107
  history.append((message, response))
108
 
109
- # Return the updated history for display and the new state for the next turn
110
- return history, chat_history_ids
111
 
112
  except Exception as e:
113
- # Append error message to history using the correct format
114
  history.append((message, f"Error: {e}"))
115
- return history, chat_history_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  # --- Gradio Interface ---
@@ -125,26 +163,153 @@ custom_css = """
125
  height: 400px;
126
  }
127
  """
128
- # CRITICAL FIX: The 'css' argument must be passed to launch() (it's correct here)
 
129
  with gr.Blocks() as demo:
130
- gr.Markdown("# πŸ—£οΈ STT, TTS & Chat App (CPU Only)")
131
  gr.Markdown(
132
- "**NOTE:** This app is running on CPU-only hardware. Speech-to-Text (Whisper) is fast, but **Text-to-Speech (Coqui TTS) and Chat will be slow**."
133
  )
134
 
135
- # Hidden state to store chat history IDs (PyTorch Tensor)
136
- chat_state = gr.State(value=None)
137
 
138
- # Create tabs for different features
139
  with gr.Tabs():
140
- # Tab 1: Chat Interface
141
- with gr.TabItem("πŸ’¬ Chat"):
142
- gr.Markdown("## Chat with AI Assistant")
 
143
  gr.Markdown(
144
- "Have a conversation with the DialoGPT model. It remembers context from your conversation!"
 
 
 
 
 
145
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- # Initialized to an empty list, which Gradio's Chatbot expects
148
  chatbot = gr.Chatbot(
149
  label="Conversation", elem_classes=["chatbot"], value=[]
150
  )
@@ -153,63 +318,48 @@ with gr.Blocks() as demo:
153
  placeholder="Type your message here and press Enter...",
154
  lines=2,
155
  )
 
156
  with gr.Row():
157
  submit_btn = gr.Button("Send", variant="primary")
158
  clear_btn = gr.Button("Clear Chat")
159
 
160
- # Functionality to handle chat submission
161
- # The 'chatbot' component provides the 'history' (list of tuples)
162
- fn_call = submit_btn.click(
163
- chat_with_bot,
164
- inputs=[msg, chatbot, chat_state],
165
- outputs=[chatbot, chat_state],
166
- # Clear the message box after the main function runs
167
  ).then(lambda: "", None, msg)
168
 
169
- # Ensure msg.submit does the same thing as the submit button
170
- msg.submit(
171
- chat_with_bot,
172
- inputs=[msg, chatbot, chat_state],
173
- outputs=[chatbot, chat_state],
 
174
  ).then(lambda: "", None, msg)
175
 
176
- # Clear button resets both the displayed history and the token history state
177
- clear_btn.click(lambda: ([], None), None, [chatbot, chat_state])
178
-
179
- # Tab 2: STT
180
- with gr.TabItem("🎀 Speech-to-Text"):
181
- with gr.Row():
182
- with gr.Column():
183
- gr.Markdown("## 🎀 Speech-to-Text (STT)")
184
- audio_input = gr.Audio(
185
- sources=["microphone", "upload"],
186
- type="filepath",
187
- label="Input Audio (Mic or Upload)",
188
- )
189
- stt_button = gr.Button("Convert Speech to Text")
190
- with gr.Column():
191
- stt_output = gr.Textbox(label="Transcribed Text", lines=3)
192
-
193
- stt_button.click(fn=speech_to_text, inputs=audio_input, outputs=stt_output)
194
 
195
- # Tab 3: TTS
196
- with gr.TabItem("πŸ”Š Text-to-Speech"):
197
- with gr.Row():
198
- with gr.Column():
199
- gr.Markdown("## πŸ”Š Text-to-Speech (TTS)")
200
- text_input = gr.Textbox(
201
- label="Text to Synthesize",
202
- lines=3,
203
- value="Hello there, this is a demonstration of the text to speech model.",
204
- )
205
- tts_button = gr.Button("Synthesize Speech (Will be slow)")
206
- with gr.Column():
207
- audio_output = gr.Audio(label="Synthesized Audio")
208
- tts_status = gr.Textbox(elem_id="status", label="Status")
209
 
210
- tts_button.click(
211
- fn=text_to_speech, inputs=text_input, outputs=[audio_output, tts_status]
 
 
 
 
 
 
 
 
 
 
 
212
  )
213
 
214
- # Pass the 'css' argument to launch()
215
  demo.launch(css=custom_css)
 
2
  import tempfile
3
 
4
  import gradio as gr
 
 
5
  import numpy as np
6
  import soundfile as sf
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
+ from TTS.api import TTS
10
 
11
  # --- Device Setup (Explicitly set to CPU) ---
12
  device = "cpu"
13
 
14
+ # --- Model Initialization ---
15
+ # STT
 
16
  STT_MODEL_NAME = "openai/whisper-tiny.en"
17
  stt_pipe = pipeline("automatic-speech-recognition", model=STT_MODEL_NAME, device=device)
18
 
19
+ # LLM (Chatbot)
20
  LLM_MODEL_NAME = "microsoft/DialoGPT-medium"
21
  chatbot_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
22
  chatbot_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_NAME)
23
  chatbot_model.to(device)
24
 
25
+ # TTS
 
 
26
  TTS_MODEL_NAME = "tts_models/en/ljspeech/tacotron2-DDC"
27
  tts_model = TTS(model_name=TTS_MODEL_NAME, progress_bar=False)
28
 
29
 
30
+ # --- Core Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  def chat_with_bot(message, history, chat_history_ids=None):
34
+ """
35
+ Chat with the conversational AI model using DialoGPT.
36
+ Returns: (updated_history, updated_chat_ids, response_text)
37
+ """
38
  if not message or not message.strip():
39
+ # Add an empty entry to history to maintain the structure expected by Gradio
40
+ history.append(("", ""))
41
+ return history, chat_history_ids, ""
42
 
43
  try:
44
+ # 1. Encode user message and move to CPU
45
  new_input_ids = chatbot_tokenizer.encode(
46
  message + chatbot_tokenizer.eos_token, return_tensors="pt"
47
  ).to(device)
48
 
49
+ # 2. Prepare full input IDs (previous history + new message)
50
  if chat_history_ids is not None:
51
+ # Ensure history tensor is on CPU before concatenation
52
+ chat_history_ids = chat_history_ids.to(device)
53
+ bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
 
54
  else:
55
  bot_input_ids = new_input_ids
56
 
57
+ # 3. Generate response
58
  chat_history_ids = chatbot_model.generate(
59
  bot_input_ids,
60
  max_length=1000,
 
65
  top_p=0.95,
66
  )
67
 
68
+ # 4. Decode response
69
  response = chatbot_tokenizer.decode(
70
+ chat_history_ids[:, bot_input_ids.shape[-1] :][0], skip_special_tokens=True
 
 
71
  )
72
 
73
+ # CRITICAL FIX: Append to history in the Gradio Chatbot (list of tuples) format
74
  history.append((message, response))
75
 
76
+ return history, chat_history_ids, response
 
77
 
78
  except Exception as e:
79
+ # CRITICAL FIX: Append error to history in the Gradio Chatbot (list of tuples) format
80
  history.append((message, f"Error: {e}"))
81
+ return history, chat_history_ids, f"Error: {e}"
82
+
83
+
84
+ def text_to_speech_from_chat(chat_response):
85
+ """Takes the chat response and converts it to speech."""
86
+ if not chat_response or chat_response.startswith("Error"):
87
+ return None, "No valid response to synthesize."
88
+
89
+ output_path = None
90
+ try:
91
+ # Create a temporary file
92
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
93
+ output_path = temp_file.name
94
+ temp_file.close()
95
+
96
+ # Generate the speech (slow on CPU)
97
+ tts_model.tts_to_file(
98
+ text=chat_response,
99
+ file_path=output_path,
100
+ )
101
+ return output_path, "Speech synthesis complete. (Completed slowly on CPU)"
102
+
103
+ except Exception as e:
104
+ # Clean up temp file on failure
105
+ if output_path and os.path.exists(output_path):
106
+ os.remove(output_path)
107
+ return None, f"Error during TTS: {e}"
108
+
109
+
110
+ def speech_to_text_and_chat(audio_file_path, history, chat_history_ids):
111
+ """Performs STT, then Chatbot generation, returning the final response text and audio."""
112
+ if audio_file_path is None:
113
+ return (
114
+ "Please upload an audio file or record your voice.",
115
+ history,
116
+ chat_history_ids,
117
+ "",
118
+ None,
119
+ "Awaiting audio input.",
120
+ )
121
+
122
+ # 1. STT
123
+ try:
124
+ result = stt_pipe(audio_file_path)
125
+ transcribed_text = result["text"]
126
+ except Exception as e:
127
+ return (
128
+ f"Error during STT: {e}",
129
+ history,
130
+ chat_history_ids,
131
+ "",
132
+ None,
133
+ f"Error during STT: {e}",
134
+ )
135
+
136
+ # 2. Chatbot
137
+ # The third returned value, last_response_text, is the pure text response.
138
+ updated_history, updated_chat_ids, last_response_text = chat_with_bot(
139
+ transcribed_text, history, chat_history_ids
140
+ )
141
+
142
+ # 3. TTS
143
+ audio_path, status_text = text_to_speech_from_chat(last_response_text)
144
+
145
+ # Returns: transcription, history, chat_ids, response_text, audio_path, status
146
+ return (
147
+ transcribed_text,
148
+ updated_history,
149
+ updated_chat_ids,
150
+ last_response_text,
151
+ audio_path,
152
+ status_text,
153
+ )
154
 
155
 
156
  # --- Gradio Interface ---
 
163
  height: 400px;
164
  }
165
  """
166
+
167
+ # CRITICAL FIX: Removed css argument from gr.Blocks()
168
  with gr.Blocks() as demo:
169
+ gr.Markdown("# πŸ—£οΈ Integrated Voice Assistant (CPU Only)")
170
  gr.Markdown(
171
+ "**NOTE:** This app is running on CPU-only hardware. The full voice flow will be slow due to **Text-to-Speech**."
172
  )
173
 
174
+ # The global chat state can be used if tabs share history, or use local states per tab
175
+ global_chat_state = gr.State(value=None)
176
 
 
177
  with gr.Tabs():
178
+
179
+ # --- NEW FULL VOICE CHAT TAB (STT -> CHAT -> TTS) ---
180
+ with gr.TabItem("πŸ—£οΈ Voice Assistant"):
181
+ gr.Markdown("## Talk to the AI Assistant")
182
  gr.Markdown(
183
+ "Speak into the microphone. Your speech will be transcribed, sent to the chatbot, and the chatbot's text response will be converted to audio."
184
+ )
185
+
186
+ # States specific to this tab
187
+ voice_chat_history = gr.Chatbot(
188
+ label="Conversation Log", elem_classes=["chatbot"], value=[]
189
  )
190
+ voice_chat_state = gr.State(value=None) # Chat state IDs for this tab
191
+
192
+ with gr.Row():
193
+ audio_in = gr.Audio(
194
+ sources=["microphone", "upload"],
195
+ type="filepath",
196
+ label="Input Audio (Mic or Upload)",
197
+ )
198
+ voice_audio_out = gr.Audio(label="AI Voice Response", autoplay=True)
199
+
200
+ voice_transcription = gr.Textbox(label="User Transcription", lines=2)
201
+ voice_response_text = gr.Textbox(label="AI Response (Text)", lines=2)
202
+
203
+ with gr.Row():
204
+ run_btn = gr.Button("Transcribe, Chat & Speak", variant="primary")
205
+ clear_voice_btn = gr.Button("Clear Conversation")
206
+
207
+ voice_status = gr.Textbox(elem_id="status", label="Status")
208
+
209
+ # Chain the functions together
210
+ run_btn.click(
211
+ fn=speech_to_text_and_chat,
212
+ inputs=[audio_in, voice_chat_history, voice_chat_state],
213
+ outputs=[
214
+ voice_transcription,
215
+ voice_chat_history,
216
+ voice_chat_state,
217
+ voice_response_text,
218
+ voice_audio_out,
219
+ voice_status,
220
+ ],
221
+ )
222
+
223
+ clear_voice_btn.click(
224
+ lambda: (None, [], None, "", None, ""),
225
+ None,
226
+ [
227
+ audio_in,
228
+ voice_chat_history,
229
+ voice_chat_state,
230
+ voice_response_text,
231
+ voice_audio_out,
232
+ voice_status,
233
+ ],
234
+ )
235
+
236
+ # --- EXISTING CHAT -> TTS TAB ---
237
+ with gr.TabItem("πŸ’¬ Chat β†’ Voice Output"):
238
+ gr.Markdown("## πŸ’¬ Chat with Voice Output")
239
+
240
+ tts_chatbot = gr.Chatbot(
241
+ label="Conversation", elem_classes=["chatbot"], value=[]
242
+ )
243
+ tts_msg = gr.Textbox(
244
+ label="Your Message",
245
+ placeholder="Type your message here and press Enter...",
246
+ lines=2,
247
+ )
248
+ tts_chat_state = gr.State(value=None)
249
+
250
+ with gr.Row():
251
+ tts_submit_btn = gr.Button("Send & Speak", variant="primary")
252
+ tts_clear_btn = gr.Button("Clear Chat")
253
+
254
+ with gr.Row():
255
+ with gr.Column():
256
+ tts_response_text = gr.Textbox(label="AI Response (Text)", lines=3)
257
+ with gr.Column():
258
+ tts_audio_output = gr.Audio(label="AI Response (Audio)")
259
+ tts_status = gr.Textbox(elem_id="status", label="Status")
260
+
261
+ def chat_and_speak(message, history, chat_ids):
262
+ """Send message to chat and convert response to speech."""
263
+ # 1. Chatbot
264
+ updated_history, updated_ids, last_response = chat_with_bot(
265
+ message, history, chat_ids
266
+ )
267
+
268
+ # 2. TTS
269
+ audio_path, status = text_to_speech_from_chat(last_response)
270
+
271
+ return updated_history, updated_ids, last_response, audio_path, status
272
+
273
+ tts_submit_btn.click(
274
+ fn=chat_and_speak,
275
+ inputs=[tts_msg, tts_chatbot, tts_chat_state],
276
+ outputs=[
277
+ tts_chatbot,
278
+ tts_chat_state,
279
+ tts_response_text,
280
+ tts_audio_output,
281
+ tts_status,
282
+ ],
283
+ ).then(lambda: "", None, tts_msg)
284
+
285
+ tts_msg.submit(
286
+ fn=chat_and_speak,
287
+ inputs=[tts_msg, tts_chatbot, tts_chat_state],
288
+ outputs=[
289
+ tts_chatbot,
290
+ tts_chat_state,
291
+ tts_response_text,
292
+ tts_audio_output,
293
+ tts_status,
294
+ ],
295
+ ).then(lambda: "", None, tts_msg)
296
+
297
+ tts_clear_btn.click(
298
+ lambda: ([], None, "", None, "Awaiting input."),
299
+ None,
300
+ [
301
+ tts_chatbot,
302
+ tts_chat_state,
303
+ tts_response_text,
304
+ tts_audio_output,
305
+ tts_status,
306
+ ],
307
+ )
308
+
309
+ # --- EXISTING TEXT CHAT ONLY TAB ---
310
+ with gr.TabItem("πŸ’¬ Text Chat Only"):
311
+ gr.Markdown("## Chat with AI Assistant")
312
 
 
313
  chatbot = gr.Chatbot(
314
  label="Conversation", elem_classes=["chatbot"], value=[]
315
  )
 
318
  placeholder="Type your message here and press Enter...",
319
  lines=2,
320
  )
321
+
322
  with gr.Row():
323
  submit_btn = gr.Button("Send", variant="primary")
324
  clear_btn = gr.Button("Clear Chat")
325
 
326
+ # Use the global state for the text-only chat
327
+ fn_call = msg.submit(
328
+ lambda message, history, chat_state: chat_with_bot(
329
+ message, history, chat_state
330
+ )[:2],
331
+ inputs=[msg, chatbot, global_chat_state],
332
+ outputs=[chatbot, global_chat_state],
333
  ).then(lambda: "", None, msg)
334
 
335
+ submit_btn.click(
336
+ lambda message, history, chat_state: chat_with_bot(
337
+ message, history, chat_state
338
+ )[:2],
339
+ inputs=[msg, chatbot, global_chat_state],
340
+ outputs=[chatbot, global_chat_state],
341
  ).then(lambda: "", None, msg)
342
 
343
+ clear_btn.click(lambda: ([], None), None, [chatbot, global_chat_state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
+ # --- EXISTING STANDALONE TTS TAB ---
346
+ with gr.TabItem("πŸ”Š Text-to-Speech Only"):
347
+ gr.Markdown("## πŸ”Š Text-to-Speech (TTS)")
 
 
 
 
 
 
 
 
 
 
 
348
 
349
+ standalone_text_input = gr.Textbox(
350
+ label="Text to Synthesize",
351
+ lines=3,
352
+ value="Hello there, this is a demonstration of the text to speech model.",
353
+ )
354
+ standalone_tts_button = gr.Button("Synthesize Speech (Will be slow)")
355
+ standalone_audio_output = gr.Audio(label="Synthesized Audio")
356
+ standalone_tts_status = gr.Textbox(elem_id="status", label="Status")
357
+
358
+ standalone_tts_button.click(
359
+ fn=text_to_speech_from_chat,
360
+ inputs=standalone_text_input,
361
+ outputs=[standalone_audio_output, standalone_tts_status],
362
  )
363
 
364
+ # CRITICAL FIX: Passed css argument to demo.launch()
365
  demo.launch(css=custom_css)