import gradio as gr import torch import numpy as np import librosa import soundfile as sf import threading import time import queue import warnings from typing import Optional, List, Dict, Tuple from dataclasses import dataclass from collections import deque import psutil import gc # Models and pipelines from dia.model import Dia from transformers import pipeline import webrtcvad warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) @dataclass class ConversationTurn: user_audio: np.ndarray user_text: str ai_response_text: str ai_response_audio: np.ndarray timestamp: float emotion: str speaker_id: str class EmotionRecognizer: def __init__(self): self.emotion_pipeline = pipeline( "audio-classification", model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", device=0 if torch.cuda.is_available() else -1 ) def detect_emotion(self, audio: np.ndarray, sample_rate: int = 16000) -> str: try: result = self.emotion_pipeline({"array": audio, "sampling_rate": sample_rate}) return result[0]["label"] if result else "neutral" except Exception: return "neutral" class VADProcessor: def __init__(self, aggressiveness: int = 2): self.vad = webrtcvad.Vad(aggressiveness) self.sample_rate = 16000 self.frame_duration = 30 self.frame_size = int(self.sample_rate * self.frame_duration / 1000) def is_speech(self, audio: np.ndarray) -> bool: audio_int16 = (audio * 32767).astype(np.int16) frames = [] for i in range(0, len(audio_int16) - self.frame_size, self.frame_size): frame = audio_int16[i : i + self.frame_size].tobytes() frames.append(self.vad.is_speech(frame, self.sample_rate)) return sum(frames) > len(frames) * 0.3 class ConversationManager: def __init__(self, max_exchanges: int = 50): self.conversations: Dict[str, deque] = {} self.max_exchanges = max_exchanges self.lock = threading.RLock() def add_turn(self, session_id: str, turn: ConversationTurn): with self.lock: if session_id not in self.conversations: self.conversations[session_id] = deque(maxlen=self.max_exchanges) self.conversations[session_id].append(turn) def get_context(self, session_id: str, last_n: int = 5) -> List[ConversationTurn]: with self.lock: return list(self.conversations.get(session_id, []))[-last_n:] def clear_session(self, session_id: str): with self.lock: if session_id in self.conversations: del self.conversations[session_id] class SupernaturalAI: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.models_loaded = False self.conversation_manager = ConversationManager() self.processing_times = deque(maxlen=100) self.emotion_recognizer = None self.vad_processor = VADProcessor() self.ultravox_model = None self.dia_model = None self._initialize_models() def _initialize_models(self): try: self.ultravox_model = pipeline( 'automatic-speech-recognition', model='fixie-ai/ultravox-v0_2', trust_remote_code=True, device=0 if torch.cuda.is_available() else -1, torch_dtype=torch.float16 ) self.dia_model = Dia.from_pretrained( "nari-labs/Dia-1.6B", compute_dtype="float16" ) self.emotion_recognizer = EmotionRecognizer() self.models_loaded = True if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: print(f"Model load error: {e}") self.models_loaded = False def process_audio_input(self, audio_data: Tuple[int, np.ndarray], session_id: str): if not self.models_loaded or audio_data is None: return None, "Models not ready", "Please wait" start = time.time() sample_rate, audio = audio_data if len(audio.shape) > 1: audio = np.mean(audio, axis=1) audio = audio.astype(np.float32) if np.max(np.abs(audio)) > 0: audio = audio / np.max(np.abs(audio)) * 0.95 if not self.vad_processor.is_speech(audio): return None, "No speech detected", "Speak clearly" if sample_rate != 16000: audio = librosa.resample(audio, sample_rate, 16000) sample_rate = 16000 try: result = self.ultravox_model({'array': audio, 'sampling_rate': sample_rate}) user_text = result.get('text', '').strip() if not user_text: return None, "Could not understand", "Try again" except Exception as e: return None, f"ASR error: {e}", "Retry" emotion = self.emotion_recognizer.detect_emotion(audio, sample_rate) context = self.conversation_manager.get_context(session_id) prompt = self._build_prompt(user_text, emotion, context) try: with torch.no_grad(): audio_out = self.dia_model.generate(prompt, use_torch_compile=False) audio_out = audio_out.cpu().numpy() if isinstance(audio_out, torch.Tensor) else audio_out except Exception as e: return None, f"TTS error: {e}", "Retry" ai_text = prompt.split('[S2]')[-1].strip() turn = ConversationTurn(audio, user_text, ai_text, audio_out, time.time(), emotion, session_id) self.conversation_manager.add_turn(session_id, turn) elapsed = time.time() - start self.processing_times.append(elapsed) if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() status = f"Processed in {elapsed:.2f}s | Emotion: {emotion}" return (44100, audio_out), status, f"You: {user_text}\n\nAI: {ai_text}" def _build_prompt(self, text, emotion, context): ctx = "".join(f"[U]{t.user_text}[A]{t.ai_response_text} " for t in context[-3:]) mods = {"happy":"(cheerful)","sad":"(sympathetic)","angry":"(calming)", "fear":"(reassuring)","surprise":"(excited)","neutral":""} return f"{ctx}[U]{text}[A]{mods.get(emotion,'')} As a supernatural AI, I sense your {emotion} energy. " def get_history(self, session_id: str) -> str: ctx = self.conversation_manager.get_context(session_id, last_n=10) if not ctx: return "No history." out = "" for i, t in enumerate(ctx,1): out += f"Turn {i} — You: {t.user_text} | AI: {t.ai_response_text} | Emotion: {t.emotion}\n\n" return out def clear_history(self, session_id: str) -> str: self.conversation_manager.clear_session(session_id) return "History cleared." # Instantiate and launch Gradio app ai = SupernaturalAI() with gr.Blocks() as demo: audio_in = gr.Audio(source="microphone", type="numpy", label="Speak") audio_out = gr.Audio(label="AI Response") session = gr.Textbox(label="Session ID", interactive=True) status = gr.Textbox(label="Status") chat = gr.Markdown("## Conversation") btn = gr.Button("Send") btn.click(fn=lambda a, s: ai.process_audio_input(a, s), inputs=[audio_in, session], outputs=[audio_out, status, chat, session]) hist_btn = gr.Button("History") hist_btn.click(fn=lambda s: ai.get_history(s), inputs=session, outputs=chat) clr_btn = gr.Button("Clear") clr_btn.click(fn=lambda s: ai.clear_history(s), inputs=session, outputs=chat) demo.queue(concurrency_count=20, max_size=100) demo.launch(server_name="0.0.0.0", server_port=7860, enable_queue=True)