Maya-AI / app.py
Devakumar868's picture
Update app.py
c78b630 verified
raw
history blame
19.5 kB
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import (
pipeline, AutoTokenizer, AutoModelForCausalLM,
WhisperProcessor, WhisperForConditionalGeneration
)
import soundfile as sf
import json
import time
from datetime import datetime
import os
import warnings
# Import Dia model correctly[2]
try:
from dia.model import Dia
DIA_AVAILABLE = True
print("βœ… Dia model imported successfully")
except ImportError as e:
print(f"⚠️ Dia import failed: {e}")
DIA_AVAILABLE = False
warnings.filterwarnings("ignore")
class MayaAI:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸš€ Initializing Maya AI on {self.device}")
# Load Whisper ASR with FORCED English
self.asr_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
self.asr_model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-large-v3",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
# FORCE English transcription
self.asr_model.config.forced_decoder_ids = self.asr_processor.get_decoder_prompt_ids(
language="english",
task="transcribe"
)
print("βœ… Whisper ASR loaded with FORCED English")
# Load FREE LLM with FIXED attention mask
self.llm_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
# FIX: Set pad_token to eos_token to avoid attention mask warnings
if self.llm_tokenizer.pad_token is None:
self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
self.llm_model = AutoModelForCausalLM.from_pretrained(
"microsoft/DialoGPT-large",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto",
pad_token_id=self.llm_tokenizer.eos_token_id
)
print("βœ… DialoGPT-Large loaded with FIXED attention masks")
# Load Emotion Recognition
self.emotion_model = pipeline(
"audio-classification",
model="superb/wav2vec2-base-superb-er",
device=self.device
)
print("βœ… Emotion recognition loaded")
# Load REAL Dia TTS Model[2]
if DIA_AVAILABLE:
try:
# Load Dia model with correct parameters[2]
self.dia_model = Dia.from_pretrained(
"nari-labs/Dia-1.6B",
compute_dtype="float16" if self.device == "cuda" else "float32",
device=self.device
)
print("βœ… Dia TTS loaded (Ultra-realistic dialogue generation)")
self.use_dia = True
except Exception as e:
print(f"⚠️ Dia loading failed: {e}")
self.use_dia = False
self._load_fallback_tts()
else:
print("⚠️ Dia not available, using fallback TTS")
self.use_dia = False
self._load_fallback_tts()
# Conversation storage
self.conversations = {}
self.call_active = False
self.speaker_turn = 1 # Track speaker turns for Dia[2]
def _load_fallback_tts(self):
"""Load fallback TTS if Dia is not available"""
try:
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
self.tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
self.tts_model = SpeechT5ForTextToSpeech.from_pretrained(
"microsoft/speecht5_tts",
torch_dtype=torch.float32
).to(self.device)
self.vocoder = SpeechT5HifiGan.from_pretrained(
"microsoft/speecht5_hifigan",
torch_dtype=torch.float32
).to(self.device)
# Load female speaker embeddings
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
self.speaker_embeddings = torch.tensor(
embeddings_dataset[7306]["xvector"],
dtype=torch.float32
).unsqueeze(0).to(self.device)
print("βœ… SpeechT5 TTS loaded as fallback")
except Exception as e:
print(f"❌ Fallback TTS loading failed: {e}")
def transcribe_with_whisper(self, audio_path):
"""Transcribe using Whisper with FORCED English"""
try:
if audio_path is None:
return "No audio provided"
# Load and preprocess audio
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
# Process with Whisper - FORCE English
inputs = self.asr_processor(
audio,
sampling_rate=16000,
return_tensors="pt",
language="english"
).to(self.device)
with torch.no_grad():
predicted_ids = self.asr_model.generate(
inputs.input_features,
max_new_tokens=150,
do_sample=False,
forced_decoder_ids=self.asr_model.config.forced_decoder_ids
)
transcription = self.asr_processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription.strip()
except Exception as e:
return f"Transcription error: {str(e)}"
def recognize_emotion_from_audio(self, audio_path):
"""Recognize emotion using superb model"""
try:
if audio_path is None:
return "neutral"
result = self.emotion_model(audio_path)
emotion_label = result[0]["label"].lower()
# Map to human emotions
emotion_map = {
"ang": "angry", "hap": "happy", "exc": "excited",
"sad": "sad", "fru": "frustrated", "fea": "fearful",
"sur": "surprised", "neu": "neutral", "dis": "disgusted"
}
return emotion_map.get(emotion_label, emotion_label)
except:
return "neutral"
def generate_with_free_llm(self, text, emotion, history):
"""Generate response using FREE LLM with FIXED attention masks"""
try:
# Emotional context prompting
emotion_prompts = {
"angry": "I understand you're frustrated. Let me help calm this situation.",
"sad": "I can hear the sadness in your voice. I'm here to support you.",
"happy": "Your joy is infectious! I love your positive energy.",
"excited": "Your enthusiasm is amazing! Tell me more!",
"fearful": "I sense your concern. Let's work through this together.",
"surprised": "That sounds unexpected! What happened?",
"neutral": "I'm listening carefully. Please continue."
}
emotion_context = emotion_prompts.get(emotion, "I'm here to help.")
# Build conversation context
context_text = ""
if history:
for entry in history[-2:]:
context_text += f"User: {entry.get('user_input', '')}\nMaya: {entry.get('ai_response', '')}\n"
prompt = f"{context_text}User: {text}\nMaya:"
# Tokenize input with PROPER attention mask
inputs = self.llm_tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024,
padding=True,
add_special_tokens=True
).to(self.device)
# Generate response with PROPER attention mask
with torch.no_grad():
outputs = self.llm_model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=80,
temperature=0.7,
do_sample=True,
pad_token_id=self.llm_tokenizer.pad_token_id,
eos_token_id=self.llm_tokenizer.eos_token_id
)
# Decode response
response = self.llm_tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
).strip()
# Clean up response
if not response or len(response) < 5:
return emotion_context
return response
except Exception as e:
return f"{emotion_prompts.get(emotion, 'I understand.')} Could you tell me more about that?"
def synthesize_with_dia(self, text, emotion):
"""Generate ultra-realistic dialogue using Dia[2]"""
try:
if not text or len(text.strip()) == 0:
return None
if self.use_dia:
# Format text for Dia with proper speaker tags[2]
speaker_tag = f"[S{self.speaker_turn}]"
# Add emotional non-verbals based on emotion[2]
if emotion == "happy":
emotional_text = f"{speaker_tag} {text} (laughs)"
elif emotion == "sad":
emotional_text = f"{speaker_tag} {text} (sighs)"
elif emotion == "excited":
emotional_text = f"{speaker_tag} {text}!"
elif emotion == "angry":
emotional_text = f"{speaker_tag} {text} (frustrated tone)"
elif emotion == "surprised":
emotional_text = f"{speaker_tag} {text} (gasps)"
else:
emotional_text = f"{speaker_tag} {text}"
# Generate with Dia[2]
output = self.dia_model.generate(
emotional_text,
use_torch_compile=True if self.device == "cuda" else False,
verbose=False
)
# Toggle speaker for next turn[2]
self.speaker_turn = 2 if self.speaker_turn == 1 else 1
return output
else:
# Fallback to SpeechT5
return self._synthesize_with_fallback(text, emotion)
except Exception as e:
print(f"Dia TTS error: {e}")
return self._synthesize_with_fallback(text, emotion)
def _synthesize_with_fallback(self, text, emotion):
"""Fallback TTS synthesis"""
try:
clean_text = text.replace("[", "").replace("]", "").strip()
if len(clean_text) > 200:
clean_text = clean_text[:200] + "..."
# Add emotional inflection through punctuation
if emotion == "happy":
clean_text = clean_text.replace(".", "!")
elif emotion == "excited":
clean_text = clean_text + "!"
elif emotion == "sad":
clean_text = clean_text.replace("!", ".")
inputs = self.tts_processor(text=clean_text, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
speech = self.tts_model.generate_speech(
inputs["input_ids"],
self.speaker_embeddings,
vocoder=self.vocoder
)
if isinstance(speech, torch.Tensor):
speech = speech.cpu().numpy().astype(np.float32)
return speech
except Exception as e:
print(f"Fallback TTS error: {e}")
return None
def start_call(self):
"""Start a new call session"""
self.call_active = True
self.speaker_turn = 1 # Reset speaker turn[2]
greeting = "Hello! I'm Maya, your AI conversation partner. I'm here to chat with you naturally and understand your emotions. How are you feeling today?"
greeting_audio = self.synthesize_with_dia(greeting, "happy")
# Dia outputs at 24kHz, fallback at 22050Hz[2]
sample_rate = 24000 if self.use_dia else 22050
return greeting, (sample_rate, greeting_audio) if greeting_audio is not None else None, "πŸ“ž Call started! Maya is greeting you with ultra-realistic speech..."
def end_call(self, user_id="default"):
"""End call and clear conversation"""
self.call_active = False
if user_id in self.conversations:
self.conversations[user_id] = []
farewell = "Thank you for chatting with me! It was wonderful talking with you. Have a great day!"
farewell_audio = self.synthesize_with_dia(farewell, "happy")
sample_rate = 24000 if self.use_dia else 22050
return farewell, (sample_rate, farewell_audio) if farewell_audio is not None else None, "πŸ“ž Call ended. Conversation cleared!"
def process_conversation(self, audio_input, user_id="default"):
"""Main conversation processing pipeline"""
if not self.call_active:
return "Please start a call first by clicking the 'Start Call' button", None, "No active call"
if audio_input is None:
return "Please record some audio", None, "No audio input"
start_time = time.time()
if user_id not in self.conversations:
self.conversations[user_id] = []
try:
# Step 1: ASR with FORCED English
transcription = self.transcribe_with_whisper(audio_input)
# Step 2: Emotion recognition
emotion = self.recognize_emotion_from_audio(audio_input)
# Step 3: FREE LLM generation with FIXED attention masks
response_text = self.generate_with_free_llm(
transcription, emotion, self.conversations[user_id]
)
# Step 4: Ultra-realistic TTS with Dia[2]
response_audio = self.synthesize_with_dia(response_text, emotion)
# Step 5: Update conversation history
processing_time = time.time() - start_time
conversation_entry = {
"timestamp": datetime.now().strftime("%H:%M:%S"),
"user_input": transcription,
"user_emotion": emotion,
"ai_response": response_text,
"processing_time": processing_time
}
self.conversations[user_id].append(conversation_entry)
# Keep last 1000 exchanges as requested[5]
if len(self.conversations[user_id]) > 1000:
self.conversations[user_id] = self.conversations[user_id][-1000:]
history = self.format_conversation_history(user_id)
sample_rate = 24000 if self.use_dia else 22050
return transcription, (sample_rate, response_audio) if response_audio is not None else None, history
except Exception as e:
return f"Processing error: {str(e)}", None, "Error in processing"
def format_conversation_history(self, user_id):
"""Format conversation history for display"""
if user_id not in self.conversations or not self.conversations[user_id]:
return "No conversation history yet."
history = []
for i, entry in enumerate(self.conversations[user_id][-10:], 1):
history.append(f"**Exchange {i}** ({entry['timestamp']})")
history.append(f"🎀 **You** ({entry['user_emotion']}): {entry['user_input']}")
history.append(f"πŸ€– **Maya**: {entry['ai_response']}")
history.append(f"⏱️ *{entry['processing_time']:.2f}s*")
history.append("---")
return "\n".join(history)
# Initialize Maya AI
print("πŸš€ Starting Maya AI with REAL Dia TTS...")
maya = MayaAI()
print("βœ… Maya AI ready with ultra-realistic dialogue generation!")
# Gradio Interface Functions
def start_call_handler():
return maya.start_call()
def end_call_handler():
return maya.end_call()
def process_audio_handler(audio):
return maya.process_conversation(audio)
# Create Gradio Interface[7]
with gr.Blocks(
title="Maya AI - Dia-Powered Sesame Killer",
theme=gr.themes.Soft()
) as demo:
gr.Markdown("""
# 🎀 Maya AI - Dia-Powered Sesame Killer
*Ultra-realistic dialogue generation with Dia TTS - Natural breathing, laughter, and human-like responses*
**Features:** βœ… Real Dia TTS βœ… English-only ASR βœ… Emotion Recognition βœ… FREE LLM βœ… Ultra-realistic Speech
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“ž Call Controls")
start_call_btn = gr.Button("πŸ“ž Start Call", variant="primary", size="lg")
end_call_btn = gr.Button("πŸ“ž End Call", variant="stop", size="lg")
gr.Markdown("### πŸŽ™οΈ Voice Input")
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="Record your message in English"
)
process_btn = gr.Button("🎯 Process Audio", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### πŸ’¬ Ultra-Realistic Conversation")
transcription_output = gr.Textbox(
label="πŸ“ What you said (English)",
lines=2,
interactive=False
)
audio_output = gr.Audio(
label="πŸ”Š Maya's Ultra-Realistic Response (Dia TTS)",
interactive=False,
autoplay=True
)
conversation_display = gr.Textbox(
label="πŸ’­ Live Conversation (FREE & Ultra-Realistic)",
lines=15,
interactive=False,
show_copy_button=True
)
# Event Handlers
start_call_btn.click(
fn=start_call_handler,
outputs=[transcription_output, audio_output, conversation_display]
)
end_call_btn.click(
fn=end_call_handler,
outputs=[transcription_output, audio_output, conversation_display]
)
process_btn.click(
fn=process_audio_handler,
inputs=[audio_input],
outputs=[transcription_output, audio_output, conversation_display]
)
audio_input.stop_recording(
fn=process_audio_handler,
inputs=[audio_input],
outputs=[transcription_output, audio_output, conversation_display]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)