anaspro commited on
Commit
cd37326
·
1 Parent(s): 9e78ee7
Files changed (2) hide show
  1. app.py +39 -1
  2. test_tts.py +40 -0
app.py CHANGED
@@ -8,6 +8,9 @@ import av
8
  import gradio as gr
9
  import spaces
10
  import torch
 
 
 
11
  from transformers import AutoModelForImageTextToText, AutoProcessor
12
  from transformers.generation.streamers import TextIteratorStreamer
13
 
@@ -154,9 +157,33 @@ def process_history(history: list[dict]) -> list[dict]:
154
  return messages
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  @spaces.GPU()
158
  @torch.inference_mode()
159
- def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
160
  if not validate_media_constraints(message):
161
  yield ""
162
  return
@@ -200,6 +227,16 @@ def generate(message: dict, history: list[dict], system_prompt: str = "", max_ne
200
  output += delta
201
  yield output
202
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  # Examples for the chat interface (with additional inputs: system_prompt, max_new_tokens)
205
  examples = [
@@ -221,6 +258,7 @@ demo = gr.ChatInterface(
221
  additional_inputs=[
222
  gr.Textbox(label="System Prompt", value="انت ذكاء صناعي يتحدث باللهجة العراقية بس ما تستخدم فصحى ابدا"),
223
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
 
224
  ],
225
  title="Shako IRAQI AI",
226
  examples=examples,
 
8
  import gradio as gr
9
  import spaces
10
  import torch
11
+ from gtts import gTTS
12
+ import io
13
+ import base64
14
  from transformers import AutoModelForImageTextToText, AutoProcessor
15
  from transformers.generation.streamers import TextIteratorStreamer
16
 
 
157
  return messages
158
 
159
 
160
+ def generate_speech(text: str, lang: str = 'ar') -> tuple[str, str]:
161
+ """Generate speech from text using Google TTS and return audio file path and base64 data."""
162
+ try:
163
+ # Create TTS object
164
+ tts = gTTS(text=text, lang=lang, slow=False)
165
+
166
+ # Save to temporary file
167
+ temp_audio_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
168
+ temp_audio_file.close()
169
+
170
+ tts.save(temp_audio_file.name)
171
+
172
+ # Also create base64 version for direct playback
173
+ audio_buffer = io.BytesIO()
174
+ tts.write_to_fp(audio_buffer)
175
+ audio_buffer.seek(0)
176
+ audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
177
+
178
+ return temp_audio_file.name, f"data:audio/mp3;base64,{audio_base64}"
179
+ except Exception as e:
180
+ print(f"TTS Error: {e}")
181
+ return None, None
182
+
183
+
184
  @spaces.GPU()
185
  @torch.inference_mode()
186
+ def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, enable_voice: bool = False) -> Iterator[tuple[str, str | None]]:
187
  if not validate_media_constraints(message):
188
  yield ""
189
  return
 
227
  output += delta
228
  yield output
229
 
230
+ # Generate voice if enabled
231
+ if enable_voice and output.strip():
232
+ _, audio_data = generate_speech(output.strip(), lang='ar')
233
+ if audio_data:
234
+ yield {"text": output, "audio": audio_data}
235
+ else:
236
+ yield output
237
+ else:
238
+ yield output
239
+
240
 
241
  # Examples for the chat interface (with additional inputs: system_prompt, max_new_tokens)
242
  examples = [
 
258
  additional_inputs=[
259
  gr.Textbox(label="System Prompt", value="انت ذكاء صناعي يتحدث باللهجة العراقية بس ما تستخدم فصحى ابدا"),
260
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
261
+ gr.Checkbox(label="Enable Voice Output", value=False),
262
  ],
263
  title="Shako IRAQI AI",
264
  examples=examples,
test_tts.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ from gtts import gTTS
4
+ import io
5
+ import base64
6
+ import tempfile
7
+
8
+ def generate_speech(text: str, lang: str = 'ar') -> tuple[str, str]:
9
+ """Generate speech from text using Google TTS and return audio file path and base64 data."""
10
+ try:
11
+ # Create TTS object
12
+ tts = gTTS(text=text, lang=lang, slow=False)
13
+
14
+ # Save to temporary file
15
+ temp_audio_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
16
+ temp_audio_file.close()
17
+
18
+ tts.save(temp_audio_file.name)
19
+
20
+ # Also create base64 version for direct playback
21
+ audio_buffer = io.BytesIO()
22
+ tts.write_to_fp(audio_buffer)
23
+ audio_buffer.seek(0)
24
+ audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
25
+
26
+ return temp_audio_file.name, f"data:audio/mp3;base64,{audio_base64}"
27
+ except Exception as e:
28
+ print(f"TTS Error: {e}")
29
+ return None, None
30
+
31
+ if __name__ == "__main__":
32
+ # Test the TTS function
33
+ text = "مرحبا، هذا اختبار للصوت"
34
+ file_path, audio_data = generate_speech(text)
35
+ if file_path and audio_data:
36
+ print(f"Audio file created: {file_path}")
37
+ print(f"Audio data length: {len(audio_data)}")
38
+ print("TTS test successful!")
39
+ else:
40
+ print("TTS test failed!")