Spaces:
Build error
Build error
| # app.py | |
| import os | |
| import uuid | |
| import json | |
| import time | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import whisper | |
| import mysql.connector | |
| from mysql.connector import pooling | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from pydub import AudioSegment | |
| import tempfile | |
| import hashlib | |
| import datetime | |
| import secrets | |
| import traceback | |
| # Initialize models (lightweight versions for Spaces) | |
| ASR_MODEL = "base" # Smaller Whisper model | |
| NLU_MODEL = "facebook/blenderbot-400M-distill" # Smaller conversation model | |
| # Database configuration | |
| DB_CONFIG = { | |
| "host": "hopper.proxy.rlwy.net", | |
| "port": 16751, | |
| "user": "root", | |
| "password": "svLvVDyJwyvWsAxTAEkrMPqkzLBkLMrD", | |
| "database": "railway", | |
| "pool_name": "voicebot_pool", | |
| "pool_size": 5 | |
| } | |
| # Create connection pool | |
| try: | |
| print(f"Attempting to connect to MySQL at {DB_CONFIG['host']}:{DB_CONFIG['port']}...") | |
| cnx_pool = mysql.connector.pooling.MySQLConnectionPool(**DB_CONFIG) | |
| print("Database connection pool created successfully") | |
| # Test the connection by getting one | |
| test_conn = cnx_pool.get_connection() | |
| if test_conn.is_connected(): | |
| print(f"Successfully connected to {DB_CONFIG['database']} database") | |
| test_conn.close() | |
| except Exception as e: | |
| print(f"Error creating database pool: {e}") | |
| # Use in-memory dictionary as fallback | |
| print("Using in-memory storage as fallback") | |
| in_memory_db = {"clients": {}, "conversations": {}} | |
| # Initialize models | |
| print("Loading ASR model...") | |
| asr_model = whisper.load_model(ASR_MODEL) | |
| print("ASR model loaded") | |
| print("Loading NLU model...") | |
| tokenizer = AutoTokenizer.from_pretrained(NLU_MODEL) | |
| nlu_model = AutoModelForCausalLM.from_pretrained(NLU_MODEL) | |
| print("NLU model loaded") | |
| # Database schema initialization | |
| def initialize_database(): | |
| try: | |
| conn = cnx_pool.get_connection() | |
| cursor = conn.cursor() | |
| # Create tables if they don't exist | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS clients ( | |
| id INT AUTO_INCREMENT PRIMARY KEY, | |
| name VARCHAR(255) NOT NULL, | |
| email VARCHAR(255) NOT NULL UNIQUE, | |
| phone VARCHAR(50), | |
| api_key VARCHAR(64) NOT NULL UNIQUE, | |
| pbx_type ENUM('Asterisk', 'FreeSwitch', '3CX', 'Nextiva', 'Other'), | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS conversations ( | |
| id INT AUTO_INCREMENT PRIMARY KEY, | |
| client_id INT, | |
| caller_id VARCHAR(50), | |
| start_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| end_time TIMESTAMP NULL, | |
| transcript TEXT, | |
| FOREIGN KEY (client_id) REFERENCES clients(id) | |
| ) | |
| """) | |
| conn.commit() | |
| print("Database initialized successfully") | |
| except Exception as e: | |
| print(f"Error initializing database: {e}") | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals() and conn.is_connected(): | |
| conn.close() | |
| # Initialize database on startup | |
| initialize_database() | |
| # API Key Management | |
| def generate_api_key(): | |
| """Generate a secure API key""" | |
| return hashlib.sha256(secrets.token_bytes(32)).hexdigest() | |
| def create_client(name, email, phone, pbx_type): | |
| """Create a new client and generate API key""" | |
| api_key = generate_api_key() | |
| try: | |
| conn = cnx_pool.get_connection() | |
| cursor = conn.cursor() | |
| query = """ | |
| INSERT INTO clients (name, email, phone, api_key, pbx_type) | |
| VALUES (%s, %s, %s, %s, %s) | |
| """ | |
| cursor.execute(query, (name, email, phone, api_key, pbx_type)) | |
| conn.commit() | |
| return {"success": True, "api_key": api_key} | |
| except Exception as e: | |
| print(f"Error creating client: {e}") | |
| # Fallback to in-memory storage | |
| if 'in_memory_db' in globals(): | |
| client_id = str(uuid.uuid4()) | |
| in_memory_db["clients"][client_id] = { | |
| "name": name, | |
| "email": email, | |
| "phone": phone, | |
| "api_key": api_key, | |
| "pbx_type": pbx_type, | |
| "created_at": datetime.datetime.now().isoformat() | |
| } | |
| return {"success": True, "api_key": api_key} | |
| return {"success": False, "error": str(e)} | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals() and conn.is_connected(): | |
| conn.close() | |
| def validate_api_key(api_key): | |
| """Validate an API key and return client details""" | |
| if not api_key: | |
| return None | |
| try: | |
| conn = cnx_pool.get_connection() | |
| cursor = conn.cursor(dictionary=True) | |
| query = "SELECT * FROM clients WHERE api_key = %s" | |
| cursor.execute(query, (api_key,)) | |
| client = cursor.fetchone() | |
| return client | |
| except Exception as e: | |
| print(f"Error validating API key: {e}") | |
| # Fallback to in-memory storage | |
| if 'in_memory_db' in globals(): | |
| for client_id, client in in_memory_db["clients"].items(): | |
| if client["api_key"] == api_key: | |
| return client | |
| return None | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals() and conn.is_connected(): | |
| conn.close() | |
| # Update the transcribe_audio function to fix the numpy array boolean ambiguity error | |
| def transcribe_audio(audio, sample_rate=None): | |
| """Transcribe audio using Whisper""" | |
| try: | |
| # Check if audio input is empty | |
| if audio is None: | |
| print("Error: Audio input is None") | |
| return "Error: No audio data received" | |
| # Debug the input | |
| print(f"Audio input type: {type(audio)}") | |
| # Handle various input formats | |
| if isinstance(audio, tuple) and len(audio) == 2: | |
| print("Audio is a tuple, extracting array and sample rate") | |
| audio_array, sample_rate = audio | |
| else: | |
| print("Audio is not a tuple") | |
| audio_array = audio | |
| # If sample_rate is None, provide a default value | |
| if sample_rate is None: | |
| sample_rate = 16000 # Common default sample rate | |
| print(f"Using default sample rate: {sample_rate}") | |
| # More detailed debugging | |
| print(f"Audio array type: {type(audio_array)}") | |
| if hasattr(audio_array, 'shape'): | |
| print(f"Audio array shape: {audio_array.shape}") | |
| print(f"Audio array dtype: {audio_array.dtype}") | |
| # Guard against invalid input | |
| if audio_array is None: | |
| print("Empty audio data received") | |
| return "Error: No audio data received" | |
| if isinstance(audio_array, (list, np.ndarray)): | |
| if len(audio_array) == 0: | |
| print("Empty audio array received") | |
| return "Error: No audio data received" | |
| # Convert to numpy array if needed | |
| if not isinstance(audio_array, np.ndarray): | |
| print("Converting to numpy array") | |
| audio_array = np.array(audio_array, dtype=np.float32) | |
| # Normalize audio if needed - FIX: Don't use the array in a boolean context | |
| max_val = np.max(np.abs(audio_array)) | |
| if np.isscalar(max_val) and max_val > 1.0: # Fix: Check if scalar and then compare | |
| print(f"Normalizing audio values from max {max_val} to [-1.0, 1.0] range") | |
| audio_array = audio_array / max_val | |
| # Get temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
| filename = temp_file.name | |
| print(f"Created temp file: {filename}") | |
| # Convert and save audio | |
| try: | |
| print(f"Creating AudioSegment with sample rate {sample_rate}") | |
| audio_segment = AudioSegment( | |
| audio_array.tobytes(), | |
| frame_rate=sample_rate, | |
| sample_width=audio_array.dtype.itemsize, | |
| channels=1 | |
| ) | |
| print("AudioSegment created, exporting to WAV") | |
| audio_segment.export(filename, format="wav") | |
| print("WAV file created successfully") | |
| except Exception as audio_e: | |
| print(f"Error in audio conversion: {audio_e}") | |
| return f"Error in audio conversion: {str(audio_e)}" | |
| # Transcribe with Whisper | |
| try: | |
| print("Starting transcription with Whisper") | |
| result = asr_model.transcribe(filename) | |
| print("Transcription completed") | |
| transcribed_text = result["text"].strip() | |
| print(f"Transcribed text: {transcribed_text}") | |
| # Return empty message if no text was transcribed | |
| if not transcribed_text: | |
| return "I couldn't hear anything. Please try speaking again." | |
| return transcribed_text | |
| except Exception as whisper_e: | |
| print(f"Error in Whisper transcription: {whisper_e}") | |
| return f"Error in transcription: {str(whisper_e)}" | |
| finally: | |
| # Clean up | |
| try: | |
| os.unlink(filename) | |
| print(f"Deleted temp file: {filename}") | |
| except Exception as e: | |
| print(f"Warning: Could not delete temp file {filename}: {e}") | |
| except Exception as e: | |
| print(f"Error transcribing audio: {e}") | |
| traceback.print_exc() | |
| return f"Error processing audio: {str(e)}" | |
| def generate_response(text): | |
| """Generate a response using the NLU model""" | |
| try: | |
| if not text or text.startswith("Error:"): | |
| return "I'm sorry, I couldn't understand what you said. Could you please try again?" | |
| inputs = tokenizer(text, return_tensors="pt") | |
| # Generate a response | |
| with torch.no_grad(): | |
| outputs = nlu_model.generate( | |
| inputs["input_ids"], | |
| max_length=100, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| except Exception as e: | |
| print(f"Error generating response: {e}") | |
| traceback.print_exc() | |
| return "I'm sorry, I encountered an error processing your request." | |
| def log_conversation(client_id, caller_id, transcript): | |
| """Log a conversation to the database""" | |
| try: | |
| conn = cnx_pool.get_connection() | |
| cursor = conn.cursor() | |
| query = """ | |
| INSERT INTO conversations (client_id, caller_id, transcript) | |
| VALUES (%s, %s, %s) | |
| """ | |
| cursor.execute(query, (client_id, caller_id, json.dumps(transcript))) | |
| conn.commit() | |
| return True | |
| except Exception as e: | |
| print(f"Error logging conversation: {e}") | |
| # Fallback to in-memory storage | |
| if 'in_memory_db' in globals(): | |
| conv_id = str(uuid.uuid4()) | |
| in_memory_db["conversations"][conv_id] = { | |
| "client_id": client_id, | |
| "caller_id": caller_id, | |
| "start_time": datetime.datetime.now().isoformat(), | |
| "transcript": transcript | |
| } | |
| return False | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals() and conn.is_connected(): | |
| conn.close() | |
| def process_voice_interaction(audio, api_key, caller_id="unknown"): | |
| """Process a voice interaction with the bot""" | |
| # Validate API key | |
| client = validate_api_key(api_key) | |
| if not client: | |
| return {"error": "Invalid API key"} | |
| # Check audio data | |
| if audio is None: | |
| return {"error": "No audio data received"} | |
| # Process the audio | |
| try: | |
| print(f"Received audio data type: {type(audio)}") | |
| # Process audio data | |
| transcription = transcribe_audio(audio) | |
| # Log the transcription for debugging | |
| print(f"Transcription: {transcription}") | |
| # Generate response | |
| response_text = generate_response(transcription) | |
| # Log the conversation | |
| transcript = { | |
| "timestamp": time.time(), | |
| "caller_id": caller_id, | |
| "user_input": transcription, | |
| "bot_response": response_text | |
| } | |
| # Use client ID from database if available, otherwise use API key as identifier | |
| client_id = client.get("id", api_key) | |
| log_conversation(client_id, caller_id, transcript) | |
| return { | |
| "success": True, | |
| "transcription": transcription, | |
| "response": response_text | |
| } | |
| except Exception as e: | |
| print(f"Error processing voice interaction: {e}") | |
| traceback.print_exc() | |
| return {"error": str(e)} | |
| # Admin functions | |
| def admin_create_client(name, email, phone, pbx_type): | |
| """Admin interface to create a client""" | |
| if not name or not email: | |
| return {"error": "Name and email are required"} | |
| result = create_client(name, email, phone, pbx_type) | |
| if result["success"]: | |
| return {"success": True, "message": f"Client created with API key: {result['api_key']}"} | |
| else: | |
| return {"error": result.get("error", "Unknown error")} | |
| def admin_get_clients(): | |
| """Admin interface to get all clients""" | |
| try: | |
| conn = cnx_pool.get_connection() | |
| cursor = conn.cursor(dictionary=True) | |
| query = "SELECT id, name, email, phone, pbx_type, created_at FROM clients" | |
| cursor.execute(query) | |
| clients = cursor.fetchall() | |
| # Convert datetime objects to strings for JSON serialization | |
| for client in clients: | |
| if isinstance(client["created_at"], datetime.datetime): | |
| client["created_at"] = client["created_at"].isoformat() | |
| return {"success": True, "clients": clients} | |
| except Exception as e: | |
| print(f"Error getting clients: {e}") | |
| # Fallback to in-memory | |
| if 'in_memory_db' in globals(): | |
| return {"success": True, "clients": list(in_memory_db["clients"].values())} | |
| return {"error": str(e)} | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals() and conn.is_connected(): | |
| conn.close() | |
| def admin_get_conversations(): | |
| """Admin interface to get all conversations""" | |
| try: | |
| conn = cnx_pool.get_connection() | |
| cursor = conn.cursor(dictionary=True) | |
| query = """ | |
| SELECT c.id, cl.name as client_name, c.caller_id, c.start_time, c.end_time, c.transcript | |
| FROM conversations c | |
| JOIN clients cl ON c.client_id = cl.id | |
| ORDER BY c.start_time DESC | |
| LIMIT 100 | |
| """ | |
| cursor.execute(query) | |
| conversations = cursor.fetchall() | |
| # Convert datetime objects and parse transcript JSON | |
| for conv in conversations: | |
| if isinstance(conv["start_time"], datetime.datetime): | |
| conv["start_time"] = conv["start_time"].isoformat() | |
| if isinstance(conv["end_time"], datetime.datetime): | |
| conv["end_time"] = conv["end_time"].isoformat() | |
| if conv["transcript"]: | |
| try: | |
| conv["transcript"] = json.loads(conv["transcript"]) | |
| except json.JSONDecodeError: | |
| pass | |
| return {"success": True, "conversations": conversations} | |
| except Exception as e: | |
| print(f"Error getting conversations: {e}") | |
| # Fallback to in-memory | |
| if 'in_memory_db' in globals(): | |
| return {"success": True, "conversations": list(in_memory_db["conversations"].values())} | |
| return {"error": str(e)} | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals() and conn.is_connected(): | |
| conn.close() | |
| # Debug function | |
| def debug_audio(audio): | |
| """Debug function to understand audio format""" | |
| try: | |
| if audio is None: | |
| return {"error": "No audio provided"} | |
| result = { | |
| "type": type(audio).__name__, | |
| "is_tuple": isinstance(audio, tuple), | |
| "length": len(audio) if hasattr(audio, "__len__") else "N/A" | |
| } | |
| if isinstance(audio, tuple) and len(audio) == 2: | |
| result["data_type"] = type(audio[0]).__name__ | |
| result["sample_rate"] = audio[1] | |
| if hasattr(audio[0], "shape"): | |
| result["shape"] = audio[0].shape | |
| result["dtype"] = str(audio[0].dtype) | |
| result["min_val"] = float(audio[0].min()) | |
| result["max_val"] = float(audio[0].max()) | |
| return {"debug_info": result} | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"error": str(e)} | |
| def build_gradio_interface(): | |
| # Admin section | |
| with gr.Blocks() as admin_interface: | |
| gr.Markdown("# Voice Bot Admin Dashboard") | |
| with gr.Tab("Create Client"): | |
| with gr.Row(): | |
| client_name = gr.Textbox(label="Client Name") | |
| client_email = gr.Textbox(label="Email") | |
| with gr.Row(): | |
| client_phone = gr.Textbox(label="Phone Number") | |
| client_pbx = gr.Dropdown(label="PBX Type", choices=["Asterisk", "FreeSwitch", "3CX", "Nextiva", "Other"]) | |
| create_btn = gr.Button("Create Client") | |
| create_output = gr.JSON(label="Result") | |
| create_btn.click( | |
| admin_create_client, | |
| inputs=[client_name, client_email, client_phone, client_pbx], | |
| outputs=create_output | |
| ) | |
| with gr.Tab("View Clients"): | |
| refresh_clients_btn = gr.Button("Refresh Client List") | |
| clients_output = gr.JSON(label="Clients") | |
| refresh_clients_btn.click( | |
| admin_get_clients, | |
| inputs=[], | |
| outputs=clients_output | |
| ) | |
| with gr.Tab("View Conversations"): | |
| refresh_convs_btn = gr.Button("Refresh Conversations") | |
| convs_output = gr.JSON(label="Recent Conversations") | |
| refresh_convs_btn.click( | |
| admin_get_conversations, | |
| inputs=[], | |
| outputs=convs_output | |
| ) | |
| # Test interface for voice bot API | |
| with gr.Blocks() as test_interface: | |
| gr.Markdown("# Voice Bot Test Interface") | |
| with gr.Row(): | |
| api_key_input = gr.Textbox(label="API Key") | |
| caller_id_input = gr.Textbox(label="Caller ID (optional)", value="test_caller") | |
| # Conversation history display | |
| conversation_display = gr.Markdown("*Conversation will appear here*") | |
| # Real-time audio input - compatible with older Gradio versions | |
| audio_input = gr.Audio( | |
| label="Speak", | |
| type="numpy" | |
| ) | |
| # State to store conversation history | |
| conversation_state = gr.State([]) | |
| # Function to process audio and update conversation | |
| def process_and_update(audio, api_key, caller_id, conversation_history): | |
| if not api_key: | |
| return "**Error:** API key is required.", conversation_history | |
| if audio is None: | |
| return "*Conversation will appear here*", conversation_history | |
| # Process the audio | |
| result = process_voice_interaction(audio, api_key, caller_id) | |
| # Update conversation history | |
| if "transcription" in result and "response" in result: | |
| # Add new conversation turn | |
| conversation_history.append({ | |
| "user": result["transcription"], | |
| "bot": result["response"] | |
| }) | |
| # Format the conversation as markdown | |
| markdown = "## Conversation\n\n" | |
| for turn in conversation_history: | |
| markdown += f"**You:** {turn['user']}\n\n" | |
| markdown += f"**Bot:** {turn['bot']}\n\n" | |
| return markdown, conversation_history | |
| else: | |
| # If there was an error | |
| error_msg = result.get("error", "Unknown error") | |
| return f"**Error:** {error_msg}", conversation_history | |
| # Submit button for audio processing | |
| submit_btn = gr.Button("Process Audio") | |
| # Event handler for submit button | |
| submit_btn.click( | |
| process_and_update, | |
| inputs=[audio_input, api_key_input, caller_id_input, conversation_state], | |
| outputs=[conversation_display, conversation_state] | |
| ) | |
| # Clear conversation button | |
| clear_btn = gr.Button("Clear Conversation") | |
| def clear_conversation(): | |
| return "*Conversation will appear here*", [] | |
| clear_btn.click( | |
| clear_conversation, | |
| inputs=[], | |
| outputs=[conversation_display, conversation_state] | |
| ) | |
| # Debug interface | |
| with gr.Blocks() as debug_interface: | |
| gr.Markdown("# Debug Interface") | |
| audio_input_debug = gr.Audio(label="Test Audio Input") | |
| debug_btn = gr.Button("Debug Audio Format") | |
| output_json = gr.JSON(label="Debug Info") | |
| debug_btn.click( | |
| debug_audio, | |
| inputs=audio_input_debug, | |
| outputs=output_json | |
| ) | |
| # Create a tabbed interface with all three tabs | |
| demo = gr.TabbedInterface( | |
| [admin_interface, test_interface, debug_interface], | |
| ["Admin Dashboard", "Test Interface", "Debug"] | |
| ) | |
| return demo | |
| # Create and launch the interface | |
| interface = build_gradio_interface() | |
| # Launch for Hugging Face Spaces | |
| interface.launch() |