iajitpanday commited on
Commit
b26f4d4
·
verified ·
1 Parent(s): 6bf4419

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +421 -0
app.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import uuid
4
+ import json
5
+ import time
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import whisper
10
+ import mysql.connector
11
+ from mysql.connector import pooling
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+ from pydub import AudioSegment
14
+ import tempfile
15
+ import hashlib
16
+ import datetime
17
+ import secrets
18
+
19
+ # Initialize models (lightweight versions for Spaces)
20
+ ASR_MODEL = "base" # Smaller Whisper model
21
+ NLU_MODEL = "facebook/blenderbot-400M-distill" # Smaller conversation model
22
+
23
+ # Database configuration
24
+ DB_CONFIG = {
25
+ "host": os.environ.get("DB_HOST", "localhost"),
26
+ "user": os.environ.get("DB_USER", "voicebot_user"),
27
+ "password": os.environ.get("DB_PASSWORD", "password"),
28
+ "database": os.environ.get("DB_NAME", "voicebot"),
29
+ "pool_name": "voicebot_pool",
30
+ "pool_size": 5
31
+ }
32
+
33
+ # Create connection pool
34
+ try:
35
+ cnx_pool = mysql.connector.pooling.MySQLConnectionPool(**DB_CONFIG)
36
+ print("Database connection pool created successfully")
37
+ except Exception as e:
38
+ print(f"Error creating database pool: {e}")
39
+ # Use in-memory dictionary as fallback
40
+ print("Using in-memory storage as fallback")
41
+ in_memory_db = {"clients": {}, "conversations": {}}
42
+
43
+ # Initialize models
44
+ print("Loading ASR model...")
45
+ asr_model = whisper.load_model(ASR_MODEL)
46
+ print("ASR model loaded")
47
+
48
+ print("Loading NLU model...")
49
+ tokenizer = AutoTokenizer.from_pretrained(NLU_MODEL)
50
+ nlu_model = AutoModelForCausalLM.from_pretrained(NLU_MODEL)
51
+ print("NLU model loaded")
52
+
53
+ # Database schema initialization
54
+ def initialize_database():
55
+ try:
56
+ conn = cnx_pool.get_connection()
57
+ cursor = conn.cursor()
58
+
59
+ # Create tables if they don't exist
60
+ cursor.execute("""
61
+ CREATE TABLE IF NOT EXISTS clients (
62
+ id INT AUTO_INCREMENT PRIMARY KEY,
63
+ name VARCHAR(255) NOT NULL,
64
+ email VARCHAR(255) NOT NULL UNIQUE,
65
+ phone VARCHAR(50),
66
+ api_key VARCHAR(64) NOT NULL UNIQUE,
67
+ pbx_type ENUM('Asterisk', 'FreeSwitch', '3CX', 'Nextiva', 'Other'),
68
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
69
+ )
70
+ """)
71
+
72
+ cursor.execute("""
73
+ CREATE TABLE IF NOT EXISTS conversations (
74
+ id INT AUTO_INCREMENT PRIMARY KEY,
75
+ client_id INT,
76
+ caller_id VARCHAR(50),
77
+ start_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
78
+ end_time TIMESTAMP NULL,
79
+ transcript TEXT,
80
+ FOREIGN KEY (client_id) REFERENCES clients(id)
81
+ )
82
+ """)
83
+
84
+ conn.commit()
85
+ print("Database initialized successfully")
86
+
87
+ except Exception as e:
88
+ print(f"Error initializing database: {e}")
89
+ finally:
90
+ if 'cursor' in locals():
91
+ cursor.close()
92
+ if 'conn' in locals():
93
+ conn.close()
94
+
95
+ # Initialize database on startup
96
+ initialize_database()
97
+
98
+ # API Key Management
99
+ def generate_api_key():
100
+ """Generate a secure API key"""
101
+ return hashlib.sha256(secrets.token_bytes(32)).hexdigest()
102
+
103
+ def create_client(name, email, phone, pbx_type):
104
+ """Create a new client and generate API key"""
105
+ api_key = generate_api_key()
106
+
107
+ try:
108
+ conn = cnx_pool.get_connection()
109
+ cursor = conn.cursor()
110
+
111
+ query = """
112
+ INSERT INTO clients (name, email, phone, api_key, pbx_type)
113
+ VALUES (%s, %s, %s, %s, %s)
114
+ """
115
+ cursor.execute(query, (name, email, phone, api_key, pbx_type))
116
+ conn.commit()
117
+
118
+ return {"success": True, "api_key": api_key}
119
+ except Exception as e:
120
+ print(f"Error creating client: {e}")
121
+ # Fallback to in-memory storage
122
+ if 'in_memory_db' in globals():
123
+ client_id = str(uuid.uuid4())
124
+ in_memory_db["clients"][client_id] = {
125
+ "name": name,
126
+ "email": email,
127
+ "phone": phone,
128
+ "api_key": api_key,
129
+ "pbx_type": pbx_type,
130
+ "created_at": datetime.datetime.now().isoformat()
131
+ }
132
+ return {"success": True, "api_key": api_key}
133
+ return {"success": False, "error": str(e)}
134
+ finally:
135
+ if 'cursor' in locals():
136
+ cursor.close()
137
+ if 'conn' in locals():
138
+ conn.close()
139
+
140
+ def validate_api_key(api_key):
141
+ """Validate an API key and return client details"""
142
+ try:
143
+ conn = cnx_pool.get_connection()
144
+ cursor = conn.cursor(dictionary=True)
145
+
146
+ query = "SELECT * FROM clients WHERE api_key = %s"
147
+ cursor.execute(query, (api_key,))
148
+ client = cursor.fetchone()
149
+
150
+ return client
151
+ except Exception as e:
152
+ print(f"Error validating API key: {e}")
153
+ # Fallback to in-memory storage
154
+ if 'in_memory_db' in globals():
155
+ for client_id, client in in_memory_db["clients"].items():
156
+ if client["api_key"] == api_key:
157
+ return client
158
+ return None
159
+ finally:
160
+ if 'cursor' in locals():
161
+ cursor.close()
162
+ if 'conn' in locals():
163
+ conn.close()
164
+
165
+ # Voice Processing Functions
166
+ def transcribe_audio(audio_array, sample_rate):
167
+ """Transcribe audio using Whisper"""
168
+ # Convert audio array to float32 if needed
169
+ if audio_array.dtype != np.float32:
170
+ audio_array = audio_array.astype(np.float32) / 32768.0 # Normalize 16-bit PCM
171
+
172
+ # Get temporary file
173
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
174
+ filename = temp_file.name
175
+
176
+ # Convert and save audio
177
+ audio_segment = AudioSegment(
178
+ audio_array.tobytes(),
179
+ frame_rate=sample_rate,
180
+ sample_width=audio_array.dtype.itemsize,
181
+ channels=1
182
+ )
183
+ audio_segment.export(filename, format="wav")
184
+
185
+ # Transcribe with Whisper
186
+ result = asr_model.transcribe(filename)
187
+
188
+ # Clean up
189
+ os.unlink(filename)
190
+
191
+ return result["text"]
192
+
193
+ def generate_response(text):
194
+ """Generate a response using the NLU model"""
195
+ inputs = tokenizer(text, return_tensors="pt")
196
+
197
+ # Generate a response
198
+ with torch.no_grad():
199
+ outputs = nlu_model.generate(
200
+ inputs["input_ids"],
201
+ max_length=100,
202
+ num_return_sequences=1,
203
+ temperature=0.7,
204
+ top_k=50,
205
+ top_p=0.95,
206
+ pad_token_id=tokenizer.eos_token_id
207
+ )
208
+
209
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
210
+ return response
211
+
212
+ def log_conversation(client_id, caller_id, transcript):
213
+ """Log a conversation to the database"""
214
+ try:
215
+ conn = cnx_pool.get_connection()
216
+ cursor = conn.cursor()
217
+
218
+ query = """
219
+ INSERT INTO conversations (client_id, caller_id, transcript)
220
+ VALUES (%s, %s, %s)
221
+ """
222
+ cursor.execute(query, (client_id, caller_id, json.dumps(transcript)))
223
+ conn.commit()
224
+
225
+ return True
226
+ except Exception as e:
227
+ print(f"Error logging conversation: {e}")
228
+ # Fallback to in-memory storage
229
+ if 'in_memory_db' in globals():
230
+ conv_id = str(uuid.uuid4())
231
+ in_memory_db["conversations"][conv_id] = {
232
+ "client_id": client_id,
233
+ "caller_id": caller_id,
234
+ "start_time": datetime.datetime.now().isoformat(),
235
+ "transcript": transcript
236
+ }
237
+ return False
238
+ finally:
239
+ if 'cursor' in locals():
240
+ cursor.close()
241
+ if 'conn' in locals():
242
+ conn.close()
243
+
244
+ # Voice Bot processing function
245
+ def process_voice_interaction(audio, api_key, caller_id="unknown"):
246
+ """Process a voice interaction with the bot"""
247
+ # Validate API key
248
+ client = validate_api_key(api_key)
249
+ if not client:
250
+ return {"error": "Invalid API key"}
251
+
252
+ # Process the audio
253
+ try:
254
+ transcription = transcribe_audio(audio[0], audio[1])
255
+ response_text = generate_response(transcription)
256
+
257
+ # Log the conversation
258
+ transcript = {
259
+ "timestamp": time.time(),
260
+ "caller_id": caller_id,
261
+ "user_input": transcription,
262
+ "bot_response": response_text
263
+ }
264
+
265
+ log_conversation(client["id"], caller_id, transcript)
266
+
267
+ return {
268
+ "success": True,
269
+ "transcription": transcription,
270
+ "response": response_text
271
+ }
272
+ except Exception as e:
273
+ print(f"Error processing voice interaction: {e}")
274
+ return {"error": str(e)}
275
+
276
+ # Admin functions
277
+ def admin_create_client(name, email, phone, pbx_type):
278
+ """Admin interface to create a client"""
279
+ if not name or not email:
280
+ return {"error": "Name and email are required"}
281
+
282
+ result = create_client(name, email, phone, pbx_type)
283
+ if result["success"]:
284
+ return {"success": True, "message": f"Client created with API key: {result['api_key']}"}
285
+ else:
286
+ return {"error": result.get("error", "Unknown error")}
287
+
288
+ def admin_get_clients():
289
+ """Admin interface to get all clients"""
290
+ try:
291
+ conn = cnx_pool.get_connection()
292
+ cursor = conn.cursor(dictionary=True)
293
+
294
+ query = "SELECT id, name, email, phone, pbx_type, created_at FROM clients"
295
+ cursor.execute(query)
296
+ clients = cursor.fetchall()
297
+
298
+ return {"success": True, "clients": clients}
299
+ except Exception as e:
300
+ print(f"Error getting clients: {e}")
301
+ # Fallback to in-memory
302
+ if 'in_memory_db' in globals():
303
+ return {"success": True, "clients": list(in_memory_db["clients"].values())}
304
+ return {"error": str(e)}
305
+ finally:
306
+ if 'cursor' in locals():
307
+ cursor.close()
308
+ if 'conn' in locals():
309
+ conn.close()
310
+
311
+ def admin_get_conversations():
312
+ """Admin interface to get all conversations"""
313
+ try:
314
+ conn = cnx_pool.get_connection()
315
+ cursor = conn.cursor(dictionary=True)
316
+
317
+ query = """
318
+ SELECT c.id, cl.name as client_name, c.caller_id, c.start_time, c.end_time, c.transcript
319
+ FROM conversations c
320
+ JOIN clients cl ON c.client_id = cl.id
321
+ ORDER BY c.start_time DESC
322
+ LIMIT 100
323
+ """
324
+ cursor.execute(query)
325
+ conversations = cursor.fetchall()
326
+
327
+ # Parse transcript JSON
328
+ for conv in conversations:
329
+ if conv["transcript"]:
330
+ try:
331
+ conv["transcript"] = json.loads(conv["transcript"])
332
+ except:
333
+ pass
334
+
335
+ return {"success": True, "conversations": conversations}
336
+ except Exception as e:
337
+ print(f"Error getting conversations: {e}")
338
+ # Fallback to in-memory
339
+ if 'in_memory_db' in globals():
340
+ return {"success": True, "conversations": list(in_memory_db["conversations"].values())}
341
+ return {"error": str(e)}
342
+ finally:
343
+ if 'cursor' in locals():
344
+ cursor.close()
345
+ if 'conn' in locals():
346
+ conn.close()
347
+
348
+ # Gradio Interface
349
+ def build_gradio_interface():
350
+ # Admin section
351
+ with gr.Blocks() as admin_interface:
352
+ gr.Markdown("# Voice Bot Admin Dashboard")
353
+
354
+ with gr.Tab("Create Client"):
355
+ with gr.Row():
356
+ client_name = gr.Textbox(label="Client Name")
357
+ client_email = gr.Textbox(label="Email")
358
+ with gr.Row():
359
+ client_phone = gr.Textbox(label="Phone Number")
360
+ client_pbx = gr.Dropdown(label="PBX Type", choices=["Asterisk", "FreeSwitch", "3CX", "Nextiva", "Other"])
361
+ create_btn = gr.Button("Create Client")
362
+ create_output = gr.JSON(label="Result")
363
+
364
+ create_btn.click(
365
+ admin_create_client,
366
+ inputs=[client_name, client_email, client_phone, client_pbx],
367
+ outputs=create_output
368
+ )
369
+
370
+ with gr.Tab("View Clients"):
371
+ refresh_clients_btn = gr.Button("Refresh Client List")
372
+ clients_output = gr.JSON(label="Clients")
373
+
374
+ refresh_clients_btn.click(
375
+ admin_get_clients,
376
+ inputs=[],
377
+ outputs=clients_output
378
+ )
379
+
380
+ with gr.Tab("View Conversations"):
381
+ refresh_convs_btn = gr.Button("Refresh Conversations")
382
+ convs_output = gr.JSON(label="Recent Conversations")
383
+
384
+ refresh_convs_btn.click(
385
+ admin_get_conversations,
386
+ inputs=[],
387
+ outputs=convs_output
388
+ )
389
+
390
+ # Test interface for voice bot API
391
+ with gr.Blocks() as test_interface:
392
+ gr.Markdown("# Voice Bot Test Interface")
393
+
394
+ with gr.Row():
395
+ api_key_input = gr.Textbox(label="API Key")
396
+ caller_id_input = gr.Textbox(label="Caller ID (optional)", value="test_caller")
397
+
398
+ audio_input = gr.Audio(label="Speak", type="numpy", source="microphone")
399
+ test_btn = gr.Button("Process Audio")
400
+
401
+ output_json = gr.JSON(label="Result")
402
+
403
+ test_btn.click(
404
+ process_voice_interaction,
405
+ inputs=[audio_input, api_key_input, caller_id_input],
406
+ outputs=output_json
407
+ )
408
+
409
+ # Create a tabbed interface
410
+ demo = gr.TabbedInterface(
411
+ [admin_interface, test_interface],
412
+ ["Admin Dashboard", "Test Interface"]
413
+ )
414
+
415
+ return demo
416
+
417
+ # Create and launch the interface
418
+ interface = build_gradio_interface()
419
+
420
+ # Launch for Hugging Face Spaces
421
+ interface.launch()