xdxb commited on
Commit
ce2e6c7
·
verified ·
1 Parent(s): cd0995e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -0
app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides a server that's compatible with OpenAI's TTS API format.
3
+ Main entry point for the Flask application.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import logging
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Dict, Any, Tuple, Optional
12
+ from flask import Flask, request, jsonify, send_file, Response, send_from_directory
13
+ from flask_cors import CORS
14
+ from dotenv import load_dotenv
15
+ from celery.result import AsyncResult
16
+ from celery_worker import celery, process_tts_request
17
+ from werkzeug.exceptions import HTTPException
18
+ from werkzeug.utils import secure_filename
19
+ import os.path
20
+ import re
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
26
+ )
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Load environment variables
30
+ load_dotenv()
31
+
32
+ # App configuration from environment variables
33
+ HOST = os.getenv("HOST", "localhost")
34
+ PORT = int(os.getenv("PORT", "7000"))
35
+ VERIFY_SSL = os.getenv("VERIFY_SSL", "true").lower() != "false"
36
+ MAX_QUEUE_SIZE = int(os.getenv("MAX_QUEUE_SIZE", "100"))
37
+
38
+ # Security configuration
39
+ ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*").split(",")
40
+ MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max request size
41
+
42
+ # Create Flask app
43
+ app = Flask(__name__,
44
+ static_folder='static',
45
+ template_folder='templates')
46
+
47
+ # Configure CORS with specific routes and security
48
+ CORS(app, resources={
49
+ r"/v1/*": {
50
+ "origins": ALLOWED_ORIGINS,
51
+ "methods": ["POST", "OPTIONS"],
52
+ "allow_headers": ["Content-Type", "Authorization"],
53
+ "max_age": 3600
54
+ },
55
+ r"/api/*": {
56
+ "origins": ALLOWED_ORIGINS,
57
+ "methods": ["GET", "OPTIONS"],
58
+ "allow_headers": ["Content-Type"],
59
+ "max_age": 3600
60
+ }
61
+ })
62
+
63
+ # Set maximum content length
64
+ app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH
65
+
66
+ # Voice samples directory
67
+ VOICE_SAMPLES_DIR = Path('voices')
68
+
69
+ def _sanitize_input(text: str) -> str:
70
+ """Sanitize user input to prevent XSS and other attacks"""
71
+ # Remove any HTML tags
72
+ text = re.sub(r'<[^>]+>', '', text)
73
+ # Remove any script tags
74
+ text = re.sub(r'<script.*?</script>', '', text, flags=re.DOTALL)
75
+ # Remove any potentially dangerous characters
76
+ text = re.sub(r'[<>"\']', '', text)
77
+ return text.strip()
78
+
79
+ # Error handlers
80
+ @app.errorhandler(HTTPException)
81
+ def handle_http_error(error: HTTPException) -> Tuple[Dict[str, str], int]:
82
+ """Handle HTTP errors"""
83
+ logger.warning(f"HTTP error: {error.code} - {error.description}")
84
+ return jsonify({"error": error.description}), error.code
85
+
86
+ @app.errorhandler(Exception)
87
+ def handle_generic_error(error: Exception) -> Tuple[Dict[str, str], int]:
88
+ """Handle unexpected errors"""
89
+ logger.error(f"Unexpected error: {str(error)}", exc_info=True)
90
+ return jsonify({"error": "Internal Server Error"}), 500
91
+
92
+ @app.after_request
93
+ def add_security_headers(response: Response) -> Response:
94
+ """Add security headers to all responses"""
95
+ response.headers['X-Content-Type-Options'] = 'nosniff'
96
+ response.headers['X-Frame-Options'] = 'DENY'
97
+ response.headers['X-XSS-Protection'] = '1; mode=block'
98
+ response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
99
+ response.headers['Content-Security-Policy'] = (
100
+ "default-src 'self'; "
101
+ "style-src 'self' https://cdnjs.cloudflare.com https://fonts.googleapis.com; "
102
+ "script-src 'self' https://cdnjs.cloudflare.com; "
103
+ "font-src 'self' https://cdnjs.cloudflare.com https://fonts.gstatic.com; "
104
+ "img-src 'self' data:; "
105
+ "media-src 'self' blob:; "
106
+ "connect-src 'self'"
107
+ )
108
+ return response
109
+
110
+ @app.route('/')
111
+ def index() -> Response:
112
+ """Render the main index page"""
113
+ return send_from_directory('static', 'index.html')
114
+
115
+ @app.route('/static/<path:filename>')
116
+ def serve_static(filename: str) -> Response:
117
+ """Serve static files with correct MIME types"""
118
+ if filename.endswith('.css'):
119
+ return send_from_directory('static', filename, mimetype='text/css')
120
+ elif filename.endswith('.js'):
121
+ return send_from_directory('static', filename, mimetype='application/javascript')
122
+ return send_from_directory('static', filename)
123
+
124
+ def validate_tts_request(body: Dict[str, Any]) -> Tuple[Optional[Dict[str, Any]], Optional[str], Optional[int]]:
125
+ """Validate TTS request parameters"""
126
+ try:
127
+ # Validate required parameters
128
+ if 'input' not in body or 'voice' not in body:
129
+ return None, "Missing required parameters: input and voice", 400
130
+
131
+ # Sanitize input
132
+ sanitized_input = _sanitize_input(body['input'])
133
+ if not sanitized_input:
134
+ return None, "Input text cannot be empty", 400
135
+
136
+ # Validate voice parameter
137
+ if not isinstance(body['voice'], str) or not body['voice']:
138
+ return None, "Invalid voice parameter", 400
139
+
140
+ openai_fm_data = {
141
+ 'input': sanitized_input,
142
+ 'voice': body['voice']
143
+ }
144
+
145
+ # Validate and sanitize instructions if provided
146
+ if 'instructions' in body:
147
+ if not isinstance(body['instructions'], str):
148
+ return None, "Instructions must be a string", 400
149
+ openai_fm_data['prompt'] = _sanitize_input(body['instructions'])
150
+
151
+ # Validate response format
152
+ content_type = "audio/mpeg"
153
+ if 'response_format' in body:
154
+ format_mapping = {
155
+ 'mp3': 'audio/mpeg',
156
+ 'opus': 'audio/opus',
157
+ 'aac': 'audio/aac',
158
+ 'flac': 'audio/flac',
159
+ 'wav': 'audio/wav',
160
+ 'pcm': 'audio/pcm'
161
+ }
162
+ requested_format = body['response_format'].lower()
163
+ if requested_format not in format_mapping:
164
+ return None, f"Unsupported response format: {requested_format}. Supported formats are: {', '.join(format_mapping.keys())}", 400
165
+ content_type = format_mapping[requested_format]
166
+ openai_fm_data['format'] = requested_format
167
+
168
+ return openai_fm_data, None, None
169
+ except Exception as e:
170
+ logger.error(f"Error validating request: {str(e)}")
171
+ return None, "Invalid request format", 400
172
+
173
+ def get_queue_details() -> Dict[str, Any]:
174
+ """Get detailed queue counts from Celery workers."""
175
+ details = {
176
+ 'active': 0,
177
+ 'reserved': 0,
178
+ 'scheduled': 0,
179
+ 'total_reported_by_workers': 0,
180
+ 'error': None
181
+ }
182
+ try:
183
+ i = celery.control.inspect(timeout=1.0) # Add timeout
184
+ if not i:
185
+ details['error'] = "Could not connect to Celery workers for inspection."
186
+ return details
187
+
188
+ active_tasks = i.active()
189
+ reserved_tasks = i.reserved()
190
+ scheduled_tasks = i.scheduled()
191
+
192
+ if active_tasks:
193
+ details['active'] = sum(len(tasks) for tasks in active_tasks.values())
194
+ if reserved_tasks:
195
+ details['reserved'] = sum(len(tasks) for tasks in reserved_tasks.values())
196
+ if scheduled_tasks:
197
+ details['scheduled'] = sum(len(tasks) for tasks in scheduled_tasks.values())
198
+
199
+ details['total_reported_by_workers'] = details['active'] + details['reserved'] + details['scheduled']
200
+
201
+ except Exception as e:
202
+ logger.error(f"Error calculating queue details: {str(e)}")
203
+ details['error'] = f"Failed to inspect Celery workers: {str(e)}"
204
+ # Reset counts to 0 on error to avoid misleading data
205
+ details['active'] = 0
206
+ details['reserved'] = 0
207
+ details['scheduled'] = 0
208
+ details['total_reported_by_workers'] = 0
209
+
210
+ return details
211
+
212
+ @app.route('/v1/audio/speech', methods=['POST'])
213
+ def openai_speech() -> Response:
214
+ """Handle POST requests to /v1/audio/speech (OpenAI compatible API)"""
215
+ try:
216
+ # Check queue size from Celery worker reports
217
+ queue_details = get_queue_details()
218
+ current_total = queue_details['total_reported_by_workers']
219
+
220
+ # Check for inspection errors
221
+ if queue_details['error']:
222
+ logger.warning(f"Could not determine queue size due to inspection error: {queue_details['error']}. Allowing request.")
223
+ # Optionally, you could reject here, but allowing might be safer if inspection is flaky
224
+
225
+ elif current_total >= MAX_QUEUE_SIZE:
226
+ logger.warning(f"Queue is full based on worker reports. Current total: {current_total}, Max size: {MAX_QUEUE_SIZE}")
227
+ return jsonify({
228
+ "error": "Queue is full. Please try again later.",
229
+ "queue_details": queue_details, # Provide detailed counts
230
+ "max_queue_size_limit": MAX_QUEUE_SIZE
231
+ }), 429
232
+
233
+ # Read and validate JSON data
234
+ body = request.get_json()
235
+ openai_fm_data, error, status_code = validate_tts_request(body)
236
+ if error:
237
+ return jsonify({"error": error}), status_code
238
+
239
+ # Determine content type from validation or default
240
+ validated_content_type = "audio/mpeg" # Default
241
+ if 'response_format' in body:
242
+ format_mapping = {
243
+ 'mp3': 'audio/mpeg',
244
+ 'opus': 'audio/opus',
245
+ 'aac': 'audio/aac',
246
+ 'flac': 'audio/flac',
247
+ 'wav': 'audio/wav',
248
+ 'pcm': 'audio/pcm'
249
+ }
250
+ requested_format = body['response_format'].lower()
251
+ if requested_format in format_mapping:
252
+ validated_content_type = format_mapping[requested_format]
253
+
254
+ # Create task data
255
+ task_data = {
256
+ 'data': openai_fm_data,
257
+ 'timestamp': datetime.now().isoformat()
258
+ }
259
+
260
+ # Submit task to Celery
261
+ task = process_tts_request.delay(task_data)
262
+
263
+ # Wait for result with timeout
264
+ try:
265
+ audio_data, error, status_code = task.get(timeout=30)
266
+ if error:
267
+ logger.error(f"Task error: {error}")
268
+ return jsonify({"error": error}), status_code
269
+
270
+ return Response(
271
+ audio_data,
272
+ mimetype=validated_content_type # Use the correctly determined content type
273
+ )
274
+ except TimeoutError:
275
+ logger.error(f"Task timeout: {task.id}")
276
+ return jsonify({
277
+ "error": "Request timed out. Please try again.",
278
+ "task_id": task.id
279
+ }), 408
280
+
281
+ except json.JSONDecodeError:
282
+ logger.error("Invalid JSON in request body")
283
+ return jsonify({"error": "Invalid JSON in request body"}), 400
284
+ except Exception as e:
285
+ logger.error(f"Unexpected error in speech endpoint: {str(e)}")
286
+ return jsonify({"error": "Internal Server Error"}), 500
287
+
288
+ @app.route('/api/queue-size', methods=['GET'])
289
+ def queue_size() -> Response:
290
+ """Handle GET requests to /api/queue-size with detailed counts"""
291
+ try:
292
+ queue_details = get_queue_details()
293
+
294
+ response_data = {
295
+ "active_tasks": queue_details['active'],
296
+ "reserved_tasks": queue_details['reserved'],
297
+ "scheduled_tasks": queue_details['scheduled'],
298
+ "total_reported_by_workers": queue_details['total_reported_by_workers'],
299
+ "max_queue_size_limit": MAX_QUEUE_SIZE,
300
+ "error": queue_details['error']
301
+ }
302
+
303
+ # Determine status code based on whether there was an inspection error
304
+ status_code = 500 if queue_details['error'] else 200
305
+
306
+ return jsonify(response_data), status_code
307
+
308
+ except Exception as e:
309
+ # This handles errors in the route handler itself, not inspection errors
310
+ logger.error(f"Error in /api/queue-size endpoint: {str(e)}")
311
+ return jsonify({
312
+ "active_tasks": 0,
313
+ "reserved_tasks": 0,
314
+ "scheduled_tasks": 0,
315
+ "total_reported_by_workers": 0,
316
+ "max_queue_size_limit": MAX_QUEUE_SIZE,
317
+ "error": "Failed to process queue status request"
318
+ }), 500
319
+
320
+ @app.route('/api/voice-sample/<voice>', methods=['GET'])
321
+ def voice_sample(voice: str) -> Response:
322
+ """Handle GET requests for voice samples"""
323
+ try:
324
+ if not voice:
325
+ return jsonify({
326
+ "error": "Voice parameter is required"
327
+ }), 400
328
+
329
+ # Secure the voice parameter and prevent path traversal
330
+ secure_voice = secure_filename(voice)
331
+ if not secure_voice or secure_voice != voice:
332
+ logger.warning(f"Invalid voice parameter: {voice}")
333
+ return jsonify({
334
+ "error": "Invalid voice parameter"
335
+ }), 400
336
+
337
+ # Normalize and validate the path
338
+ base_path = os.path.abspath(VOICE_SAMPLES_DIR)
339
+ sample_path = os.path.normpath(os.path.join(base_path, f"{secure_voice}_sample.mp3"))
340
+
341
+ # Ensure the path is within the voice samples directory
342
+ if not sample_path.startswith(base_path):
343
+ logger.warning(f"Path traversal attempt: {sample_path}")
344
+ return jsonify({
345
+ "error": "Invalid path"
346
+ }), 400
347
+
348
+ if not os.path.exists(sample_path):
349
+ logger.warning(f"Sample not found for voice: {voice}")
350
+ return jsonify({
351
+ "error": f"Sample not found for voice: {voice}"
352
+ }), 404
353
+
354
+ return send_file(
355
+ sample_path,
356
+ mimetype="audio/mpeg",
357
+ as_attachment=False,
358
+ download_name=f"{secure_voice}_sample.mp3"
359
+ )
360
+
361
+ except Exception as e:
362
+ logger.error(f"Error serving voice sample: {str(e)}")
363
+ return jsonify({
364
+ "error": "Internal Server Error"
365
+ }), 500
366
+
367
+ @app.route('/api/version', methods=['GET'])
368
+ def get_version() -> Response:
369
+ """Handle GET requests for API version"""
370
+ return jsonify({
371
+ "version": "v2.0.0-alpha_x"
372
+ })