Devakumar868 commited on
Commit
d4c442c
Β·
verified Β·
1 Parent(s): c78b630

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -496
app.py CHANGED
@@ -1,496 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- import librosa
5
- from transformers import (
6
- pipeline, AutoTokenizer, AutoModelForCausalLM,
7
- WhisperProcessor, WhisperForConditionalGeneration
8
- )
9
- import soundfile as sf
10
- import json
11
- import time
12
- from datetime import datetime
13
- import os
14
- import warnings
15
-
16
- # Import Dia model correctly[2]
17
- try:
18
- from dia.model import Dia
19
- DIA_AVAILABLE = True
20
- print("βœ… Dia model imported successfully")
21
- except ImportError as e:
22
- print(f"⚠️ Dia import failed: {e}")
23
- DIA_AVAILABLE = False
24
-
25
- warnings.filterwarnings("ignore")
26
-
27
- class MayaAI:
28
- def __init__(self):
29
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
- print(f"πŸš€ Initializing Maya AI on {self.device}")
31
-
32
- # Load Whisper ASR with FORCED English
33
- self.asr_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
34
- self.asr_model = WhisperForConditionalGeneration.from_pretrained(
35
- "openai/whisper-large-v3",
36
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
37
- ).to(self.device)
38
-
39
- # FORCE English transcription
40
- self.asr_model.config.forced_decoder_ids = self.asr_processor.get_decoder_prompt_ids(
41
- language="english",
42
- task="transcribe"
43
- )
44
- print("βœ… Whisper ASR loaded with FORCED English")
45
-
46
- # Load FREE LLM with FIXED attention mask
47
- self.llm_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
48
- # FIX: Set pad_token to eos_token to avoid attention mask warnings
49
- if self.llm_tokenizer.pad_token is None:
50
- self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
51
-
52
- self.llm_model = AutoModelForCausalLM.from_pretrained(
53
- "microsoft/DialoGPT-large",
54
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
55
- device_map="auto",
56
- pad_token_id=self.llm_tokenizer.eos_token_id
57
- )
58
- print("βœ… DialoGPT-Large loaded with FIXED attention masks")
59
-
60
- # Load Emotion Recognition
61
- self.emotion_model = pipeline(
62
- "audio-classification",
63
- model="superb/wav2vec2-base-superb-er",
64
- device=self.device
65
- )
66
- print("βœ… Emotion recognition loaded")
67
-
68
- # Load REAL Dia TTS Model[2]
69
- if DIA_AVAILABLE:
70
- try:
71
- # Load Dia model with correct parameters[2]
72
- self.dia_model = Dia.from_pretrained(
73
- "nari-labs/Dia-1.6B",
74
- compute_dtype="float16" if self.device == "cuda" else "float32",
75
- device=self.device
76
- )
77
- print("βœ… Dia TTS loaded (Ultra-realistic dialogue generation)")
78
- self.use_dia = True
79
- except Exception as e:
80
- print(f"⚠️ Dia loading failed: {e}")
81
- self.use_dia = False
82
- self._load_fallback_tts()
83
- else:
84
- print("⚠️ Dia not available, using fallback TTS")
85
- self.use_dia = False
86
- self._load_fallback_tts()
87
-
88
- # Conversation storage
89
- self.conversations = {}
90
- self.call_active = False
91
- self.speaker_turn = 1 # Track speaker turns for Dia[2]
92
-
93
- def _load_fallback_tts(self):
94
- """Load fallback TTS if Dia is not available"""
95
- try:
96
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
97
- from datasets import load_dataset
98
-
99
- self.tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
100
- self.tts_model = SpeechT5ForTextToSpeech.from_pretrained(
101
- "microsoft/speecht5_tts",
102
- torch_dtype=torch.float32
103
- ).to(self.device)
104
- self.vocoder = SpeechT5HifiGan.from_pretrained(
105
- "microsoft/speecht5_hifigan",
106
- torch_dtype=torch.float32
107
- ).to(self.device)
108
-
109
- # Load female speaker embeddings
110
- embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
111
- self.speaker_embeddings = torch.tensor(
112
- embeddings_dataset[7306]["xvector"],
113
- dtype=torch.float32
114
- ).unsqueeze(0).to(self.device)
115
- print("βœ… SpeechT5 TTS loaded as fallback")
116
- except Exception as e:
117
- print(f"❌ Fallback TTS loading failed: {e}")
118
-
119
- def transcribe_with_whisper(self, audio_path):
120
- """Transcribe using Whisper with FORCED English"""
121
- try:
122
- if audio_path is None:
123
- return "No audio provided"
124
-
125
- # Load and preprocess audio
126
- audio, sr = librosa.load(audio_path, sr=16000, mono=True)
127
-
128
- # Process with Whisper - FORCE English
129
- inputs = self.asr_processor(
130
- audio,
131
- sampling_rate=16000,
132
- return_tensors="pt",
133
- language="english"
134
- ).to(self.device)
135
-
136
- with torch.no_grad():
137
- predicted_ids = self.asr_model.generate(
138
- inputs.input_features,
139
- max_new_tokens=150,
140
- do_sample=False,
141
- forced_decoder_ids=self.asr_model.config.forced_decoder_ids
142
- )
143
-
144
- transcription = self.asr_processor.batch_decode(
145
- predicted_ids,
146
- skip_special_tokens=True
147
- )[0]
148
-
149
- return transcription.strip()
150
-
151
- except Exception as e:
152
- return f"Transcription error: {str(e)}"
153
-
154
- def recognize_emotion_from_audio(self, audio_path):
155
- """Recognize emotion using superb model"""
156
- try:
157
- if audio_path is None:
158
- return "neutral"
159
-
160
- result = self.emotion_model(audio_path)
161
- emotion_label = result[0]["label"].lower()
162
-
163
- # Map to human emotions
164
- emotion_map = {
165
- "ang": "angry", "hap": "happy", "exc": "excited",
166
- "sad": "sad", "fru": "frustrated", "fea": "fearful",
167
- "sur": "surprised", "neu": "neutral", "dis": "disgusted"
168
- }
169
-
170
- return emotion_map.get(emotion_label, emotion_label)
171
- except:
172
- return "neutral"
173
-
174
- def generate_with_free_llm(self, text, emotion, history):
175
- """Generate response using FREE LLM with FIXED attention masks"""
176
- try:
177
- # Emotional context prompting
178
- emotion_prompts = {
179
- "angry": "I understand you're frustrated. Let me help calm this situation.",
180
- "sad": "I can hear the sadness in your voice. I'm here to support you.",
181
- "happy": "Your joy is infectious! I love your positive energy.",
182
- "excited": "Your enthusiasm is amazing! Tell me more!",
183
- "fearful": "I sense your concern. Let's work through this together.",
184
- "surprised": "That sounds unexpected! What happened?",
185
- "neutral": "I'm listening carefully. Please continue."
186
- }
187
-
188
- emotion_context = emotion_prompts.get(emotion, "I'm here to help.")
189
-
190
- # Build conversation context
191
- context_text = ""
192
- if history:
193
- for entry in history[-2:]:
194
- context_text += f"User: {entry.get('user_input', '')}\nMaya: {entry.get('ai_response', '')}\n"
195
-
196
- prompt = f"{context_text}User: {text}\nMaya:"
197
-
198
- # Tokenize input with PROPER attention mask
199
- inputs = self.llm_tokenizer(
200
- prompt,
201
- return_tensors="pt",
202
- truncation=True,
203
- max_length=1024,
204
- padding=True,
205
- add_special_tokens=True
206
- ).to(self.device)
207
-
208
- # Generate response with PROPER attention mask
209
- with torch.no_grad():
210
- outputs = self.llm_model.generate(
211
- input_ids=inputs.input_ids,
212
- attention_mask=inputs.attention_mask,
213
- max_new_tokens=80,
214
- temperature=0.7,
215
- do_sample=True,
216
- pad_token_id=self.llm_tokenizer.pad_token_id,
217
- eos_token_id=self.llm_tokenizer.eos_token_id
218
- )
219
-
220
- # Decode response
221
- response = self.llm_tokenizer.decode(
222
- outputs[0][inputs.input_ids.shape[1]:],
223
- skip_special_tokens=True
224
- ).strip()
225
-
226
- # Clean up response
227
- if not response or len(response) < 5:
228
- return emotion_context
229
-
230
- return response
231
-
232
- except Exception as e:
233
- return f"{emotion_prompts.get(emotion, 'I understand.')} Could you tell me more about that?"
234
-
235
- def synthesize_with_dia(self, text, emotion):
236
- """Generate ultra-realistic dialogue using Dia[2]"""
237
- try:
238
- if not text or len(text.strip()) == 0:
239
- return None
240
-
241
- if self.use_dia:
242
- # Format text for Dia with proper speaker tags[2]
243
- speaker_tag = f"[S{self.speaker_turn}]"
244
-
245
- # Add emotional non-verbals based on emotion[2]
246
- if emotion == "happy":
247
- emotional_text = f"{speaker_tag} {text} (laughs)"
248
- elif emotion == "sad":
249
- emotional_text = f"{speaker_tag} {text} (sighs)"
250
- elif emotion == "excited":
251
- emotional_text = f"{speaker_tag} {text}!"
252
- elif emotion == "angry":
253
- emotional_text = f"{speaker_tag} {text} (frustrated tone)"
254
- elif emotion == "surprised":
255
- emotional_text = f"{speaker_tag} {text} (gasps)"
256
- else:
257
- emotional_text = f"{speaker_tag} {text}"
258
-
259
- # Generate with Dia[2]
260
- output = self.dia_model.generate(
261
- emotional_text,
262
- use_torch_compile=True if self.device == "cuda" else False,
263
- verbose=False
264
- )
265
-
266
- # Toggle speaker for next turn[2]
267
- self.speaker_turn = 2 if self.speaker_turn == 1 else 1
268
-
269
- return output
270
- else:
271
- # Fallback to SpeechT5
272
- return self._synthesize_with_fallback(text, emotion)
273
-
274
- except Exception as e:
275
- print(f"Dia TTS error: {e}")
276
- return self._synthesize_with_fallback(text, emotion)
277
-
278
- def _synthesize_with_fallback(self, text, emotion):
279
- """Fallback TTS synthesis"""
280
- try:
281
- clean_text = text.replace("[", "").replace("]", "").strip()
282
- if len(clean_text) > 200:
283
- clean_text = clean_text[:200] + "..."
284
-
285
- # Add emotional inflection through punctuation
286
- if emotion == "happy":
287
- clean_text = clean_text.replace(".", "!")
288
- elif emotion == "excited":
289
- clean_text = clean_text + "!"
290
- elif emotion == "sad":
291
- clean_text = clean_text.replace("!", ".")
292
-
293
- inputs = self.tts_processor(text=clean_text, return_tensors="pt")
294
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
295
-
296
- with torch.no_grad():
297
- speech = self.tts_model.generate_speech(
298
- inputs["input_ids"],
299
- self.speaker_embeddings,
300
- vocoder=self.vocoder
301
- )
302
-
303
- if isinstance(speech, torch.Tensor):
304
- speech = speech.cpu().numpy().astype(np.float32)
305
-
306
- return speech
307
- except Exception as e:
308
- print(f"Fallback TTS error: {e}")
309
- return None
310
-
311
- def start_call(self):
312
- """Start a new call session"""
313
- self.call_active = True
314
- self.speaker_turn = 1 # Reset speaker turn[2]
315
- greeting = "Hello! I'm Maya, your AI conversation partner. I'm here to chat with you naturally and understand your emotions. How are you feeling today?"
316
-
317
- greeting_audio = self.synthesize_with_dia(greeting, "happy")
318
-
319
- # Dia outputs at 24kHz, fallback at 22050Hz[2]
320
- sample_rate = 24000 if self.use_dia else 22050
321
- return greeting, (sample_rate, greeting_audio) if greeting_audio is not None else None, "πŸ“ž Call started! Maya is greeting you with ultra-realistic speech..."
322
-
323
- def end_call(self, user_id="default"):
324
- """End call and clear conversation"""
325
- self.call_active = False
326
- if user_id in self.conversations:
327
- self.conversations[user_id] = []
328
-
329
- farewell = "Thank you for chatting with me! It was wonderful talking with you. Have a great day!"
330
- farewell_audio = self.synthesize_with_dia(farewell, "happy")
331
-
332
- sample_rate = 24000 if self.use_dia else 22050
333
- return farewell, (sample_rate, farewell_audio) if farewell_audio is not None else None, "πŸ“ž Call ended. Conversation cleared!"
334
-
335
- def process_conversation(self, audio_input, user_id="default"):
336
- """Main conversation processing pipeline"""
337
- if not self.call_active:
338
- return "Please start a call first by clicking the 'Start Call' button", None, "No active call"
339
-
340
- if audio_input is None:
341
- return "Please record some audio", None, "No audio input"
342
-
343
- start_time = time.time()
344
-
345
- if user_id not in self.conversations:
346
- self.conversations[user_id] = []
347
-
348
- try:
349
- # Step 1: ASR with FORCED English
350
- transcription = self.transcribe_with_whisper(audio_input)
351
-
352
- # Step 2: Emotion recognition
353
- emotion = self.recognize_emotion_from_audio(audio_input)
354
-
355
- # Step 3: FREE LLM generation with FIXED attention masks
356
- response_text = self.generate_with_free_llm(
357
- transcription, emotion, self.conversations[user_id]
358
- )
359
-
360
- # Step 4: Ultra-realistic TTS with Dia[2]
361
- response_audio = self.synthesize_with_dia(response_text, emotion)
362
-
363
- # Step 5: Update conversation history
364
- processing_time = time.time() - start_time
365
- conversation_entry = {
366
- "timestamp": datetime.now().strftime("%H:%M:%S"),
367
- "user_input": transcription,
368
- "user_emotion": emotion,
369
- "ai_response": response_text,
370
- "processing_time": processing_time
371
- }
372
-
373
- self.conversations[user_id].append(conversation_entry)
374
-
375
- # Keep last 1000 exchanges as requested[5]
376
- if len(self.conversations[user_id]) > 1000:
377
- self.conversations[user_id] = self.conversations[user_id][-1000:]
378
-
379
- history = self.format_conversation_history(user_id)
380
-
381
- sample_rate = 24000 if self.use_dia else 22050
382
- return transcription, (sample_rate, response_audio) if response_audio is not None else None, history
383
-
384
- except Exception as e:
385
- return f"Processing error: {str(e)}", None, "Error in processing"
386
-
387
- def format_conversation_history(self, user_id):
388
- """Format conversation history for display"""
389
- if user_id not in self.conversations or not self.conversations[user_id]:
390
- return "No conversation history yet."
391
-
392
- history = []
393
- for i, entry in enumerate(self.conversations[user_id][-10:], 1):
394
- history.append(f"**Exchange {i}** ({entry['timestamp']})")
395
- history.append(f"🎀 **You** ({entry['user_emotion']}): {entry['user_input']}")
396
- history.append(f"πŸ€– **Maya**: {entry['ai_response']}")
397
- history.append(f"⏱️ *{entry['processing_time']:.2f}s*")
398
- history.append("---")
399
-
400
- return "\n".join(history)
401
-
402
- # Initialize Maya AI
403
- print("πŸš€ Starting Maya AI with REAL Dia TTS...")
404
- maya = MayaAI()
405
- print("βœ… Maya AI ready with ultra-realistic dialogue generation!")
406
-
407
- # Gradio Interface Functions
408
- def start_call_handler():
409
- return maya.start_call()
410
-
411
- def end_call_handler():
412
- return maya.end_call()
413
-
414
- def process_audio_handler(audio):
415
- return maya.process_conversation(audio)
416
-
417
- # Create Gradio Interface[7]
418
- with gr.Blocks(
419
- title="Maya AI - Dia-Powered Sesame Killer",
420
- theme=gr.themes.Soft()
421
- ) as demo:
422
-
423
- gr.Markdown("""
424
- # 🎀 Maya AI - Dia-Powered Sesame Killer
425
- *Ultra-realistic dialogue generation with Dia TTS - Natural breathing, laughter, and human-like responses*
426
-
427
- **Features:** βœ… Real Dia TTS βœ… English-only ASR βœ… Emotion Recognition βœ… FREE LLM βœ… Ultra-realistic Speech
428
- """)
429
-
430
- with gr.Row():
431
- with gr.Column(scale=1):
432
- gr.Markdown("### πŸ“ž Call Controls")
433
-
434
- start_call_btn = gr.Button("πŸ“ž Start Call", variant="primary", size="lg")
435
- end_call_btn = gr.Button("πŸ“ž End Call", variant="stop", size="lg")
436
-
437
- gr.Markdown("### πŸŽ™οΈ Voice Input")
438
- audio_input = gr.Audio(
439
- sources=["microphone"],
440
- type="filepath",
441
- label="Record your message in English"
442
- )
443
-
444
- process_btn = gr.Button("🎯 Process Audio", variant="primary")
445
-
446
- with gr.Column(scale=2):
447
- gr.Markdown("### πŸ’¬ Ultra-Realistic Conversation")
448
-
449
- transcription_output = gr.Textbox(
450
- label="πŸ“ What you said (English)",
451
- lines=2,
452
- interactive=False
453
- )
454
-
455
- audio_output = gr.Audio(
456
- label="πŸ”Š Maya's Ultra-Realistic Response (Dia TTS)",
457
- interactive=False,
458
- autoplay=True
459
- )
460
-
461
- conversation_display = gr.Textbox(
462
- label="πŸ’­ Live Conversation (FREE & Ultra-Realistic)",
463
- lines=15,
464
- interactive=False,
465
- show_copy_button=True
466
- )
467
-
468
- # Event Handlers
469
- start_call_btn.click(
470
- fn=start_call_handler,
471
- outputs=[transcription_output, audio_output, conversation_display]
472
- )
473
-
474
- end_call_btn.click(
475
- fn=end_call_handler,
476
- outputs=[transcription_output, audio_output, conversation_display]
477
- )
478
-
479
- process_btn.click(
480
- fn=process_audio_handler,
481
- inputs=[audio_input],
482
- outputs=[transcription_output, audio_output, conversation_display]
483
- )
484
-
485
- audio_input.stop_recording(
486
- fn=process_audio_handler,
487
- inputs=[audio_input],
488
- outputs=[transcription_output, audio_output, conversation_display]
489
- )
490
-
491
- if __name__ == "__main__":
492
- demo.launch(
493
- server_name="0.0.0.0",
494
- server_port=7860,
495
- show_error=True
496
- )