File size: 12,378 Bytes
2be6512
 
 
 
79a976b
2be6512
 
 
 
 
79a976b
 
 
 
2be6512
 
 
79a976b
 
 
 
 
 
 
 
 
 
c91de72
79a976b
c91de72
79a976b
2be6512
c91de72
79a976b
 
 
2be6512
c91de72
79a976b
 
 
 
 
 
2be6512
a527a8c
79a976b
 
a527a8c
79a976b
a527a8c
2be6512
 
 
 
 
79a976b
2be6512
c91de72
 
 
79a976b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2be6512
 
 
a527a8c
 
2be6512
a527a8c
c91de72
 
a527a8c
2be6512
a527a8c
 
2be6512
 
 
79a976b
2be6512
79a976b
 
c91de72
79a976b
 
 
c91de72
79a976b
 
 
 
 
 
 
 
2be6512
c91de72
 
79a976b
 
 
c91de72
 
79a976b
 
c91de72
2be6512
79a976b
 
 
 
 
 
 
 
 
c91de72
 
79a976b
c91de72
79a976b
c91de72
 
 
 
79a976b
 
 
 
 
 
 
 
 
c91de72
79a976b
2be6512
79a976b
2be6512
 
 
 
 
79a976b
2be6512
 
 
 
 
 
 
79a976b
 
 
 
a527a8c
 
79a976b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2be6512
 
 
79a976b
 
2be6512
 
79a976b
 
 
 
 
 
2be6512
 
 
 
 
 
 
79a976b
2be6512
 
79a976b
2be6512
79a976b
2be6512
c91de72
2be6512
c91de72
79a976b
c91de72
79a976b
2be6512
 
 
79a976b
 
 
 
2be6512
 
79a976b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2be6512
 
 
79a976b
2be6512
 
 
79a976b
 
2be6512
 
79a976b
 
 
2be6512
 
79a976b
2be6512
79a976b
c91de72
79a976b
 
2be6512
 
 
 
79a976b
 
2be6512
 
 
 
79a976b
 
 
 
2be6512
 
 
 
 
 
 
 
 
 
 
79a976b
2be6512
 
79a976b
 
2be6512
 
 
 
79a976b
 
 
 
 
 
2be6512
a527a8c
2be6512
79a976b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, WhisperProcessor, WhisperForConditionalGeneration
import soundfile as sf
import json
import time
from datetime import datetime
import os
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

