Devakumar868 commited on
Commit
c78b630
Β·
verified Β·
1 Parent(s): 6e6580b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -82
app.py CHANGED
@@ -12,10 +12,15 @@ import time
12
  from datetime import datetime
13
  import os
14
  import warnings
15
- from datasets import load_dataset
16
 
17
- # Import Dia TTS model
18
- from dia.model import Dia
 
 
 
 
 
 
19
 
20
  warnings.filterwarnings("ignore")
21
 
@@ -60,18 +65,36 @@ class MayaAI:
60
  )
61
  print("βœ… Emotion recognition loaded")
62
 
63
- # Load Dia TTS Model (The REAL Dia from Nari Labs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  try:
65
- self.dia_model = Dia.from_pretrained(
66
- "nari-labs/Dia-1.6B",
67
- compute_dtype="float16" if self.device == "cuda" else "float32"
68
- )[11][13][15]
69
- print("βœ… Dia TTS loaded successfully from Nari Labs")
70
- self.use_dia = True
71
- except Exception as e:
72
- print(f"⚠️ Dia loading failed: {e}")
73
- # Fallback to SpeechT5 with FIXED dtypes
74
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
 
75
 
76
  self.tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
77
  self.tts_model = SpeechT5ForTextToSpeech.from_pretrained(
@@ -90,12 +113,9 @@ class MayaAI:
90
  dtype=torch.float32
91
  ).unsqueeze(0).to(self.device)
92
  print("βœ… SpeechT5 TTS loaded as fallback")
93
- self.use_dia = False
94
-
95
- # Conversation storage
96
- self.conversations = {}
97
- self.call_active = False
98
-
99
  def transcribe_with_whisper(self, audio_path):
100
  """Transcribe using Whisper with FORCED English"""
101
  try:
@@ -189,7 +209,7 @@ class MayaAI:
189
  with torch.no_grad():
190
  outputs = self.llm_model.generate(
191
  input_ids=inputs.input_ids,
192
- attention_mask=inputs.attention_mask, # FIX: Explicit attention mask
193
  max_new_tokens=80,
194
  temperature=0.7,
195
  do_sample=True,
@@ -213,84 +233,92 @@ class MayaAI:
213
  return f"{emotion_prompts.get(emotion, 'I understand.')} Could you tell me more about that?"
214
 
215
  def synthesize_with_dia(self, text, emotion):
216
- """Generate natural emotional speech using Dia TTS"""[11][13][15]
217
  try:
218
  if not text or len(text.strip()) == 0:
219
  return None
220
 
221
  if self.use_dia:
222
- # Use Dia TTS with proper speaker tags and emotional context
223
- # Add emotional markers based on Dia's supported non-verbal tags
 
 
224
  if emotion == "happy":
225
- emotional_text = f"[S1] {text} (laughs)"[11][15]
226
  elif emotion == "sad":
227
- emotional_text = f"[S1] {text} (sighs)"[11][15]
228
  elif emotion == "excited":
229
- emotional_text = f"[S1] {text}!"
230
  elif emotion == "angry":
231
- emotional_text = f"[S1] {text} (clears throat)"[11][15]
232
  elif emotion == "surprised":
233
- emotional_text = f"[S1] {text} (gasps)"[11][15]
234
  else:
235
- emotional_text = f"[S1] {text}"[11][15]
236
 
237
- # Add natural breathing for longer text (Dia feature)
238
- if len(emotional_text.split()) > 15:
239
- words = emotional_text.split()
240
- mid_point = len(words) // 2
241
- emotional_text = " ".join(words[:mid_point]) + " (inhales) " + " ".join(words[mid_point:])
242
-
243
- # Generate using Dia model
244
  output = self.dia_model.generate(
245
  emotional_text,
246
  use_torch_compile=True if self.device == "cuda" else False,
247
  verbose=False
248
- )[11][18]
 
 
 
249
 
250
  return output
251
  else:
252
- # Use SpeechT5 fallback with emotional context
253
- clean_text = text.replace("[", "").replace("]", "").strip()
254
- if len(clean_text) > 200:
255
- clean_text = clean_text[:200] + "..."
256
-
257
- # Add emotional inflection through punctuation
258
- if emotion == "happy":
259
- clean_text = clean_text.replace(".", "!")
260
- elif emotion == "excited":
261
- clean_text = clean_text + "!"
262
- elif emotion == "sad":
263
- clean_text = clean_text.replace("!", ".")
264
-
265
- inputs = self.tts_processor(text=clean_text, return_tensors="pt")
266
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
267
-
268
- with torch.no_grad():
269
- speech = self.tts_model.generate_speech(
270
- inputs["input_ids"],
271
- self.speaker_embeddings,
272
- vocoder=self.vocoder
273
- )
274
-
275
- if isinstance(speech, torch.Tensor):
276
- speech = speech.cpu().numpy().astype(np.float32)
277
-
278
- return speech
 
 
 
 
 
 
 
279
 
 
280
  except Exception as e:
281
- print(f"TTS error: {e}")
282
  return None
283
 
284
  def start_call(self):
285
  """Start a new call session"""
286
  self.call_active = True
 
287
  greeting = "Hello! I'm Maya, your AI conversation partner. I'm here to chat with you naturally and understand your emotions. How are you feeling today?"
288
 
289
  greeting_audio = self.synthesize_with_dia(greeting, "happy")
290
 
291
- # Dia outputs at 44100 Hz sample rate
292
- sample_rate = 44100 if self.use_dia else 22050
293
- return greeting, (sample_rate, greeting_audio) if greeting_audio is not None else None, "πŸ“ž Call started! Maya is greeting you..."
294
 
295
  def end_call(self, user_id="default"):
296
  """End call and clear conversation"""
@@ -301,7 +329,7 @@ class MayaAI:
301
  farewell = "Thank you for chatting with me! It was wonderful talking with you. Have a great day!"
302
  farewell_audio = self.synthesize_with_dia(farewell, "happy")
303
 
304
- sample_rate = 44100 if self.use_dia else 22050
305
  return farewell, (sample_rate, farewell_audio) if farewell_audio is not None else None, "πŸ“ž Call ended. Conversation cleared!"
306
 
307
  def process_conversation(self, audio_input, user_id="default"):
@@ -329,7 +357,7 @@ class MayaAI:
329
  transcription, emotion, self.conversations[user_id]
330
  )
331
 
332
- # Step 4: Dia TTS with natural emotional speech
333
  response_audio = self.synthesize_with_dia(response_text, emotion)
334
 
335
  # Step 5: Update conversation history
@@ -344,13 +372,13 @@ class MayaAI:
344
 
345
  self.conversations[user_id].append(conversation_entry)
346
 
347
- # Keep last 1000 exchanges as specified
348
  if len(self.conversations[user_id]) > 1000:
349
  self.conversations[user_id] = self.conversations[user_id][-1000:]
350
 
351
  history = self.format_conversation_history(user_id)
352
 
353
- sample_rate = 44100 if self.use_dia else 22050
354
  return transcription, (sample_rate, response_audio) if response_audio is not None else None, history
355
 
356
  except Exception as e:
@@ -372,9 +400,9 @@ class MayaAI:
372
  return "\n".join(history)
373
 
374
  # Initialize Maya AI
375
- print("πŸš€ Starting Maya AI with Dia TTS...")
376
  maya = MayaAI()
377
- print("βœ… Maya AI ready with natural emotional speech!")
378
 
379
  # Gradio Interface Functions
380
  def start_call_handler():
@@ -386,17 +414,17 @@ def end_call_handler():
386
  def process_audio_handler(audio):
387
  return maya.process_conversation(audio)
388
 
389
- # Create Gradio Interface
390
  with gr.Blocks(
391
- title="Maya AI - Dia TTS Sesame Killer",
392
  theme=gr.themes.Soft()
393
  ) as demo:
394
 
395
  gr.Markdown("""
396
- # 🎀 Maya AI - Dia TTS Sesame Killer
397
- *Powered by Nari Labs Dia TTS: Ultra-realistic dialogue with natural breathing, laughter, and emotional speech*
398
 
399
- **Features:** βœ… Dia Natural TTS βœ… English-only ASR βœ… Emotion Recognition βœ… FREE Models βœ… Human-like Speech with Non-verbals
400
  """)
401
 
402
  with gr.Row():
