Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import librosa | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| import soundfile as sf | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| import time | |
| from datetime import datetime | |
| import os | |
| # Initialize models | |
| class ConversationalAI: | |
| def __init__(self): | |
| # Load Parakeet ASR | |
| self.asr_model = self.load_parakeet_asr() | |
| # Load Gemini (using local alternative due to API constraints) | |
| self.llm_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it") | |
| self.llm_model = AutoModelForCausalLM.from_pretrained( | |
| "google/gemma-2-9b-it", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # Load Dia TTS | |
| self.tts_model = self.load_dia_tts() | |
| # Load ERVQ for emotion recognition | |
| self.emotion_model = self.load_ervq_emotion() | |
| # Conversation history | |
| self.conversations = {} | |
| def load_parakeet_asr(self): | |
| try: | |
| from nemo.collections.asr import ASRModel | |
| model = ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2") | |
| return model | |
| except: | |
| # Fallback to Whisper if Parakeet unavailable | |
| return pipeline("automatic-speech-recognition", | |
| model="openai/whisper-large-v3", | |
| torch_dtype=torch.float16, | |
| device="cuda") | |
| def load_dia_tts(self): | |
| try: | |
| # Load Dia model from Nari Labs | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained("narilabs/dia-1.6b", | |
| torch_dtype=torch.float16, | |
| device_map="auto") | |
| return model | |
| except: | |
| # Fallback to high-quality alternative | |
| return pipeline("text-to-speech", | |
| model="microsoft/speecht5_tts", | |
| torch_dtype=torch.float16, | |
| device="cuda") | |
| def load_ervq_emotion(self): | |
| # ERVQ emotion recognition model | |
| try: | |
| return pipeline("audio-classification", | |
| model="speechbrain/emotion-recognition-wav2vec2-IEMOCAP", | |
| device="cuda") | |
| except: | |
| return None | |
| def transcribe_audio(self, audio_path): | |
| """Transcribe audio using Parakeet ASR""" | |
| try: | |
| if hasattr(self.asr_model, 'transcribe'): | |
| # Parakeet method | |
| transcription = self.asr_model.transcribe([audio_path]) | |
| return transcription[0] if transcription else "" | |
| else: | |
| # Whisper fallback | |
| result = self.asr_model(audio_path) | |
| return result["text"] | |
| except Exception as e: | |
| return f"Transcription error: {str(e)}" | |
| def recognize_emotion(self, audio_path): | |
| """Recognize emotion from audio""" | |
| if self.emotion_model is None: | |
| return "neutral" | |
| try: | |
| result = self.emotion_model(audio_path) | |
| return result[0]["label"].lower() | |
| except: | |
| return "neutral" | |
| def generate_response(self, text, emotion, conversation_history): | |
| """Generate contextual response using Gemini""" | |
| # Build context-aware prompt | |
| context = f"Previous conversation: {conversation_history[-3:] if conversation_history else 'None'}" | |
| emotion_context = f"User emotion detected: {emotion}" | |
| prompt = f"""You are Maya, a naturally conversational AI assistant with emotional intelligence. | |
| {context} | |
| {emotion_context} | |
| Respond naturally and emotionally appropriate to: {text} | |
| Keep responses conversational, empathetic, and under 100 words.""" | |
| inputs = self.llm_tokenizer(prompt, return_tensors="pt").to("cuda") | |
| with torch.no_grad(): | |
| outputs = self.llm_model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=self.llm_tokenizer.eos_token_id | |
| ) | |
| response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the new response | |
| response = response.split("Respond naturally")[-1].strip() | |
| return response | |
| def synthesize_speech(self, text, emotion): | |
| """Generate emotional speech using Dia TTS""" | |
| try: | |
| # Emotional context for TTS | |
| emotional_prompt = f"[{emotion}] {text}" | |
| if hasattr(self.tts_model, 'generate_speech'): | |
| # Dia method | |
| audio = self.tts_model.generate_speech(emotional_prompt) | |
| else: | |
| # Fallback method | |
| audio = self.tts_model(text) | |
| audio = audio["audio"] | |
| return audio | |
| except Exception as e: | |
| return None | |
| def process_conversation(self, audio_input, user_id="default"): | |
| """Main conversation processing pipeline""" | |
| if audio_input is None: | |
| return "Please provide audio input", None, "No conversation yet" | |
| start_time = time.time() | |
| # Initialize user conversation if not exists | |
| if user_id not in self.conversations: | |
| self.conversations[user_id] = [] | |
| # Step 1: Transcribe audio | |
| transcription = self.transcribe_audio(audio_input) | |
| # Step 2: Recognize emotion | |
| emotion = self.recognize_emotion(audio_input) | |
| # Step 3: Generate response | |
| response_text = self.generate_response( | |
| transcription, emotion, self.conversations[user_id] | |
| ) | |
| # Step 4: Synthesize speech | |
| response_audio = self.synthesize_speech(response_text, emotion) | |
| # Step 5: Update conversation history | |
| conversation_entry = { | |
| "timestamp": datetime.now().isoformat(), | |
| "user_input": transcription, | |
| "user_emotion": emotion, | |
| "ai_response": response_text, | |
| "processing_time": time.time() - start_time | |
| } | |
| self.conversations[user_id].append(conversation_entry) | |
| # Keep only last 50 exchanges per user | |
| if len(self.conversations[user_id]) > 50: | |
| self.conversations[user_id] = self.conversations[user_id][-50:] | |
| # Format conversation history | |
| history = self.format_conversation_history(user_id) | |
| return transcription, response_audio, history | |
| def format_conversation_history(self, user_id): | |
| """Format conversation history for display""" | |
| if user_id not in self.conversations: | |
| return "No conversation history" | |
| history = [] | |
| for entry in self.conversations[user_id][-10:]: # Show last 10 exchanges | |
| history.append(f"π€ You ({entry['user_emotion']}): {entry['user_input']}") | |
| history.append(f"π€ Maya: {entry['ai_response']}") | |
| history.append(f"β±οΈ Response time: {entry['processing_time']:.2f}s\n") | |
| return "\n".join(history) | |
| def clear_conversation(self, user_id="default"): | |
| """Clear conversation history""" | |
| if user_id in self.conversations: | |
| self.conversations[user_id] = [] | |
| return "Conversation cleared!" | |
| # Initialize the AI system | |
| ai_system = ConversationalAI() | |
| # Gradio interface | |
| def process_audio(audio): | |
| transcription, response_audio, history = ai_system.process_conversation(audio) | |
| return transcription, response_audio, history | |
| def clear_chat(): | |
| message = ai_system.clear_conversation() | |
| return message, "Conversation cleared!" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Maya AI - Advanced Conversational AI", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π€ Maya AI - Your Emotional Conversational Partner") | |
| gr.Markdown("*Powered by Parakeet ASR, Gemini LLM, and Dia TTS with emotional intelligence*") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| label="ποΈ Speak to Maya", | |
| interactive=True | |
| ) | |
| process_btn = gr.Button("π¬ Process Conversation", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Conversation", variant="secondary") | |
| with gr.Column(scale=2): | |
| transcription_output = gr.Textbox( | |
| label="π What you said", | |
| interactive=False, | |
| lines=3 | |
| ) | |
| audio_output = gr.Audio( | |
| label="π Maya's Response", | |
| interactive=False | |
| ) | |
| conversation_history = gr.Textbox( | |
| label="π Conversation History", | |
| interactive=False, | |
| lines=15, | |
| max_lines=20 | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_audio, | |
| inputs=[audio_input], | |
| outputs=[transcription_output, audio_output, conversation_history] | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat, | |
| outputs=[transcription_output, conversation_history] | |
| ) | |
| # Auto-process when audio is recorded | |
| audio_input.change( | |
| fn=process_audio, | |
| inputs=[audio_input], | |
| outputs=[transcription_output, audio_output, conversation_history] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True | |
| ) | |