Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| import warnings | |
| from transformers import pipeline, AutoProcessor, AutoModel | |
| from dia.model import Dia | |
| import asyncio | |
| import time | |
| from collections import deque | |
| import json | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| # Global variables for model caching | |
| dia_model = None | |
| asr_model = None | |
| emotion_classifier = None | |
| conversation_histories = {} | |
| MAX_HISTORY = 50 | |
| MAX_CONCURRENT_USERS = 20 | |
| class ConversationManager: | |
| def __init__(self): | |
| self.histories = {} | |
| self.max_history = MAX_HISTORY | |
| def get_history(self, session_id): | |
| if session_id not in self.histories: | |
| self.histories[session_id] = deque(maxlen=self.max_history) | |
| return list(self.histories[session_id]) | |
| def add_exchange(self, session_id, user_input, ai_response, user_emotion=None, ai_emotion=None): | |
| if session_id not in self.histories: | |
| self.histories[session_id] = deque(maxlen=self.max_history) | |
| exchange = { | |
| "user": user_input, | |
| "ai": ai_response, | |
| "user_emotion": user_emotion, | |
| "ai_emotion": ai_emotion, | |
| "timestamp": time.time() | |
| } | |
| self.histories[session_id].append(exchange) | |
| def clear_history(self, session_id): | |
| if session_id in self.histories: | |
| del self.histories[session_id] | |
| conversation_manager = ConversationManager() | |
| def load_models(): | |
| """Load all models once and cache globally""" | |
| global dia_model, asr_model, emotion_classifier | |
| if dia_model is None: | |
| print("Loading Dia TTS model...") | |
| try: | |
| dia_model = Dia.from_pretrained( | |
| "nari-labs/Dia-1.6B", | |
| compute_dtype="float16", | |
| torch_dtype=torch.float16 | |
| ) | |
| print("โ Dia model loaded successfully!") | |
| except Exception as e: | |
| print(f"โ Error loading Dia model: {e}") | |
| raise | |
| if asr_model is None: | |
| print("Loading ASR model...") | |
| try: | |
| # Using Whisper for ASR with optimizations | |
| asr_model = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-small", | |
| torch_dtype=torch.float16, | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| print("โ ASR model loaded successfully!") | |
| except Exception as e: | |
| print(f"โ Error loading ASR model: {e}") | |
| raise | |
| if emotion_classifier is None: | |
| print("Loading emotion classifier...") | |
| try: | |
| emotion_classifier = pipeline( | |
| "text-classification", | |
| model="j-hartmann/emotion-english-distilroberta-base", | |
| torch_dtype=torch.float16, | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| print("โ Emotion classifier loaded successfully!") | |
| except Exception as e: | |
| print(f"โ Error loading emotion classifier: {e}") | |
| raise | |
| def detect_emotion(text): | |
| """Detect emotion from text""" | |
| try: | |
| if emotion_classifier is None: | |
| return "neutral" | |
| result = emotion_classifier(text) | |
| return result[0]['label'].lower() if result else "neutral" | |
| except Exception as e: | |
| print(f"Error in emotion detection: {e}") | |
| return "neutral" | |
| def transcribe_audio(audio_data): | |
| """Transcribe audio to text with emotion detection""" | |
| try: | |
| if audio_data is None: | |
| return "", "neutral" | |
| # Handle different audio input formats | |
| if isinstance(audio_data, tuple): | |
| sample_rate, audio = audio_data | |
| audio = audio.astype(np.float32) | |
| else: | |
| audio = audio_data | |
| sample_rate = 16000 | |
| # Ensure audio is in the right format for Whisper | |
| if len(audio.shape) > 1: | |
| audio = audio.mean(axis=1) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000) | |
| # Transcribe | |
| result = asr_model(audio) | |
| text = result["text"].strip() | |
| # Detect emotion from transcribed text | |
| emotion = detect_emotion(text) | |
| return text, emotion | |
| except Exception as e: | |
| print(f"Error in transcription: {e}") | |
| return "", "neutral" | |
| def generate_emotional_response(user_text, user_emotion, conversation_history, session_id): | |
| """Generate contextually aware emotional response""" | |
| try: | |
| # Build context from conversation history | |
| context = "" | |
| if conversation_history: | |
| recent_exchanges = conversation_history[-5:] # Last 5 exchanges for context | |
| for exchange in recent_exchanges: | |
| context += f"User: {exchange['user']}\nAI: {exchange['ai']}\n" | |
| # Emotional adaptation logic | |
| emotion_responses = { | |
| "joy": ["excited", "happy", "cheerful"], | |
| "sadness": ["empathetic", "gentle", "comforting"], | |
| "anger": ["calm", "understanding", "patient"], | |
| "fear": ["reassuring", "supportive", "confident"], | |
| "surprise": ["curious", "engaged", "interested"], | |
| "disgust": ["neutral", "diplomatic", "respectful"], | |
| "neutral": ["friendly", "conversational", "natural"] | |
| } | |
| ai_emotion = np.random.choice(emotion_responses.get(user_emotion, ["friendly"])) | |
| # Generate response based on context and emotion | |
| if "supernatural" in user_text.lower() or "magic" in user_text.lower(): | |
| response_templates = [ | |
| "The mystical energies around us are quite fascinating, aren't they?", | |
| "I sense something extraordinary in your words...", | |
| "The supernatural realm holds many mysteries we're yet to understand.", | |
| "There's an otherworldly quality to our conversation that intrigues me." | |
| ] | |
| elif user_emotion == "sadness": | |
| response_templates = [ | |
| "I understand how you're feeling, and I'm here to listen.", | |
| "Your emotions are valid, and it's okay to feel this way.", | |
| "Sometimes sharing our feelings can help lighten the burden." | |
| ] | |
| elif user_emotion == "joy": | |
| response_templates = [ | |
| "Your happiness is contagious! I love your positive energy!", | |
| "It's wonderful to hear such joy in your voice!", | |
| "Your enthusiasm brightens up our conversation!" | |
| ] | |
| else: | |
| response_templates = [ | |
| f"That's an interesting perspective on {user_text.split()[-1] if user_text.split() else 'that'}.", | |
| "I find our conversation quite engaging and thought-provoking.", | |
| "Your thoughts resonate with me in unexpected ways." | |
| ] | |
| response = np.random.choice(response_templates) | |
| # Add emotional cues for TTS | |
| emotion_cues = { | |
| "excited": "(excited)", | |
| "happy": "(laughs)", | |
| "gentle": "(sighs)", | |
| "empathetic": "(softly)", | |
| "reassuring": "(warmly)", | |
| "curious": "(intrigued)" | |
| } | |
| if ai_emotion in emotion_cues: | |
| response += f" {emotion_cues[ai_emotion]}" | |
| return response, ai_emotion | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| return "I'm here to listen and understand you better.", "neutral" | |
| def generate_speech(text, emotion="neutral", speaker="S1"): | |
| """Generate speech with emotional conditioning""" | |
| try: | |
| if dia_model is None: | |
| load_models() | |
| # Clear GPU cache | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Format text for Dia model with speaker tags | |
| formatted_text = f"[{speaker}] {text}" | |
| # Set seed for consistency | |
| torch.manual_seed(42) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(42) | |
| print(f"Generating speech: {formatted_text[:100]}...") | |
| # Generate audio with optimizations | |
| with torch.no_grad(): | |
| audio_output = dia_model.generate( | |
| formatted_text, | |
| use_torch_compile=False, # Disabled for stability | |
| verbose=False | |
| ) | |
| # Convert to numpy if needed | |
| if isinstance(audio_output, torch.Tensor): | |
| audio_output = audio_output.cpu().numpy() | |
| # Normalize audio | |
| if len(audio_output) > 0: | |
| max_val = np.max(np.abs(audio_output)) | |
| if max_val > 1.0: | |
| audio_output = audio_output / max_val * 0.95 | |
| return (44100, audio_output) | |
| except Exception as e: | |
| print(f"Error in speech generation: {e}") | |
| return None | |
| def process_conversation(audio_input, session_id, history): | |
| """Main conversation processing pipeline""" | |
| start_time = time.time() | |
| try: | |
| # Step 1: Transcribe audio (Target: <100ms) | |
| transcription_start = time.time() | |
| user_text, user_emotion = transcribe_audio(audio_input) | |
| transcription_time = (time.time() - transcription_start) * 1000 | |
| if not user_text: | |
| return None, "โ Could not transcribe audio", history, f"Transcription failed" | |
| # Step 2: Get conversation history | |
| conversation_history = conversation_manager.get_history(session_id) | |
| # Step 3: Generate response (Target: <200ms) | |
| response_start = time.time() | |
| ai_response, ai_emotion = generate_emotional_response( | |
| user_text, user_emotion, conversation_history, session_id | |
| ) | |
| response_time = (time.time() - response_start) * 1000 | |
| # Step 4: Generate speech (Target: <200ms) | |
| tts_start = time.time() | |
| audio_output = generate_speech(ai_response, ai_emotion, "S2") | |
| tts_time = (time.time() - tts_start) * 1000 | |
| # Step 5: Update conversation history | |
| conversation_manager.add_exchange( | |
| session_id, user_text, ai_response, user_emotion, ai_emotion | |
| ) | |
| # Update gradio history | |
| history.append([user_text, ai_response]) | |
| total_time = (time.time() - start_time) * 1000 | |
| status = f"""โ Processing Complete! | |
| ๐ Transcription: {transcription_time:.0f}ms | |
| ๐ง Response Generation: {response_time:.0f}ms | |
| ๐ต Speech Synthesis: {tts_time:.0f}ms | |
| โฑ๏ธ Total Latency: {total_time:.0f}ms | |
| ๐ User Emotion: {user_emotion} | |
| ๐ค AI Emotion: {ai_emotion} | |
| ๐ฌ History: {len(conversation_history)}/50 exchanges""" | |
| return audio_output, status, history, f"User: {user_text}" | |
| except Exception as e: | |
| error_msg = f"โ Error: {str(e)}" | |
| return None, error_msg, history, "Processing failed" | |
| # Initialize models on startup | |
| load_models() | |
| # Create Gradio interface | |
| with gr.Blocks(title="Supernatural AI Agent", theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px; background: linear-gradient(45deg, #1a1a2e, #16213e); color: white; border-radius: 15px; margin-bottom: 20px;"> | |
| <h1>๐ฎ Supernatural Conversational AI Agent</h1> | |
| <p style="font-size: 18px;">Human-like emotional intelligence with <500ms latency โข Speech-to-Speech AI</p> | |
| <p style="font-size: 14px; opacity: 0.8;">Powered by Dia TTS โข Emotional Recognition โข 50 Exchange Memory</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Session management | |
| session_id = gr.Textbox( | |
| label="๐ Session ID", | |
| value="user_001", | |
| info="Unique ID for conversation history" | |
| ) | |
| # Audio input | |
| audio_input = gr.Audio( | |
| label="๐ค Speak to the AI", | |
| type="numpy", | |
| format="wav" | |
| ) | |
| # Process button | |
| process_btn = gr.Button( | |
| "๐ฃ๏ธ Process Conversation", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| # Clear history button | |
| clear_btn = gr.Button( | |
| "๐๏ธ Clear History", | |
| variant="secondary" | |
| ) | |
| with gr.Column(scale=2): | |
| # Chat history | |
| chatbot = gr.Chatbot( | |
| label="๐ฌ Conversation History", | |
| height=400, | |
| show_copy_button=True | |
| ) | |
| # Audio output | |
| audio_output = gr.Audio( | |
| label="๐ AI Response", | |
| type="numpy", | |
| autoplay=True | |
| ) | |
| # Status display | |
| status_display = gr.Textbox( | |
| label="๐ Processing Status", | |
| lines=8, | |
| interactive=False | |
| ) | |
| # Last input display | |
| last_input = gr.Textbox( | |
| label="๐ Last Transcription", | |
| interactive=False | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_conversation, | |
| inputs=[audio_input, session_id, chatbot], | |
| outputs=[audio_output, status_display, chatbot, last_input], | |
| concurrency_limit=MAX_CONCURRENT_USERS | |
| ) | |
| def clear_conversation_history(session_id_val): | |
| conversation_manager.clear_history(session_id_val) | |
| return [], "โ Conversation history cleared!" | |
| clear_btn.click( | |
| fn=clear_conversation_history, | |
| inputs=[session_id], | |
| outputs=[chatbot, status_display] | |
| ) | |
| # Usage instructions | |
| gr.HTML(""" | |
| <div style="margin-top: 20px; padding: 15px; background: #f8f9fa; border-radius: 10px;"> | |
| <h3>๐ฏ Usage Instructions:</h3> | |
| <ul> | |
| <li><strong>Record Audio:</strong> Click the microphone and speak naturally</li> | |
| <li><strong>Emotional AI:</strong> The AI detects and responds to your emotions</li> | |
| <li><strong>Memory:</strong> Maintains up to 50 conversation exchanges</li> | |
| <li><strong>Latency:</strong> Optimized for <500ms response time</li> | |
| <li><strong>Concurrent Users:</strong> Supports up to 20 simultaneous users</li> | |
| </ul> | |
| <h3>๐ฎ Supernatural Features:</h3> | |
| <p>Try mentioning supernatural, mystical, or magical topics for specialized responses!</p> | |
| <h3>โก Performance Metrics:</h3> | |
| <p><strong>Target Latency:</strong> <500ms | <strong>Memory:</strong> 50 exchanges | <strong>Concurrent Users:</strong> 20</p> | |
| </div> | |
| """) | |
| # Configure queue for optimal performance | |
| demo.queue( | |
| default_concurrency_limit=MAX_CONCURRENT_USERS, | |
| max_size=100 | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |