Maya-AI / app.py
Devakumar868's picture
Update app.py
6e6580b verified
raw
history blame
18.7 kB
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import (
pipeline, AutoTokenizer, AutoModelForCausalLM,
WhisperProcessor, WhisperForConditionalGeneration
)
import soundfile as sf
import json
import time
from datetime import datetime
import os
import warnings
from datasets import load_dataset
# Import Dia TTS model
from dia.model import Dia
warnings.filterwarnings("ignore")
class MayaAI:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"πŸš€ Initializing Maya AI on {self.device}")
# Load Whisper ASR with FORCED English
self.asr_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
self.asr_model = WhisperForConditionalGeneration.from_pretrained(
"openai/whisper-large-v3",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
# FORCE English transcription
self.asr_model.config.forced_decoder_ids = self.asr_processor.get_decoder_prompt_ids(
language="english",
task="transcribe"
)
print("βœ… Whisper ASR loaded with FORCED English")
# Load FREE LLM with FIXED attention mask
self.llm_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
# FIX: Set pad_token to eos_token to avoid attention mask warnings
if self.llm_tokenizer.pad_token is None:
self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
self.llm_model = AutoModelForCausalLM.from_pretrained(
"microsoft/DialoGPT-large",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto",
pad_token_id=self.llm_tokenizer.eos_token_id
)
print("βœ… DialoGPT-Large loaded with FIXED attention masks")
# Load Emotion Recognition
self.emotion_model = pipeline(
"audio-classification",
model="superb/wav2vec2-base-superb-er",
device=self.device
)
print("βœ… Emotion recognition loaded")
# Load Dia TTS Model (The REAL Dia from Nari Labs)
try:
self.dia_model = Dia.from_pretrained(
"nari-labs/Dia-1.6B",
compute_dtype="float16" if self.device == "cuda" else "float32"
)[11][13][15]
print("βœ… Dia TTS loaded successfully from Nari Labs")
self.use_dia = True
except Exception as e:
print(f"⚠️ Dia loading failed: {e}")
# Fallback to SpeechT5 with FIXED dtypes
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
self.tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
self.tts_model = SpeechT5ForTextToSpeech.from_pretrained(
"microsoft/speecht5_tts",
torch_dtype=torch.float32
).to(self.device)
self.vocoder = SpeechT5HifiGan.from_pretrained(
"microsoft/speecht5_hifigan",
torch_dtype=torch.float32
).to(self.device)
# Load female speaker embeddings
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
self.speaker_embeddings = torch.tensor(
embeddings_dataset[7306]["xvector"],
dtype=torch.float32
).unsqueeze(0).to(self.device)
print("βœ… SpeechT5 TTS loaded as fallback")
self.use_dia = False
# Conversation storage
self.conversations = {}
self.call_active = False
def transcribe_with_whisper(self, audio_path):
"""Transcribe using Whisper with FORCED English"""
try:
if audio_path is None:
return "No audio provided"
# Load and preprocess audio
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
# Process with Whisper - FORCE English
inputs = self.asr_processor(
audio,
sampling_rate=16000,
return_tensors="pt",
language="english"
).to(self.device)
with torch.no_grad():
predicted_ids = self.asr_model.generate(
inputs.input_features,
max_new_tokens=150,
do_sample=False,
forced_decoder_ids=self.asr_model.config.forced_decoder_ids
)
transcription = self.asr_processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
return transcription.strip()
except Exception as e:
return f"Transcription error: {str(e)}"
def recognize_emotion_from_audio(self, audio_path):
"""Recognize emotion using superb model"""
try:
if audio_path is None:
return "neutral"
result = self.emotion_model(audio_path)
emotion_label = result[0]["label"].lower()
# Map to human emotions
emotion_map = {
"ang": "angry", "hap": "happy", "exc": "excited",
"sad": "sad", "fru": "frustrated", "fea": "fearful",
"sur": "surprised", "neu": "neutral", "dis": "disgusted"
}
return emotion_map.get(emotion_label, emotion_label)
except:
return "neutral"
def generate_with_free_llm(self, text, emotion, history):
"""Generate response using FREE LLM with FIXED attention masks"""
try:
# Emotional context prompting
emotion_prompts = {
"angry": "I understand you're frustrated. Let me help calm this situation.",
"sad": "I can hear the sadness in your voice. I'm here to support you.",
"happy": "Your joy is infectious! I love your positive energy.",
"excited": "Your enthusiasm is amazing! Tell me more!",
"fearful": "I sense your concern. Let's work through this together.",
"surprised": "That sounds unexpected! What happened?",
"neutral": "I'm listening carefully. Please continue."
}
emotion_context = emotion_prompts.get(emotion, "I'm here to help.")
# Build conversation context
context_text = ""
if history:
for entry in history[-2:]:
context_text += f"User: {entry.get('user_input', '')}\nMaya: {entry.get('ai_response', '')}\n"
prompt = f"{context_text}User: {text}\nMaya:"
# Tokenize input with PROPER attention mask
inputs = self.llm_tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024,
padding=True,
add_special_tokens=True
).to(self.device)
# Generate response with PROPER attention mask
with torch.no_grad():
outputs = self.llm_model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask, # FIX: Explicit attention mask
max_new_tokens=80,
temperature=0.7,
do_sample=True,
pad_token_id=self.llm_tokenizer.pad_token_id,
eos_token_id=self.llm_tokenizer.eos_token_id
)
# Decode response
response = self.llm_tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
).strip()
# Clean up response
if not response or len(response) < 5:
return emotion_context
return response
except Exception as e:
return f"{emotion_prompts.get(emotion, 'I understand.')} Could you tell me more about that?"
def synthesize_with_dia(self, text, emotion):
"""Generate natural emotional speech using Dia TTS"""[11][13][15]
try:
if not text or len(text.strip()) == 0:
return None
if self.use_dia:
# Use Dia TTS with proper speaker tags and emotional context
# Add emotional markers based on Dia's supported non-verbal tags
if emotion == "happy":
emotional_text = f"[S1] {text} (laughs)"[11][15]
elif emotion == "sad":
emotional_text = f"[S1] {text} (sighs)"[11][15]
elif emotion == "excited":
emotional_text = f"[S1] {text}!"
elif emotion == "angry":
emotional_text = f"[S1] {text} (clears throat)"[11][15]
elif emotion == "surprised":
emotional_text = f"[S1] {text} (gasps)"[11][15]
else:
emotional_text = f"[S1] {text}"[11][15]
# Add natural breathing for longer text (Dia feature)
if len(emotional_text.split()) > 15:
words = emotional_text.split()
mid_point = len(words) // 2
emotional_text = " ".join(words[:mid_point]) + " (inhales) " + " ".join(words[mid_point:])
# Generate using Dia model
output = self.dia_model.generate(
emotional_text,
use_torch_compile=True if self.device == "cuda" else False,
verbose=False
)[11][18]
return output
else:
# Use SpeechT5 fallback with emotional context
clean_text = text.replace("[", "").replace("]", "").strip()
if len(clean_text) > 200:
clean_text = clean_text[:200] + "..."
# Add emotional inflection through punctuation
if emotion == "happy":
clean_text = clean_text.replace(".", "!")
elif emotion == "excited":
clean_text = clean_text + "!"
elif emotion == "sad":
clean_text = clean_text.replace("!", ".")
inputs = self.tts_processor(text=clean_text, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
speech = self.tts_model.generate_speech(
inputs["input_ids"],
self.speaker_embeddings,
vocoder=self.vocoder
)
if isinstance(speech, torch.Tensor):
speech = speech.cpu().numpy().astype(np.float32)
return speech
except Exception as e:
print(f"TTS error: {e}")
return None
def start_call(self):
"""Start a new call session"""
self.call_active = True
greeting = "Hello! I'm Maya, your AI conversation partner. I'm here to chat with you naturally and understand your emotions. How are you feeling today?"
greeting_audio = self.synthesize_with_dia(greeting, "happy")
# Dia outputs at 44100 Hz sample rate
sample_rate = 44100 if self.use_dia else 22050
return greeting, (sample_rate, greeting_audio) if greeting_audio is not None else None, "πŸ“ž Call started! Maya is greeting you..."
def end_call(self, user_id="default"):
"""End call and clear conversation"""
self.call_active = False
if user_id in self.conversations:
self.conversations[user_id] = []
farewell = "Thank you for chatting with me! It was wonderful talking with you. Have a great day!"
farewell_audio = self.synthesize_with_dia(farewell, "happy")
sample_rate = 44100 if self.use_dia else 22050
return farewell, (sample_rate, farewell_audio) if farewell_audio is not None else None, "πŸ“ž Call ended. Conversation cleared!"
def process_conversation(self, audio_input, user_id="default"):
"""Main conversation processing pipeline"""
if not self.call_active:
return "Please start a call first by clicking the 'Start Call' button", None, "No active call"
if audio_input is None:
return "Please record some audio", None, "No audio input"
start_time = time.time()
if user_id not in self.conversations:
self.conversations[user_id] = []
try:
# Step 1: ASR with FORCED English
transcription = self.transcribe_with_whisper(audio_input)
# Step 2: Emotion recognition
emotion = self.recognize_emotion_from_audio(audio_input)
# Step 3: FREE LLM generation with FIXED attention masks
response_text = self.generate_with_free_llm(
transcription, emotion, self.conversations[user_id]
)
# Step 4: Dia TTS with natural emotional speech
response_audio = self.synthesize_with_dia(response_text, emotion)
# Step 5: Update conversation history
processing_time = time.time() - start_time
conversation_entry = {
"timestamp": datetime.now().strftime("%H:%M:%S"),
"user_input": transcription,
"user_emotion": emotion,
"ai_response": response_text,
"processing_time": processing_time
}
self.conversations[user_id].append(conversation_entry)
# Keep last 1000 exchanges as specified
if len(self.conversations[user_id]) > 1000:
self.conversations[user_id] = self.conversations[user_id][-1000:]
history = self.format_conversation_history(user_id)
sample_rate = 44100 if self.use_dia else 22050
return transcription, (sample_rate, response_audio) if response_audio is not None else None, history
except Exception as e:
return f"Processing error: {str(e)}", None, "Error in processing"
def format_conversation_history(self, user_id):
"""Format conversation history for display"""
if user_id not in self.conversations or not self.conversations[user_id]:
return "No conversation history yet."
history = []
for i, entry in enumerate(self.conversations[user_id][-10:], 1):
history.append(f"**Exchange {i}** ({entry['timestamp']})")
history.append(f"🎀 **You** ({entry['user_emotion']}): {entry['user_input']}")
history.append(f"πŸ€– **Maya**: {entry['ai_response']}")
history.append(f"⏱️ *{entry['processing_time']:.2f}s*")
history.append("---")
return "\n".join(history)
# Initialize Maya AI
print("πŸš€ Starting Maya AI with Dia TTS...")
maya = MayaAI()
print("βœ… Maya AI ready with natural emotional speech!")
# Gradio Interface Functions
def start_call_handler():
return maya.start_call()
def end_call_handler():
return maya.end_call()
def process_audio_handler(audio):
return maya.process_conversation(audio)
# Create Gradio Interface
with gr.Blocks(
title="Maya AI - Dia TTS Sesame Killer",
theme=gr.themes.Soft()
) as demo:
gr.Markdown("""
# 🎀 Maya AI - Dia TTS Sesame Killer
*Powered by Nari Labs Dia TTS: Ultra-realistic dialogue with natural breathing, laughter, and emotional speech*
**Features:** βœ… Dia Natural TTS βœ… English-only ASR βœ… Emotion Recognition βœ… FREE Models βœ… Human-like Speech with Non-verbals
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“ž Call Controls")
start_call_btn = gr.Button("πŸ“ž Start Call", variant="primary", size="lg")
end_call_btn = gr.Button("πŸ“ž End Call", variant="stop", size="lg")
gr.Markdown("### πŸŽ™οΈ Voice Input")
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="Record your message in English"
)
process_btn = gr.Button("🎯 Process Audio", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### πŸ’¬ Natural Dia Conversation")
transcription_output = gr.Textbox(
label="πŸ“ What you said (English)",
lines=2,
interactive=False
)
audio_output = gr.Audio(
label="πŸ”Š Maya's Dia Response (Natural with Breathing & Emotions)",
interactive=False,
autoplay=True
)
conversation_display = gr.Textbox(
label="πŸ’­ Live Conversation (FREE & Natural Dia TTS)",
lines=15,
interactive=False,
show_copy_button=True
)
# Event Handlers
start_call_btn.click(
fn=start_call_handler,
outputs=[transcription_output, audio_output, conversation_display]
)
end_call_btn.click(
fn=end_call_handler,
outputs=[transcription_output, audio_output, conversation_display]
)
process_btn.click(
fn=process_audio_handler,
inputs=[audio_input],
outputs=[transcription_output, audio_output, conversation_display]
)
audio_input.stop_recording(
fn=process_audio_handler,
inputs=[audio_input],
outputs=[transcription_output, audio_output, conversation_display]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)