@@ -416,7 +444,7 @@ with gr.Blocks(
416
  process_btn = gr.Button("🎯 Process Audio", variant="primary")
417
 
418
  with gr.Column(scale=2):
419
- gr.Markdown("### πŸ’¬ Natural Dia Conversation")
420
 
421
  transcription_output = gr.Textbox(
422
  label="πŸ“ What you said (English)",
@@ -425,13 +453,13 @@ with gr.Blocks(
425
  )
426
 
427
  audio_output = gr.Audio(
428
- label="πŸ”Š Maya's Dia Response (Natural with Breathing & Emotions)",
429
  interactive=False,
430
  autoplay=True
431
  )
432
 
433
  conversation_display = gr.Textbox(
434
- label="πŸ’­ Live Conversation (FREE & Natural Dia TTS)",
435
  lines=15,
436
  interactive=False,
437
  show_copy_button=True
 
12
  from datetime import datetime
13
  import os
14
  import warnings
 
15
 
16
+ # Import Dia model correctly[2]
17
+ try:
18
+ from dia.model import Dia
19
+ DIA_AVAILABLE = True
20
+ print("βœ… Dia model imported successfully")
21
+ except ImportError as e:
22
+ print(f"⚠️ Dia import failed: {e}")
23
+ DIA_AVAILABLE = False
24
 
25
  warnings.filterwarnings("ignore")
26
 
 
65
  )
66
  print("βœ… Emotion recognition loaded")
67
 
68
+ # Load REAL Dia TTS Model[2]
69
+ if DIA_AVAILABLE:
70
+ try:
71
+ # Load Dia model with correct parameters[2]
72
+ self.dia_model = Dia.from_pretrained(
73
+ "nari-labs/Dia-1.6B",
74
+ compute_dtype="float16" if self.device == "cuda" else "float32",
75
+ device=self.device
76
+ )
77
+ print("βœ… Dia TTS loaded (Ultra-realistic dialogue generation)")
78
+ self.use_dia = True
79
+ except Exception as e:
80
+ print(f"⚠️ Dia loading failed: {e}")
81
+ self.use_dia = False
82
+ self._load_fallback_tts()
83
+ else:
84
+ print("⚠️ Dia not available, using fallback TTS")
85
+ self.use_dia = False
86
+ self._load_fallback_tts()
87
+
88
+ # Conversation storage
89
+ self.conversations = {}
90
+ self.call_active = False
91
+ self.speaker_turn = 1 # Track speaker turns for Dia[2]
92
+
93
+ def _load_fallback_tts(self):
94
+ """Load fallback TTS if Dia is not available"""
95
  try:
 
 
 
 
 
 
 
 
 
96
  from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
97
+ from datasets import load_dataset
98
 
99
  self.tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
100
  self.tts_model = SpeechT5ForTextToSpeech.from_pretrained(
 
113
  dtype=torch.float32
114
  ).unsqueeze(0).to(self.device)
115
  print("βœ… SpeechT5 TTS loaded as fallback")
116
+ except Exception as e:
117
+ print(f"❌ Fallback TTS loading failed: {e}")
118
+
 
 
 
119
  def transcribe_with_whisper(self, audio_path):
120
  """Transcribe using Whisper with FORCED English"""
121
  try:
 
209
  with torch.no_grad():
210
  outputs = self.llm_model.generate(
211
  input_ids=inputs.input_ids,
212
+ attention_mask=inputs.attention_mask,
213
  max_new_tokens=80,
214
  temperature=0.7,
215
  do_sample=True,
 
233
  return f"{emotion_prompts.get(emotion, 'I understand.')} Could you tell me more about that?"
234
 
235
  def synthesize_with_dia(self, text, emotion):
236
+ """Generate ultra-realistic dialogue using Dia[2]"""
237
  try:
238
  if not text or len(text.strip()) == 0:
239
  return None
240
 
241
  if self.use_dia:
242
+ # Format text for Dia with proper speaker tags[2]
243
+ speaker_tag = f"[S{self.speaker_turn}]"
244
+
245
+ # Add emotional non-verbals based on emotion[2]
246
  if emotion == "happy":
247
+ emotional_text = f"{speaker_tag} {text} (laughs)"
248
  elif emotion == "sad":
249
+ emotional_text = f"{speaker_tag} {text} (sighs)"
250
  elif emotion == "excited":
251
+ emotional_text = f"{speaker_tag} {text}!"
252
  elif emotion == "angry":
253
+ emotional_text = f"{speaker_tag} {text} (frustrated tone)"
254
  elif emotion == "surprised":
255
+ emotional_text = f"{speaker_tag} {text} (gasps)"
256
  else:
257
+ emotional_text = f"{speaker_tag} {text}"
258
 
259
+ # Generate with Dia[2]
 
 
 
 
 
 
260
  output = self.dia_model.generate(
261
  emotional_text,
262
  use_torch_compile=True if self.device == "cuda" else False,
263
  verbose=False
264
+ )
265
+
266
+ # Toggle speaker for next turn[2]
267
+ self.speaker_turn = 2 if self.speaker_turn == 1 else 1
268
 
269
  return output
270
  else:
271
+ # Fallback to SpeechT5
272
+ return self._synthesize_with_fallback(text, emotion)
273
+
274
+ except Exception as e:
275
+ print(f"Dia TTS error: {e}")
276
+ return self._synthesize_with_fallback(text, emotion)
277
+
278
+ def _synthesize_with_fallback(self, text, emotion):
279
+ """Fallback TTS synthesis"""
280
+ try:
281
+ clean_text = text.replace("[", "").replace("]", "").strip()
282
+ if len(clean_text) > 200:
283
+ clean_text = clean_text[:200] + "..."
284
+
285
+ # Add emotional inflection through punctuation
286
+ if emotion == "happy":
287
+ clean_text = clean_text.replace(".", "!")
288
+ elif emotion == "excited":
289
+ clean_text = clean_text + "!"
290
+ elif emotion == "sad":
291
+ clean_text = clean_text.replace("!", ".")
292
+
293
+ inputs = self.tts_processor(text=clean_text, return_tensors="pt")
294
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
295
+
296
+ with torch.no_grad():
297
+ speech = self.tts_model.generate_speech(
298
+ inputs["input_ids"],
299
+ self.speaker_embeddings,
300
+ vocoder=self.vocoder
301
+ )
302
+
303
+ if isinstance(speech, torch.Tensor):
304
+ speech = speech.cpu().numpy().astype(np.float32)
305
 
306
+ return speech
307
  except Exception as e:
308
+ print(f"Fallback TTS error: {e}")
309
  return None
310
 
311
  def start_call(self):
312
  """Start a new call session"""
313
  self.call_active = True
314
+ self.speaker_turn = 1 # Reset speaker turn[2]
315
  greeting = "Hello! I'm Maya, your AI conversation partner. I'm here to chat with you naturally and understand your emotions. How are you feeling today?"
316
 
317
  greeting_audio = self.synthesize_with_dia(greeting, "happy")
318
 
319
+ # Dia outputs at 24kHz, fallback at 22050Hz[2]
320
+ sample_rate = 24000 if self.use_dia else 22050
321
+ return greeting, (sample_rate, greeting_audio) if greeting_audio is not None else None, "πŸ“ž Call started! Maya is greeting you with ultra-realistic speech..."
322
 
323
  def end_call(self, user_id="default"):
324
  """End call and clear conversation"""
 
329
  farewell = "Thank you for chatting with me! It was wonderful talking with you. Have a great day!"
330
  farewell_audio = self.synthesize_with_dia(farewell, "happy")
331
 
332
+ sample_rate = 24000 if self.use_dia else 22050
333
  return farewell, (sample_rate, farewell_audio) if farewell_audio is not None else None, "πŸ“ž Call ended. Conversation cleared!"
334
 
335
  def process_conversation(self, audio_input, user_id="default"):
 
357
  transcription, emotion, self.conversations[user_id]
358
  )
359
 
360
+ # Step 4: Ultra-realistic TTS with Dia[2]
361
  response_audio = self.synthesize_with_dia(response_text, emotion)
362
 
363
  # Step 5: Update conversation history
 
372
 
373
  self.conversations[user_id].append(conversation_entry)
374
 
375
+ # Keep last 1000 exchanges as requested[5]
376
  if len(self.conversations[user_id]) > 1000:
377
  self.conversations[user_id] = self.conversations[user_id][-1000:]
378
 
379
  history = self.format_conversation_history(user_id)
380
 
381
+ sample_rate = 24000 if self.use_dia else 22050
382
  return transcription, (sample_rate, response_audio) if response_audio is not None else None, history
383
 
384
  except Exception as e:
 
400
  return "\n".join(history)
401
 
402
  # Initialize Maya AI
403
+ print("πŸš€ Starting Maya AI with REAL Dia TTS...")
404
  maya = MayaAI()
405
+ print("βœ… Maya AI ready with ultra-realistic dialogue generation!")
406
 
407
  # Gradio Interface Functions
408
  def start_call_handler():
 
414
  def process_audio_handler(audio):
415
  return maya.process_conversation(audio)
416
 
417
+ # Create Gradio Interface[7]
418
  with gr.Blocks(
419
+ title="Maya AI - Dia-Powered Sesame Killer",
420
  theme=gr.themes.Soft()
421
  ) as demo:
422
 
423
  gr.Markdown("""
424
+ # 🎀 Maya AI - Dia-Powered Sesame Killer
425
+ *Ultra-realistic dialogue generation with Dia TTS - Natural breathing, laughter, and human-like responses*
426
 
427
+ **Features:** βœ… Real Dia TTS βœ… English-only ASR βœ… Emotion Recognition βœ… FREE LLM βœ… Ultra-realistic Speech
428
  """)
429
 
430
  with gr.Row():
 
444
  process_btn = gr.Button("🎯 Process Audio", variant="primary")
445
 
446
  with gr.Column(scale=2):
447
+ gr.Markdown("### πŸ’¬ Ultra-Realistic Conversation")
448
 
449
  transcription_output = gr.Textbox(
450
  label="πŸ“ What you said (English)",
 
453
  )
454
 
455
  audio_output = gr.Audio(
456
+ label="πŸ”Š Maya's Ultra-Realistic Response (Dia TTS)",
457
  interactive=False,
458
  autoplay=True
459
  )
460
 
461
  conversation_display = gr.Textbox(
462
+ label="πŸ’­ Live Conversation (FREE & Ultra-Realistic)",
463
  lines=15,
464
  interactive=False,
465
  show_copy_button=True