iajitpanday commited on
Commit
e6d8c7b
·
verified ·
1 Parent(s): 96245ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -48
app.py CHANGED
@@ -1,60 +1,245 @@
1
  # app.py
2
  import gradio as gr
3
- from twilio.rest import Client
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
- import whisper
6
- import pyttsx3
7
- import io
8
- import wave
9
-
10
- # Initialize models
11
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
12
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
13
- whisper_model = whisper.load_model("base")
 
 
 
14
 
15
- # Initialize Twilio client
16
- account_sid = os.environ["TWILIO_ACCOUNT_SID"]
17
- auth_token = os.environ["TWILIO_AUTH_TOKEN"]
18
- client = Client(account_sid, auth_token)
19
 
20
- def process_voice_call(audio_data):
21
- # 1. Speech to Text using Whisper
22
- audio = whisper.load_audio(audio_data)
23
- result = whisper_model.transcribe(audio)
24
- user_text = result["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # 2. Generate response using LLM
27
- new_user_input_ids = tokenizer.encode(user_text + tokenizer.eos_token,
28
- return_tensors='pt')
29
- chat_history_ids = model.generate(
30
- new_user_input_ids,
31
- max_length=1000,
32
- num_beams=5,
33
- no_repeat_ngram_size=2,
34
- temperature=0.7,
35
- do_sample=True,
36
- top_k=50,
37
- top_p=0.95,
38
- pad_token_id=tokenizer.eos_token_id
39
- )
40
 
41
- response = tokenizer.decode(chat_history_ids[:, new_user_input_ids.shape[-1]:][0],
42
- skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # 3. Text to Speech
45
- tts_engine = pyttsx3.init()
46
- tts_engine.save_to_file(response, "response.wav")
47
- tts_engine.runAndWait()
 
 
 
 
 
 
 
 
 
48
 
49
- return response, "response.wav"
 
 
 
 
 
 
50
 
51
  # Create Gradio interface
52
- iface = gr.Interface(
53
- fn=process_voice_call,
54
- inputs=gr.Audio(type="filepath"),
55
- outputs=[gr.Textbox(), gr.Audio()],
56
- title="Voice AI Customer Support"
57
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
59
  if __name__ == "__main__":
60
- iface.launch()
 
 
1
  # app.py
2
  import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ AutoProcessor,
9
+ AutoModelForSpeechSeq2Seq,
10
+ pipeline
11
+ )
12
+ from TTS.api import TTS
13
+ import tempfile
14
+ import os
15
+ import json
16
+ import logging
17
 
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
 
21
 
22
+ class VoiceAIBot:
23
+ def __init__(self):
24
+ # Initialize models
25
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ logger.info(f"Using device: {self.device}")
27
+
28
+ # Speech Recognition Model (Whisper)
29
+ self.asr_model = pipeline(
30
+ "automatic-speech-recognition",
31
+ model="openai/whisper-base",
32
+ device=self.device
33
+ )
34
+
35
+ # Conversation Model (DialoGPT for customer support)
36
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
37
+ self.conversation_model = AutoModelForCausalLM.from_pretrained(
38
+ "microsoft/DialoGPT-medium"
39
+ ).to(self.device)
40
+
41
+ # Text-to-Speech Model
42
+ self.tts = TTS("tts_models/en/ljspeech/tacotron2-DDC")
43
+
44
+ # Customer support knowledge base
45
+ self.knowledge_base = {
46
+ "order status": "I can help you check your order status. Please provide your order number.",
47
+ "return policy": "Our return policy allows returns within 30 days of purchase. Items must be unused and in original packaging.",
48
+ "shipping": "Standard shipping takes 3-5 business days. Express shipping takes 1-2 business days.",
49
+ "payment": "We accept all major credit cards, PayPal, and Apple Pay.",
50
+ "business hours": "We're open Monday-Friday, 9 AM to 6 PM EST.",
51
+ "technical support": "I can help with basic technical issues. For complex problems, I'll connect you with our technical team.",
52
+ }
53
+
54
+ # Conversation history
55
+ self.conversation_history = []
56
+
57
+ def transcribe_audio(self, audio_file):
58
+ """Convert speech to text using Whisper"""
59
+ try:
60
+ result = self.asr_model(audio_file)
61
+ transcription = result["text"]
62
+ logger.info(f"Transcription: {transcription}")
63
+ return transcription
64
+ except Exception as e:
65
+ logger.error(f"Transcription error: {e}")
66
+ return "Sorry, I couldn't understand the audio."
67
+
68
+ def check_knowledge_base(self, user_input):
69
+ """Check if query matches knowledge base"""
70
+ user_input_lower = user_input.lower()
71
+ for keyword, response in self.knowledge_base.items():
72
+ if keyword in user_input_lower:
73
+ return response
74
+ return None
75
+
76
+ def generate_response(self, user_input):
77
+ """Generate AI response based on user input"""
78
+ # First check knowledge base
79
+ kb_response = self.check_knowledge_base(user_input)
80
+ if kb_response:
81
+ return kb_response
82
+
83
+ # If not found in knowledge base, use conversation model
84
+ try:
85
+ # Add current conversation to history
86
+ self.conversation_history.append(user_input)
87
+
88
+ # Prepare input for the model
89
+ input_text = "Customer: " + user_input + " Agent:"
90
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.device)
91
+
92
+ # Generate response
93
+ with torch.no_grad():
94
+ output = self.conversation_model.generate(
95
+ input_ids,
96
+ max_length=150,
97
+ num_beams=5,
98
+ temperature=0.7,
99
+ do_sample=True,
100
+ top_k=50,
101
+ top_p=0.95,
102
+ pad_token_id=self.tokenizer.eos_token_id,
103
+ no_repeat_ngram_size=2
104
+ )
105
+
106
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
107
+ # Extract only the agent's response
108
+ agent_response = response.split("Agent:")[-1].strip()
109
+
110
+ # Add to conversation history
111
+ self.conversation_history.append(agent_response)
112
+
113
+ # Keep conversation history manageable
114
+ if len(self.conversation_history) > 10:
115
+ self.conversation_history = self.conversation_history[-10:]
116
+
117
+ logger.info(f"Generated response: {agent_response}")
118
+ return agent_response
119
+
120
+ except Exception as e:
121
+ logger.error(f"Response generation error: {e}")
122
+ return "I'm sorry, I'm having trouble processing your request right now. Can you please try again?"
123
 
124
+ def text_to_speech(self, text):
125
+ """Convert text to speech"""
126
+ try:
127
+ # Create temporary file for audio output
128
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
129
+ self.tts.tts_to_file(text=text, file_path=tmp_file.name)
130
+ return tmp_file.name
131
+ except Exception as e:
132
+ logger.error(f"TTS error: {e}")
133
+ return None
 
 
 
 
134
 
135
+ def process_voice_input(self, audio_file):
136
+ """Process complete voice interaction"""
137
+ if audio_file is None:
138
+ return "Please provide an audio input.", None, self.format_conversation_history()
139
+
140
+ # 1. Transcribe speech to text
141
+ user_text = self.transcribe_audio(audio_file)
142
+
143
+ # 2. Generate AI response
144
+ ai_response = self.generate_response(user_text)
145
+
146
+ # 3. Convert response to speech
147
+ audio_response = self.text_to_speech(ai_response)
148
+
149
+ # 4. Return all outputs
150
+ return user_text, audio_response, self.format_conversation_history()
151
 
152
+ def format_conversation_history(self):
153
+ """Format conversation history for display"""
154
+ if not self.conversation_history:
155
+ return "No conversation history yet."
156
+
157
+ formatted = "Conversation History:\n\n"
158
+ for i in range(0, len(self.conversation_history), 2):
159
+ if i < len(self.conversation_history):
160
+ formatted += f"Customer: {self.conversation_history[i]}\n"
161
+ if i + 1 < len(self.conversation_history):
162
+ formatted += f"Agent: {self.conversation_history[i + 1]}\n\n"
163
+
164
+ return formatted
165
 
166
+ def clear_history(self):
167
+ """Clear conversation history"""
168
+ self.conversation_history = []
169
+ return "Conversation history cleared.", self.format_conversation_history()
170
+
171
+ # Initialize the bot
172
+ bot = VoiceAIBot()
173
 
174
  # Create Gradio interface
175
+ def create_interface():
176
+ with gr.Blocks(title="Voice AI Customer Support Bot") as demo:
177
+ gr.Markdown("# 🎤 Voice AI Customer Support Bot")
178
+ gr.Markdown("Upload audio or record your voice to interact with the AI customer support agent.")
179
+
180
+ with gr.Row():
181
+ with gr.Column(scale=1):
182
+ # Audio input
183
+ audio_input = gr.Audio(
184
+ sources=["microphone", "upload"],
185
+ type="filepath",
186
+ label="Speak your question"
187
+ )
188
+
189
+ # Process button
190
+ process_btn = gr.Button("Process Voice", variant="primary")
191
+
192
+ # Clear history button
193
+ clear_btn = gr.Button("Clear History", variant="secondary")
194
+
195
+ with gr.Column(scale=1):
196
+ # Transcribed text output
197
+ transcription_output = gr.Textbox(
198
+ label="What you said:",
199
+ interactive=False
200
+ )
201
+
202
+ # Audio response output
203
+ audio_output = gr.Audio(
204
+ label="AI Response (Audio)",
205
+ interactive=False
206
+ )
207
+
208
+ # Conversation history
209
+ with gr.Row():
210
+ conversation_history = gr.Textbox(
211
+ label="Conversation History",
212
+ lines=10,
213
+ interactive=False
214
+ )
215
+
216
+ # Event handlers
217
+ process_btn.click(
218
+ fn=bot.process_voice_input,
219
+ inputs=[audio_input],
220
+ outputs=[transcription_output, audio_output, conversation_history]
221
+ )
222
+
223
+ clear_btn.click(
224
+ fn=bot.clear_history,
225
+ inputs=[],
226
+ outputs=[transcription_output, conversation_history]
227
+ )
228
+
229
+ # Example usage
230
+ gr.Markdown("## Example Queries")
231
+ gr.Markdown("""
232
+ Try asking about:
233
+ - Order status
234
+ - Return policy
235
+ - Shipping information
236
+ - Business hours
237
+ - Technical support
238
+ """)
239
+
240
+ return demo
241
 
242
+ # Launch the interface
243
  if __name__ == "__main__":
244
+ demo = create_interface()
245
+ demo.launch(share=True)