class ConversationalAI:
    def __init__(self):
        # Set device
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        
        # Load Whisper ASR with proper configuration
        self.asr_processor = WhisperProcessor.from_pretrained("openai/whisper-base.en")
        self.asr_model = WhisperForConditionalGeneration.from_pretrained(
            "openai/whisper-base.en",
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        ).to(self.device)
        
        # Load LLM with proper device handling
        self.llm_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
        self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
        self.llm_model = AutoModelForCausalLM.from_pretrained(
            "microsoft/DialoGPT-medium",
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            pad_token_id=self.llm_tokenizer.eos_token_id
        ).to(self.device)
        
        # Load TTS model
        self.tts_model = pipeline(
            "text-to-speech",
            model="microsoft/speecht5_tts",
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            device=self.device
        )
        
        # Load CORRECT audio emotion recognition model
        self.emotion_model = pipeline(
            "audio-classification",
            model="speechbrain/emotion-recognition-wav2vec2-IEMOCAP",
            device=self.device
        )[1]
        
        # Conversation history
        self.conversations = {}
    
    def transcribe_audio(self, audio_path):
        """Transcribe audio using Whisper with proper device handling"""
        try:
            if audio_path is None:
                return "No audio provided"
            
            # Load and preprocess audio
            audio, sr = librosa.load(audio_path, sr=16000, mono=True)
            
            # Process with Whisper
            inputs = self.asr_processor(
                audio, 
                sampling_rate=16000, 
                return_tensors="pt",
                language="en"
            ).to(self.device)
            
            with torch.no_grad():
                predicted_ids = self.asr_model.generate(
                    inputs.input_features,
                    max_new_tokens=100,
                    do_sample=False
                )
            
            transcription = self.asr_processor.batch_decode(
                predicted_ids, 
                skip_special_tokens=True
            )[0]
            
            return transcription.strip()
            
        except Exception as e:
            return f"Transcription error: {str(e)}"
    
    def recognize_emotion(self, audio_path):
        """Recognize emotion from audio using proper audio model"""
        try:
            if audio_path is None:
                return "neutral"
            
            result = self.emotion_model(audio_path)
            return result[0]["label"].lower()
        except Exception as e:
            print(f"Emotion recognition error: {e}")
            return "neutral"
    
    def generate_response(self, text, emotion, conversation_history):
        """Generate contextual response with proper device handling"""
        try:
            if text.startswith("Transcription error") or not text.strip():
                return "I'm sorry, I couldn't understand what you said. Could you please try again?"
            
            # Build context-aware prompt
            emotion_prompt = f"[User seems {emotion}] " if emotion != "neutral" else ""
            prompt = f"{emotion_prompt}User: {text}\nMaya:"
            
            # Tokenize with proper attention mask
            inputs = self.llm_tokenizer(
                prompt, 
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.llm_model.generate(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    max_new_tokens=80,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=self.llm_tokenizer.eos_token_id,
                    eos_token_id=self.llm_tokenizer.eos_token_id
                )
            
            # Decode response
            response = self.llm_tokenizer.decode(
                outputs[0][inputs.input_ids.shape[1]:], 
                skip_special_tokens=True
            ).strip()
            
            # Clean up response
            if not response:
                response = "I understand. Could you tell me more about that?"
            
            return response
            
        except Exception as e:
            return "I'm here to help. What would you like to talk about?"
    
    def synthesize_speech(self, text):
        """Generate speech using TTS"""
        try:
            if not text or len(text.strip()) == 0:
                return None
                
            # Clean text for TTS
            clean_text = text.replace("[", "").replace("]", "").strip()
            if len(clean_text) > 200:
                clean_text = clean_text[:200] + "..."
            
            audio = self.tts_model(clean_text)
            return audio["audio"]
            
        except Exception as e:
            print(f"TTS error: {e}")
            return None
    
    def process_conversation(self, audio_input, user_id="default"):
        """Main conversation processing pipeline"""
        if audio_input is None:
            return "Please record some audio first", None, "No conversation yet"
        
        start_time = time.time()
        
        # Initialize user conversation if not exists
        if user_id not in self.conversations:
            self.conversations[user_id] = []
        
        try:
            # Step 1: Transcribe audio
            transcription = self.transcribe_audio(audio_input)
            
            # Step 2: Recognize emotion from audio
            emotion = self.recognize_emotion(audio_input)
            
            # Step 3: Generate response
            response_text = self.generate_response(
                transcription, emotion, self.conversations[user_id]
            )
            
            # Step 4: Synthesize speech
            response_audio = self.synthesize_speech(response_text)
            
            # Step 5: Update conversation history
            processing_time = time.time() - start_time
            conversation_entry = {
                "timestamp": datetime.now().strftime("%H:%M:%S"),
                "user_input": transcription,
                "user_emotion": emotion,
                "ai_response": response_text,
                "processing_time": processing_time
            }
            
            self.conversations[user_id].append(conversation_entry)
            
            # Keep only last 15 exchanges per user
            if len(self.conversations[user_id]) > 15:
                self.conversations[user_id] = self.conversations[user_id][-15:]
            
            # Format conversation history
            history = self.format_conversation_history(user_id)
            
            return transcription, response_audio, history
            
        except Exception as e:
            error_msg = f"Processing error: {str(e)}"
            return error_msg, None, "Error occurred during processing"
    
    def format_conversation_history(self, user_id):
        """Format conversation history for display"""
        if user_id not in self.conversations or not self.conversations[user_id]:
            return "No conversation history yet. Start by recording some audio!"
        
        history = []
        for i, entry in enumerate(self.conversations[user_id][-5:], 1):
            history.append(f"**Exchange {i}** ({entry['timestamp']})")
            history.append(f"🎀 **You** ({entry['user_emotion']}): {entry['user_input']}")
            history.append(f"πŸ€– **Maya**: {entry['ai_response']}")
            history.append(f"⏱️ *Response time: {entry['processing_time']:.2f}s*")
            history.append("---")
        
        return "\n".join(history)
    
    def clear_conversation(self, user_id="default"):
        """Clear conversation history"""
        if user_id in self.conversations:
            self.conversations[user_id] = []
        return "Conversation cleared! Ready for a fresh start."

# Initialize the AI system
print("Initializing Maya AI...")
ai_system = ConversationalAI()
print("Maya AI ready!")

# Gradio interface functions
def process_audio(audio):
    if audio is None:
        return "Please record some audio first", None, "Click the microphone button above to start recording"
    
    return ai_system.process_conversation(audio)

def clear_chat():
    message = ai_system.clear_conversation()
    return "", None, message

def greet():
    return "", None, "πŸ‘‹ Hi! I'm Maya, your AI conversation partner. Click the microphone button and start talking!"

# Create Gradio interface
with gr.Blocks(
    title="Maya AI - Conversational Assistant", 
    theme=gr.themes.Soft(),
    css="""
    .gradio-container {
        max-width: 1200px !important;
    }
    .audio-container {
        min-height: 200px;
    }
    """
) as demo:
    
    gr.Markdown("""
    # 🎀 Maya AI - Your Conversational Partner
    *Advanced speech recognition with emotional understanding*
    
    **Instructions:** Click the microphone button, speak clearly, then click stop. Maya will respond with voice and text!
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### πŸŽ™οΈ Voice Input")
            audio_input = gr.Audio(
                sources=["microphone"],
                type="filepath",
                label="Record your message",
                elem_classes=["audio-container"]
            )
            
            with gr.Row():
                process_btn = gr.Button("πŸ’¬ Process Audio", variant="primary", size="lg")
                clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
            
        with gr.Column(scale=2):
            gr.Markdown("### πŸ“ Conversation")
            transcription_output = gr.Textbox(
                label="What you said",
                lines=2,
                interactive=False,
                placeholder="Your speech will appear here..."
            )
            
            audio_output = gr.Audio(
                label="πŸ”Š Maya's Response",
                interactive=False,
                autoplay=True
            )
            
            conversation_history = gr.Textbox(
                label="πŸ’­ Conversation History",
                lines=12,
                interactive=False,
                placeholder="Conversation history will appear here...",
                show_copy_button=True
            )
    
    # Event handlers
    process_btn.click(
        fn=process_audio,
        inputs=[audio_input],
        outputs=[transcription_output, audio_output, conversation_history]
    )
    
    clear_btn.click(
        fn=clear_chat,
        outputs=[transcription_output, audio_output, conversation_history]
    )
    
    # Auto-process when audio is uploaded/recorded
    audio_input.stop_recording(
        fn=process_audio,
        inputs=[audio_input],
        outputs=[transcription_output, audio_output, conversation_history]
    )
    
    # Initialize with greeting
    demo.load(
        fn=greet,
        outputs=[transcription_output, audio_output, conversation_history]
    )

# Launch the app - FIXED: Removed show_tips parameter
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
        quiet=True
    )