Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import librosa | |
| from transformers import ( | |
| pipeline, AutoTokenizer, AutoModelForCausalLM, | |
| WhisperProcessor, WhisperForConditionalGeneration | |
| ) | |
| import soundfile as sf | |
| import json | |
| import time | |
| from datetime import datetime | |
| import os | |
| import warnings | |
| # Import Dia model correctly[2] | |
| try: | |
| from dia.model import Dia | |
| DIA_AVAILABLE = True | |
| print("β Dia model imported successfully") | |
| except ImportError as e: | |
| print(f"β οΈ Dia import failed: {e}") | |
| DIA_AVAILABLE = False | |
| warnings.filterwarnings("ignore") | |
| class MayaAI: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"π Initializing Maya AI on {self.device}") | |
| # Load Whisper ASR with FORCED English | |
| self.asr_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") | |
| self.asr_model = WhisperForConditionalGeneration.from_pretrained( | |
| "openai/whisper-large-v3", | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ).to(self.device) | |
| # FORCE English transcription | |
| self.asr_model.config.forced_decoder_ids = self.asr_processor.get_decoder_prompt_ids( | |
| language="english", | |
| task="transcribe" | |
| ) | |
| print("β Whisper ASR loaded with FORCED English") | |
| # Load FREE LLM with FIXED attention mask | |
| self.llm_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") | |
| # FIX: Set pad_token to eos_token to avoid attention mask warnings | |
| if self.llm_tokenizer.pad_token is None: | |
| self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token | |
| self.llm_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/DialoGPT-large", | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| device_map="auto", | |
| pad_token_id=self.llm_tokenizer.eos_token_id | |
| ) | |
| print("β DialoGPT-Large loaded with FIXED attention masks") | |
| # Load Emotion Recognition | |
| self.emotion_model = pipeline( | |
| "audio-classification", | |
| model="superb/wav2vec2-base-superb-er", | |
| device=self.device | |
| ) | |
| print("β Emotion recognition loaded") | |
| # Load REAL Dia TTS Model[2] | |
| if DIA_AVAILABLE: | |
| try: | |
| # Load Dia model with correct parameters[2] | |
| self.dia_model = Dia.from_pretrained( | |
| "nari-labs/Dia-1.6B", | |
| compute_dtype="float16" if self.device == "cuda" else "float32", | |
| device=self.device | |
| ) | |
| print("β Dia TTS loaded (Ultra-realistic dialogue generation)") | |
| self.use_dia = True | |
| except Exception as e: | |
| print(f"β οΈ Dia loading failed: {e}") | |
| self.use_dia = False | |
| self._load_fallback_tts() | |
| else: | |
| print("β οΈ Dia not available, using fallback TTS") | |
| self.use_dia = False | |
| self._load_fallback_tts() | |
| # Conversation storage | |
| self.conversations = {} | |
| self.call_active = False | |
| self.speaker_turn = 1 # Track speaker turns for Dia[2] | |
| def _load_fallback_tts(self): | |
| """Load fallback TTS if Dia is not available""" | |
| try: | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| from datasets import load_dataset | |
| self.tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| self.tts_model = SpeechT5ForTextToSpeech.from_pretrained( | |
| "microsoft/speecht5_tts", | |
| torch_dtype=torch.float32 | |
| ).to(self.device) | |
| self.vocoder = SpeechT5HifiGan.from_pretrained( | |
| "microsoft/speecht5_hifigan", | |
| torch_dtype=torch.float32 | |
| ).to(self.device) | |
| # Load female speaker embeddings | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| self.speaker_embeddings = torch.tensor( | |
| embeddings_dataset[7306]["xvector"], | |
| dtype=torch.float32 | |
| ).unsqueeze(0).to(self.device) | |
| print("β SpeechT5 TTS loaded as fallback") | |
| except Exception as e: | |
| print(f"β Fallback TTS loading failed: {e}") | |
| def transcribe_with_whisper(self, audio_path): | |
| """Transcribe using Whisper with FORCED English""" | |
| try: | |
| if audio_path is None: | |
| return "No audio provided" | |
| # Load and preprocess audio | |
| audio, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| # Process with Whisper - FORCE English | |
| inputs = self.asr_processor( | |
| audio, | |
| sampling_rate=16000, | |
| return_tensors="pt", | |
| language="english" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| predicted_ids = self.asr_model.generate( | |
| inputs.input_features, | |
| max_new_tokens=150, | |
| do_sample=False, | |
| forced_decoder_ids=self.asr_model.config.forced_decoder_ids | |
| ) | |
| transcription = self.asr_processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| return transcription.strip() | |
| except Exception as e: | |
| return f"Transcription error: {str(e)}" | |
| def recognize_emotion_from_audio(self, audio_path): | |
| """Recognize emotion using superb model""" | |
| try: | |
| if audio_path is None: | |
| return "neutral" | |
| result = self.emotion_model(audio_path) | |
| emotion_label = result[0]["label"].lower() | |
| # Map to human emotions | |
| emotion_map = { | |
| "ang": "angry", "hap": "happy", "exc": "excited", | |
| "sad": "sad", "fru": "frustrated", "fea": "fearful", | |
| "sur": "surprised", "neu": "neutral", "dis": "disgusted" | |
| } | |
| return emotion_map.get(emotion_label, emotion_label) | |
| except: | |
| return "neutral" | |
| def generate_with_free_llm(self, text, emotion, history): | |
| """Generate response using FREE LLM with FIXED attention masks""" | |
| try: | |
| # Emotional context prompting | |
| emotion_prompts = { | |
| "angry": "I understand you're frustrated. Let me help calm this situation.", | |
| "sad": "I can hear the sadness in your voice. I'm here to support you.", | |
| "happy": "Your joy is infectious! I love your positive energy.", | |
| "excited": "Your enthusiasm is amazing! Tell me more!", | |
| "fearful": "I sense your concern. Let's work through this together.", | |
| "surprised": "That sounds unexpected! What happened?", | |
| "neutral": "I'm listening carefully. Please continue." | |
| } | |
| emotion_context = emotion_prompts.get(emotion, "I'm here to help.") | |
| # Build conversation context | |
| context_text = "" | |
| if history: | |
| for entry in history[-2:]: | |
| context_text += f"User: {entry.get('user_input', '')}\nMaya: {entry.get('ai_response', '')}\n" | |
| prompt = f"{context_text}User: {text}\nMaya:" | |
| # Tokenize input with PROPER attention mask | |
| inputs = self.llm_tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=1024, | |
| padding=True, | |
| add_special_tokens=True | |
| ).to(self.device) | |
| # Generate response with PROPER attention mask | |
| with torch.no_grad(): | |
| outputs = self.llm_model.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_new_tokens=80, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=self.llm_tokenizer.pad_token_id, | |
| eos_token_id=self.llm_tokenizer.eos_token_id | |
| ) | |
| # Decode response | |
| response = self.llm_tokenizer.decode( | |
| outputs[0][inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| # Clean up response | |
| if not response or len(response) < 5: | |
| return emotion_context | |
| return response | |
| except Exception as e: | |
| return f"{emotion_prompts.get(emotion, 'I understand.')} Could you tell me more about that?" | |
| def synthesize_with_dia(self, text, emotion): | |
| """Generate ultra-realistic dialogue using Dia[2]""" | |
| try: | |
| if not text or len(text.strip()) == 0: | |
| return None | |
| if self.use_dia: | |
| # Format text for Dia with proper speaker tags[2] | |
| speaker_tag = f"[S{self.speaker_turn}]" | |
| # Add emotional non-verbals based on emotion[2] | |
| if emotion == "happy": | |
| emotional_text = f"{speaker_tag} {text} (laughs)" | |
| elif emotion == "sad": | |
| emotional_text = f"{speaker_tag} {text} (sighs)" | |
| elif emotion == "excited": | |
| emotional_text = f"{speaker_tag} {text}!" | |
| elif emotion == "angry": | |
| emotional_text = f"{speaker_tag} {text} (frustrated tone)" | |
| elif emotion == "surprised": | |
| emotional_text = f"{speaker_tag} {text} (gasps)" | |
| else: | |
| emotional_text = f"{speaker_tag} {text}" | |
| # Generate with Dia[2] | |
| output = self.dia_model.generate( | |
| emotional_text, | |
| use_torch_compile=True if self.device == "cuda" else False, | |
| verbose=False | |
| ) | |
| # Toggle speaker for next turn[2] | |
| self.speaker_turn = 2 if self.speaker_turn == 1 else 1 | |
| return output | |
| else: | |
| # Fallback to SpeechT5 | |
| return self._synthesize_with_fallback(text, emotion) | |
| except Exception as e: | |
| print(f"Dia TTS error: {e}") | |
| return self._synthesize_with_fallback(text, emotion) | |
| def _synthesize_with_fallback(self, text, emotion): | |
| """Fallback TTS synthesis""" | |
| try: | |
| clean_text = text.replace("[", "").replace("]", "").strip() | |
| if len(clean_text) > 200: | |
| clean_text = clean_text[:200] + "..." | |
| # Add emotional inflection through punctuation | |
| if emotion == "happy": | |
| clean_text = clean_text.replace(".", "!") | |
| elif emotion == "excited": | |
| clean_text = clean_text + "!" | |
| elif emotion == "sad": | |
| clean_text = clean_text.replace("!", ".") | |
| inputs = self.tts_processor(text=clean_text, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| speech = self.tts_model.generate_speech( | |
| inputs["input_ids"], | |
| self.speaker_embeddings, | |
| vocoder=self.vocoder | |
| ) | |
| if isinstance(speech, torch.Tensor): | |
| speech = speech.cpu().numpy().astype(np.float32) | |
| return speech | |
| except Exception as e: | |
| print(f"Fallback TTS error: {e}") | |
| return None | |
| def start_call(self): | |
| """Start a new call session""" | |
| self.call_active = True | |
| self.speaker_turn = 1 # Reset speaker turn[2] | |
| greeting = "Hello! I'm Maya, your AI conversation partner. I'm here to chat with you naturally and understand your emotions. How are you feeling today?" | |
| greeting_audio = self.synthesize_with_dia(greeting, "happy") | |
| # Dia outputs at 24kHz, fallback at 22050Hz[2] | |
| sample_rate = 24000 if self.use_dia else 22050 | |
| return greeting, (sample_rate, greeting_audio) if greeting_audio is not None else None, "π Call started! Maya is greeting you with ultra-realistic speech..." | |
| def end_call(self, user_id="default"): | |
| """End call and clear conversation""" | |
| self.call_active = False | |
| if user_id in self.conversations: | |
| self.conversations[user_id] = [] | |
| farewell = "Thank you for chatting with me! It was wonderful talking with you. Have a great day!" | |
| farewell_audio = self.synthesize_with_dia(farewell, "happy") | |
| sample_rate = 24000 if self.use_dia else 22050 | |
| return farewell, (sample_rate, farewell_audio) if farewell_audio is not None else None, "π Call ended. Conversation cleared!" | |
| def process_conversation(self, audio_input, user_id="default"): | |
| """Main conversation processing pipeline""" | |
| if not self.call_active: | |
| return "Please start a call first by clicking the 'Start Call' button", None, "No active call" | |
| if audio_input is None: | |
| return "Please record some audio", None, "No audio input" | |
| start_time = time.time() | |
| if user_id not in self.conversations: | |
| self.conversations[user_id] = [] | |
| try: | |
| # Step 1: ASR with FORCED English | |
| transcription = self.transcribe_with_whisper(audio_input) | |
| # Step 2: Emotion recognition | |
| emotion = self.recognize_emotion_from_audio(audio_input) | |
| # Step 3: FREE LLM generation with FIXED attention masks | |
| response_text = self.generate_with_free_llm( | |
| transcription, emotion, self.conversations[user_id] | |
| ) | |
| # Step 4: Ultra-realistic TTS with Dia[2] | |
| response_audio = self.synthesize_with_dia(response_text, emotion) | |
| # Step 5: Update conversation history | |
| processing_time = time.time() - start_time | |
| conversation_entry = { | |
| "timestamp": datetime.now().strftime("%H:%M:%S"), | |
| "user_input": transcription, | |
| "user_emotion": emotion, | |
| "ai_response": response_text, | |
| "processing_time": processing_time | |
| } | |
| self.conversations[user_id].append(conversation_entry) | |
| # Keep last 1000 exchanges as requested[5] | |
| if len(self.conversations[user_id]) > 1000: | |
| self.conversations[user_id] = self.conversations[user_id][-1000:] | |
| history = self.format_conversation_history(user_id) | |
| sample_rate = 24000 if self.use_dia else 22050 | |
| return transcription, (sample_rate, response_audio) if response_audio is not None else None, history | |
| except Exception as e: | |
| return f"Processing error: {str(e)}", None, "Error in processing" | |
| def format_conversation_history(self, user_id): | |
| """Format conversation history for display""" | |
| if user_id not in self.conversations or not self.conversations[user_id]: | |
| return "No conversation history yet." | |
| history = [] | |
| for i, entry in enumerate(self.conversations[user_id][-10:], 1): | |
| history.append(f"**Exchange {i}** ({entry['timestamp']})") | |
| history.append(f"π€ **You** ({entry['user_emotion']}): {entry['user_input']}") | |
| history.append(f"π€ **Maya**: {entry['ai_response']}") | |
| history.append(f"β±οΈ *{entry['processing_time']:.2f}s*") | |
| history.append("---") | |
| return "\n".join(history) | |
| # Initialize Maya AI | |
| print("π Starting Maya AI with REAL Dia TTS...") | |
| maya = MayaAI() | |
| print("β Maya AI ready with ultra-realistic dialogue generation!") | |
| # Gradio Interface Functions | |
| def start_call_handler(): | |
| return maya.start_call() | |
| def end_call_handler(): | |
| return maya.end_call() | |
| def process_audio_handler(audio): | |
| return maya.process_conversation(audio) | |
| # Create Gradio Interface[7] | |
| with gr.Blocks( | |
| title="Maya AI - Dia-Powered Sesame Killer", | |
| theme=gr.themes.Soft() | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π€ Maya AI - Dia-Powered Sesame Killer | |
| *Ultra-realistic dialogue generation with Dia TTS - Natural breathing, laughter, and human-like responses* | |
| **Features:** β Real Dia TTS β English-only ASR β Emotion Recognition β FREE LLM β Ultra-realistic Speech | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Call Controls") | |
| start_call_btn = gr.Button("π Start Call", variant="primary", size="lg") | |
| end_call_btn = gr.Button("π End Call", variant="stop", size="lg") | |
| gr.Markdown("### ποΈ Voice Input") | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="Record your message in English" | |
| ) | |
| process_btn = gr.Button("π― Process Audio", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π¬ Ultra-Realistic Conversation") | |
| transcription_output = gr.Textbox( | |
| label="π What you said (English)", | |
| lines=2, | |
| interactive=False | |
| ) | |
| audio_output = gr.Audio( | |
| label="π Maya's Ultra-Realistic Response (Dia TTS)", | |
| interactive=False, | |
| autoplay=True | |
| ) | |
| conversation_display = gr.Textbox( | |
| label="π Live Conversation (FREE & Ultra-Realistic)", | |
| lines=15, | |
| interactive=False, | |
| show_copy_button=True | |
| ) | |
| # Event Handlers | |
| start_call_btn.click( | |
| fn=start_call_handler, | |
| outputs=[transcription_output, audio_output, conversation_display] | |
| ) | |
| end_call_btn.click( | |
| fn=end_call_handler, | |
| outputs=[transcription_output, audio_output, conversation_display] | |
| ) | |
| process_btn.click( | |
| fn=process_audio_handler, | |
| inputs=[audio_input], | |
| outputs=[transcription_output, audio_output, conversation_display] | |
| ) | |
| audio_input.stop_recording( | |
| fn=process_audio_handler, | |
| inputs=[audio_input], | |
| outputs=[transcription_output, audio_output, conversation_display] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |