johnbridges commited on
Commit
25d7670
·
1 Parent(s): 58a6828

init commit

Browse files
Files changed (6) hide show
  1. Dockerfile +38 -0
  2. app.py +377 -0
  3. commit +3 -0
  4. kokoro.py +165 -0
  5. requirements.txt +17 -0
  6. tts_processor.py +163 -0
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Install system dependencies
4
+ RUN apt-get update && apt-get install -y --no-install-recommends \
5
+ libsndfile1 \
6
+ espeak-ng \
7
+ ffmpeg \
8
+ git \
9
+ wget \
10
+ && rm -rf /var/lib/apt/lists/*
11
+ RUN useradd -m -u 1000 user
12
+
13
+ # Switch to the "user" user
14
+ USER user
15
+
16
+ # Set home to the user's home directory
17
+ ENV HOME=/home/user \
18
+ PATH=/home/user/.local/bin:$PATH
19
+
20
+ # Set the working directory to the user's home directory
21
+ WORKDIR $HOME/app
22
+
23
+ # Create the files directory
24
+ RUN mkdir -p $HOME/app/files
25
+
26
+ # Copy and install Python dependencies
27
+ COPY requirements.txt $HOME/app/
28
+ RUN pip install --no-cache-dir -r requirements.txt && pip install --upgrade pip
29
+
30
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
31
+ COPY --chown=user . $HOME/app
32
+
33
+
34
+ # Expose port
35
+ EXPOSE 7860
36
+
37
+ # Run the application
38
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_from_directory, abort
2
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
+ import librosa
4
+ import torch
5
+ import numpy as np
6
+ from onnxruntime import InferenceSession
7
+ import soundfile as sf
8
+ import os
9
+ import sys
10
+ import uuid
11
+ import logging
12
+ from flask_cors import CORS
13
+ import threading
14
+ import werkzeug
15
+ import tempfile
16
+ from huggingface_hub import snapshot_download
17
+ from tts_processor import preprocess_all
18
+ import hashlib
19
+ import os
20
+ import torch
21
+ import numpy as np
22
+ import onnxruntime as ort
23
+
24
+ # ---------------------------
25
+ # THREAD LIMIT CONFIG
26
+ # ---------------------------
27
+ MAX_THREADS = 2 # <-- change this number to control all thread usage
28
+
29
+ # ---------------------------
30
+ # ---------------------------
31
+ # STORAGE ROOT
32
+ # ---------------------------
33
+ SERVE_DIR = "/home/user/app/files"
34
+ os.makedirs(SERVE_DIR, exist_ok=True)
35
+
36
+ # Limit NumPy / BLAS / MKL threads
37
+ os.environ["OMP_NUM_THREADS"] = str(MAX_THREADS)
38
+ os.environ["OPENBLAS_NUM_THREADS"] = str(MAX_THREADS)
39
+ os.environ["MKL_NUM_THREADS"] = str(MAX_THREADS)
40
+ os.environ["VECLIB_MAXIMUM_THREADS"] = str(MAX_THREADS)
41
+ os.environ["NUMEXPR_NUM_THREADS"] = str(MAX_THREADS)
42
+
43
+ # Torch thread limits
44
+ torch.set_num_threads(MAX_THREADS)
45
+ torch.set_num_interop_threads(1) # keep inter-op small to avoid overhead
46
+
47
+ # ONNXRuntime session options (use when creating the session)
48
+ sess_options = ort.SessionOptions()
49
+ sess_options.intra_op_num_threads = MAX_THREADS
50
+ sess_options.inter_op_num_threads = 1
51
+
52
+
53
+ # Configure logging
54
+ logging.basicConfig(level=logging.INFO)
55
+ logger = logging.getLogger(__name__)
56
+
57
+ app = Flask(__name__)
58
+ CORS(app, resources={r"/*": {"origins": "*"}})
59
+
60
+ # Global lock to ensure one method runs at a time
61
+ global_lock = threading.Lock()
62
+
63
+ # Repository ID and paths
64
+ kokoro_model_id = 'onnx-community/Kokoro-82M-v1.0-ONNX'
65
+ model_path = 'kokoro_model'
66
+ voice_name = 'am_adam' # Example voice: af (adjust as needed)
67
+
68
+ # Directory to serve files from
69
+ SERVE_DIR = os.environ.get("SERVE_DIR", "./files") # Default to './files' if not provided
70
+
71
+ os.makedirs(SERVE_DIR, exist_ok=True)
72
+ def validate_audio_file(file):
73
+ """Validates audio files including WebM/Opus format"""
74
+ if not isinstance(file, werkzeug.datastructures.FileStorage):
75
+ raise ValueError("Invalid file type")
76
+
77
+ # Supported MIME types (add WebM/Opus)
78
+ supported_types = [
79
+ "audio/wav",
80
+ "audio/x-wav",
81
+ "audio/mpeg",
82
+ "audio/mp3",
83
+ "audio/webm",
84
+ "audio/ogg" # For Opus in Ogg container
85
+ ]
86
+
87
+ # Check MIME type
88
+ if file.content_type not in supported_types:
89
+ raise ValueError(f"Unsupported file type. Must be one of: {', '.join(supported_types)}")
90
+
91
+ # Check file size
92
+ file.seek(0, os.SEEK_END)
93
+ file_size = file.tell()
94
+ file.seek(0) # Reset file pointer
95
+
96
+ max_size = 10 * 1024 * 1024 # 10 MB
97
+ if file_size > max_size:
98
+ raise ValueError(f"File is too large (max {max_size//(1024*1024)} MB)")
99
+
100
+ # Optional: Verify file header matches content_type
101
+ if not verify_audio_header(file):
102
+ raise ValueError("File header doesn't match declared content type")
103
+ def verify_audio_header(file):
104
+ """Quickly checks if file headers match the declared audio format"""
105
+ header = file.read(4)
106
+ file.seek(0) # Rewind after reading
107
+
108
+ if file.content_type in ["audio/webm", "audio/ogg"]:
109
+ # WebM starts with \x1aE\xdf\xa3, Ogg with OggS
110
+ return (
111
+ (file.content_type == "audio/webm" and header.startswith(b'\x1aE\xdf\xa3')) or
112
+ (file.content_type == "audio/ogg" and header.startswith(b'OggS'))
113
+ )
114
+ elif file.content_type in ["audio/wav", "audio/x-wav"]:
115
+ return header.startswith(b'RIFF')
116
+ elif file.content_type in ["audio/mpeg", "audio/mp3"]:
117
+ return header.startswith(b'\xff\xfb') # MP3 frame sync
118
+ return True # Skip verification for other types
119
+
120
+ def validate_text_input(text):
121
+ if not isinstance(text, str):
122
+ raise ValueError("Text input must be a string")
123
+ if len(text.strip()) == 0:
124
+ raise ValueError("Text input cannot be empty")
125
+ if len(text) > 1024: # Limit to 1024 characters
126
+ raise ValueError("Text input is too long (max 1024 characters)")
127
+
128
+ file_cache = {}
129
+
130
+ def is_cached(cached_file_path):
131
+ """
132
+ Check if a file exists in the cache.
133
+ If the file is not in the cache, perform a disk check and update the cache.
134
+ """
135
+ if cached_file_path in file_cache:
136
+ return file_cache[cached_file_path] # Return cached result
137
+ exists = os.path.exists(cached_file_path) # Perform disk check
138
+ file_cache[cached_file_path] = exists # Update the cache
139
+ return exists
140
+
141
+ # Initialize models
142
+ def initialize_models():
143
+ global sess, voice_style, processor, whisper_model
144
+
145
+ try:
146
+ # Download the ONNX model if not already downloaded
147
+ if not os.path.exists(model_path):
148
+ logger.info("Downloading and loading Kokoro model...")
149
+ kokoro_dir = snapshot_download(kokoro_model_id, cache_dir=model_path)
150
+ logger.info(f"Kokoro model directory: {kokoro_dir}")
151
+ else:
152
+ kokoro_dir = model_path
153
+ logger.info(f"Using cached Kokoro model directory: {kokoro_dir}")
154
+
155
+ # Validate ONNX file path
156
+ onnx_path = None
157
+ for root, _, files in os.walk(kokoro_dir):
158
+ if 'model.onnx' in files:
159
+ onnx_path = os.path.join(root, 'model.onnx')
160
+ break
161
+
162
+ if not onnx_path or not os.path.exists(onnx_path):
163
+ raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
164
+
165
+ logger.info("Loading ONNX session...")
166
+ sess = InferenceSession(onnx_path, sess_options)
167
+ logger.info(f"ONNX session loaded successfully from {onnx_path}")
168
+
169
+ # Load the voice style vector
170
+ voice_style_path = None
171
+ for root, _, files in os.walk(kokoro_dir):
172
+ if f'{voice_name}.bin' in files:
173
+ voice_style_path = os.path.join(root, f'{voice_name}.bin')
174
+ break
175
+
176
+ if not voice_style_path or not os.path.exists(voice_style_path):
177
+ raise FileNotFoundError(f"Voice style file not found at {voice_style_path}")
178
+
179
+ logger.info("Loading voice style vector...")
180
+ voice_style = np.fromfile(voice_style_path, dtype=np.float32).reshape(-1, 1, 256)
181
+ logger.info(f"Voice style vector loaded successfully from {voice_style_path}")
182
+
183
+ # Initialize Whisper model for S2T
184
+ logger.info("Downloading and loading Whisper model...")
185
+ processor = WhisperProcessor.from_pretrained("openai/whisper-base")
186
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
187
+ whisper_model.config.forced_decoder_ids = None
188
+ logger.info("Whisper model loaded successfully")
189
+
190
+ except Exception as e:
191
+ logger.error(f"Error initializing models: {str(e)}")
192
+ raise
193
+
194
+ # Initialize models
195
+ initialize_models()
196
+
197
+ # Health check endpoint
198
+ @app.route('/health', methods=['GET'])
199
+ def health_check():
200
+ try:
201
+ return jsonify({"status": "healthy"}), 200
202
+ except Exception as e:
203
+ logger.error(f"Health check failed: {str(e)}")
204
+ return jsonify({"status": "unhealthy"}), 500
205
+
206
+ # Text-to-Speech (T2S) Endpoint
207
+ @app.route('/generate_audio', methods=['POST'])
208
+ def generate_audio():
209
+ """Text-to-Speech (T2S) Endpoint"""
210
+ with global_lock:
211
+ try:
212
+ logger.debug("Received request to /generate_audio")
213
+ data = request.json
214
+ text = data['text']
215
+
216
+ validate_text_input(text)
217
+
218
+ # Preprocess & stable hash
219
+ text = preprocess_all(text)
220
+ text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
221
+ filename = f"{text_hash}.wav"
222
+ cached_file_path = os.path.join(SERVE_DIR, filename)
223
+
224
+ # Cache hit
225
+ if is_cached(cached_file_path):
226
+ logger.info("Returning cached audio")
227
+ return jsonify({"status": "success", "filename": filename})
228
+
229
+ # Tokenize
230
+ from kokoro import phonemize, tokenize # lazy import is fine
231
+ tokens = tokenize(phonemize(text, 'a'))
232
+ if len(tokens) > 510:
233
+ logger.warning("Text too long; truncating to 510 tokens.")
234
+ tokens = tokens[:510]
235
+ tokens = [[0, *tokens, 0]]
236
+
237
+ # Style vector
238
+ ref_s = voice_style[len(tokens[0]) - 2] # (1,256)
239
+
240
+ # ONNX inference
241
+ audio = sess.run(None, dict(
242
+ input_ids=np.array(tokens, dtype=np.int64),
243
+ style=ref_s,
244
+ speed=np.ones(1, dtype=np.float32),
245
+ ))[0]
246
+
247
+ # Save
248
+ audio = np.squeeze(audio).astype(np.float32)
249
+ sf.write(cached_file_path, audio, 24000)
250
+
251
+ logger.info(f"Audio saved: {cached_file_path}")
252
+ return jsonify({"status": "success", "filename": filename})
253
+ except Exception as e:
254
+ logger.error(f"Error generating audio: {str(e)}")
255
+ return jsonify({"status": "error", "message": str(e)}), 500
256
+
257
+ # Speech-to-Text (S2T) Endpoint
258
+ # Add these imports at the top with the other imports
259
+ import subprocess
260
+ import tempfile
261
+ from pathlib import Path
262
+
263
+ # Then update the transcribe_audio function:
264
+ @app.route('/transcribe_audio', methods=['POST'])
265
+ def transcribe_audio():
266
+ """Speech-to-Text (S2T) Endpoint with automatic format conversion"""
267
+ with global_lock: # Acquire global lock to ensure only one instance runs
268
+ input_audio_path = None
269
+ converted_audio_path = None
270
+ try:
271
+ logger.debug("Received request to /transcribe_audio")
272
+ file = request.files['file']
273
+
274
+ # Create temporary files for both input and output
275
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as input_temp:
276
+ input_audio_path = input_temp.name
277
+ file.save(input_audio_path)
278
+ logger.debug(f"Original audio file saved to {input_audio_path}")
279
+
280
+ # Create a temporary file for the converted WAV
281
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as output_temp:
282
+ converted_audio_path = output_temp.name
283
+
284
+ # Convert to WAV with ffmpeg (16kHz, mono)
285
+ logger.debug(f"Converting audio to 16kHz mono WAV format...")
286
+ conversion_command = [
287
+ 'ffmpeg',
288
+ '-y', # Force overwrite without prompting
289
+ '-i', input_audio_path,
290
+ '-acodec', 'pcm_s16le', # 16-bit PCM
291
+ '-ac', '1', # mono
292
+ '-ar', '16000', # 16kHz sample rate
293
+ '-af', 'highpass=f=80,lowpass=f=7500,afftdn=nr=10:nf=-25,loudnorm=I=-16:TP=-1.5:LRA=11', # Audio cleanup filters
294
+ converted_audio_path
295
+ ]
296
+ result = subprocess.run(
297
+ conversion_command,
298
+ stdout=subprocess.PIPE,
299
+ stderr=subprocess.PIPE,
300
+ text=True
301
+ )
302
+
303
+ if result.returncode != 0:
304
+ logger.error(f"FFmpeg conversion error: {result.stderr}")
305
+ raise Exception(f"Audio conversion failed: {result.stderr}")
306
+
307
+ logger.debug(f"Audio successfully converted to {converted_audio_path}")
308
+
309
+ # Load and process the converted audio
310
+ logger.debug("Processing audio for transcription...")
311
+ audio_array, sampling_rate = librosa.load(converted_audio_path, sr=16000)
312
+
313
+ input_features = processor(
314
+ audio_array,
315
+ sampling_rate=sampling_rate,
316
+ return_tensors="pt"
317
+ ).input_features
318
+
319
+ # Generate transcription
320
+ logger.debug("Generating transcription...")
321
+ predicted_ids = whisper_model.generate(input_features)
322
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
323
+ logger.info(f"Transcription: {transcription}")
324
+
325
+ return jsonify({"status": "success", "transcription": transcription})
326
+ except Exception as e:
327
+ logger.error(f"Error transcribing audio: {str(e)}")
328
+ return jsonify({"status": "error", "message": str(e)}), 500
329
+ finally:
330
+ # Clean up temporary files
331
+ for path in [input_audio_path, converted_audio_path]:
332
+ if path and os.path.exists(path):
333
+ try:
334
+ os.remove(path)
335
+ logger.debug(f"Temporary file {path} removed")
336
+ except Exception as e:
337
+ logger.warning(f"Failed to remove temporary file {path}: {e}")
338
+
339
+ @app.route('/files/<filename>', methods=['GET'])
340
+ def serve_wav_file(filename):
341
+ """
342
+ Serve a .wav file from the configured directory.
343
+ Only serves files ending with '.wav'.
344
+ """
345
+ # Ensure only .wav files are allowed
346
+ if not filename.lower().endswith('.wav'):
347
+ abort(400, "Only .wav files are allowed.")
348
+
349
+ # Check if the file exists in the directory
350
+ file_path = os.path.join(SERVE_DIR, filename)
351
+ logger.debug(f"Looking for file at: {file_path}")
352
+ if not os.path.isfile(file_path):
353
+ logger.error(f"File not found: {file_path}")
354
+ abort(404, "File not found.")
355
+
356
+ # Serve the file
357
+ return send_from_directory(SERVE_DIR, filename)
358
+
359
+ # Error handlers
360
+ @app.errorhandler(400)
361
+ def bad_request(error):
362
+ """Handle 400 errors."""
363
+ return {"error": "Bad Request", "message": str(error)}, 400
364
+
365
+ @app.errorhandler(404)
366
+ def not_found(error):
367
+ """Handle 404 errors."""
368
+ return {"error": "Not Found", "message": str(error)}, 404
369
+
370
+ @app.errorhandler(500)
371
+ def internal_error(error):
372
+ """Handle unexpected errors."""
373
+ return {"error": "Internal Server Error", "message": "An unexpected error occurred."}, 500
374
+
375
+ if __name__ == "__main__":
376
+ app.run(host="0.0.0.0", port=7860, threaded=False, processes=1)
377
+
commit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git add .
2
+ git commit -m "$*"
3
+ git push
kokoro.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import phonemizer
2
+ import re
3
+ import torch
4
+ import numpy as np
5
+
6
+ def split_num(num):
7
+ num = num.group()
8
+ if '.' in num:
9
+ return num
10
+ elif ':' in num:
11
+ h, m = [int(n) for n in num.split(':')]
12
+ if m == 0:
13
+ return f"{h} o'clock"
14
+ elif m < 10:
15
+ return f'{h} oh {m}'
16
+ return f'{h} {m}'
17
+ year = int(num[:4])
18
+ if year < 1100 or year % 1000 < 10:
19
+ return num
20
+ left, right = num[:2], int(num[2:4])
21
+ s = 's' if num.endswith('s') else ''
22
+ if 100 <= year % 1000 <= 999:
23
+ if right == 0:
24
+ return f'{left} hundred{s}'
25
+ elif right < 10:
26
+ return f'{left} oh {right}{s}'
27
+ return f'{left} {right}{s}'
28
+
29
+ def flip_money(m):
30
+ m = m.group()
31
+ bill = 'dollar' if m[0] == '$' else 'pound'
32
+ if m[-1].isalpha():
33
+ return f'{m[1:]} {bill}s'
34
+ elif '.' not in m:
35
+ s = '' if m[1:] == '1' else 's'
36
+ return f'{m[1:]} {bill}{s}'
37
+ b, c = m[1:].split('.')
38
+ s = '' if b == '1' else 's'
39
+ c = int(c.ljust(2, '0'))
40
+ coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
41
+ return f'{b} {bill}{s} and {c} {coins}'
42
+
43
+ def point_num(num):
44
+ a, b = num.group().split('.')
45
+ return ' point '.join([a, ' '.join(b)])
46
+
47
+ def normalize_text(text):
48
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
49
+ text = text.replace('«', chr(8220)).replace('»', chr(8221))
50
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
51
+ text = text.replace('(', '«').replace(')', '»')
52
+ for a, b in zip('、。!,:;?', ',.!,:;?'):
53
+ text = text.replace(a, b+' ')
54
+ text = re.sub(r'[^\S \n]', ' ', text)
55
+ text = re.sub(r' +', ' ', text)
56
+ text = re.sub(r'(?<=\n) +(?=\n)', '', text)
57
+ text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
58
+ text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
59
+ text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
60
+ text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
61
+ text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
62
+ text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
63
+ text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
64
+ text = re.sub(r'(?<=\d),(?=\d)', '', text)
65
+ text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
66
+ text = re.sub(r'\d*\.\d+', point_num, text)
67
+ text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
68
+ text = re.sub(r'(?<=\d)S', ' S', text)
69
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
70
+ text = re.sub(r"(?<=X')S\b", 's', text)
71
+ text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
72
+ text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
73
+ return text.strip()
74
+
75
+ def get_vocab():
76
+ _pad = "$"
77
+ _punctuation = ';:,.!?¡¿—…"«»“” '
78
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
79
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
80
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
81
+ dicts = {}
82
+ for i in range(len((symbols))):
83
+ dicts[symbols[i]] = i
84
+ return dicts
85
+
86
+ VOCAB = get_vocab()
87
+ def tokenize(ps):
88
+ return [i for i in map(VOCAB.get, ps) if i is not None]
89
+
90
+ phonemizers = dict(
91
+ a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
92
+ b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
93
+ )
94
+ def phonemize(text, lang, norm=True):
95
+ if norm:
96
+ text = normalize_text(text)
97
+ ps = phonemizers[lang].phonemize([text])
98
+ ps = ps[0] if ps else ''
99
+ # https://en.wiktionary.org/wiki/kokoro#English
100
+ ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
101
+ ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
102
+ ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
103
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
104
+ if lang == 'a':
105
+ ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
106
+ ps = ''.join(filter(lambda p: p in VOCAB, ps))
107
+ return ps.strip()
108
+
109
+ def length_to_mask(lengths):
110
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
111
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
112
+ return mask
113
+
114
+ @torch.no_grad()
115
+ def forward(model, tokens, ref_s, speed):
116
+ device = ref_s.device
117
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
118
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
119
+ text_mask = length_to_mask(input_lengths).to(device)
120
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
121
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
122
+ s = ref_s[:, 128:]
123
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
124
+ x, _ = model.predictor.lstm(d)
125
+ duration = model.predictor.duration_proj(x)
126
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
127
+ pred_dur = torch.round(duration).clamp(min=1).long()
128
+ pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
129
+ c_frame = 0
130
+ for i in range(pred_aln_trg.size(0)):
131
+ pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
132
+ c_frame += pred_dur[0,i].item()
133
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
134
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
135
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
136
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
137
+ return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
138
+
139
+ def generate(model, text, voicepack, lang='a', speed=1, ps=None):
140
+ ps = ps or phonemize(text, lang)
141
+ tokens = tokenize(ps)
142
+ if not tokens:
143
+ return None
144
+ elif len(tokens) > 510:
145
+ tokens = tokens[:510]
146
+ print('Truncated to 510 tokens')
147
+ ref_s = voicepack[len(tokens)]
148
+ out = forward(model, tokens, ref_s, speed)
149
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
150
+ return out, ps
151
+
152
+ def generate_full(model, text, voicepack, lang='a', speed=1, ps=None):
153
+ ps = ps or phonemize(text, lang)
154
+ tokens = tokenize(ps)
155
+ if not tokens:
156
+ return None
157
+ outs = []
158
+ loop_count = len(tokens)//510 + (1 if len(tokens) % 510 != 0 else 0)
159
+ for i in range(loop_count):
160
+ ref_s = voicepack[len(tokens[i*510:(i+1)*510])]
161
+ out = forward(model, tokens[i*510:(i+1)*510], ref_s, speed)
162
+ outs.append(out)
163
+ outs = np.concatenate(outs)
164
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
165
+ return outs, ps
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ transformers
4
+ librosa
5
+ numpy
6
+ soundfile
7
+ huggingface_hub
8
+ phonemizer
9
+ munch
10
+ werkzeug
11
+ num2words
12
+ dateparser
13
+ inflect
14
+ ftfy
15
+ sentencepiece
16
+ torch --index-url https://download.pytorch.org/whl/cpu
17
+ onnxruntime
tts_processor.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dateutil.parser import parse
3
+ from num2words import num2words
4
+ import inflect
5
+ from ftfy import fix_text
6
+
7
+ # Initialize the inflect engine
8
+ inflect_engine = inflect.engine()
9
+
10
+ # Define alphabet pronunciation mapping
11
+ alphabet_map = {
12
+ "A": " Eh ", "B": " Bee ", "C": " See ", "D": " Dee ", "E": " Eee ",
13
+ "F": " Eff ", "G": " Jee ", "H": " Aitch ", "I": " Eye ", "J": " Jay ",
14
+ "K": " Kay ", "L": " El ", "M": " Emm ", "N": " Enn ", "O": " Ohh ",
15
+ "P": " Pee ", "Q": " Queue ", "R": " Are ", "S": " Ess ", "T": " Tee ",
16
+ "U": " You ", "V": " Vee ", "W": " Double You ", "X": " Ex ", "Y": " Why ", "Z": " Zed "
17
+ }
18
+
19
+ # Function to add ordinal suffix to a number
20
+ def add_ordinal_suffix(day):
21
+ """Adds ordinal suffix to a day (e.g., 13 -> 13th)."""
22
+ if 11 <= day <= 13: # Special case for 11th, 12th, 13th
23
+ return f"{day}th"
24
+ elif day % 10 == 1:
25
+ return f"{day}st"
26
+ elif day % 10 == 2:
27
+ return f"{day}nd"
28
+ elif day % 10 == 3:
29
+ return f"{day}rd"
30
+ else:
31
+ return f"{day}th"
32
+
33
+ # Function to format dates in a human-readable form
34
+ def format_date(parsed_date, include_time=True):
35
+ """Formats a parsed date into a human-readable string."""
36
+ if not parsed_date:
37
+ return None
38
+
39
+ # Convert the day into an ordinal (e.g., 13 -> 13th)
40
+ day = add_ordinal_suffix(parsed_date.day)
41
+
42
+ # Format the date in a TTS-friendly way
43
+ if include_time and parsed_date.hour != 0 and parsed_date.minute != 0:
44
+ return parsed_date.strftime(f"%B {day}, %Y at %-I:%M %p") # Unix
45
+ return parsed_date.strftime(f"%B {day}, %Y") # Only date
46
+
47
+ # Normalize dates in the text
48
+ def normalize_dates(text):
49
+ """
50
+ Finds and replaces date strings with a nicely formatted, TTS-friendly version.
51
+ """
52
+ def replace_date(match):
53
+ raw_date = match.group(0)
54
+ try:
55
+ parsed_date = parse(raw_date)
56
+ if parsed_date:
57
+ include_time = "T" in raw_date or " " in raw_date # Include time only if explicitly provided
58
+ return format_date(parsed_date, include_time)
59
+ except ValueError:
60
+ pass
61
+ return raw_date
62
+
63
+ # Match common date formats
64
+ date_pattern = r"\b(\d{4}-\d{2}-\d{2}(?:[ T]\d{2}:\d{2}:\d{2})?|\d{2}/\d{2}/\d{4}|\d{1,2} \w+ \d{4})\b"
65
+ return re.sub(date_pattern, replace_date, text)
66
+
67
+ # Replace invalid characters and clean text
68
+ def replace_invalid_chars(string):
69
+ string = fix_text(string)
70
+ replacements = {
71
+ "**": "",
72
+ '&#x27;': "'",
73
+ 'AI;': 'Artificial Intelligence!',
74
+ 'iddqd;': 'Immortality cheat code',
75
+ '😉;': 'wink wink!',
76
+ ':D': '*laughs* Ahahaha!',
77
+ ';D': '*laughs* Ahahaha!'
78
+ }
79
+ for old, new in replacements.items():
80
+ string = string.replace(old, new)
81
+ return string
82
+
83
+ # Replace numbers with their word equivalents
84
+ def replace_numbers(string):
85
+ ipv4_pattern = r'(\b\d{1,3}(\.\d{1,3}){3}\b)'
86
+ ipv6_pattern = r'([0-9a-fA-F]{1,4}:){2,7}[0-9a-fA-F]{1,4}'
87
+ range_pattern = r'\b\d+-\d+\b' # Detect ranges like 1-4
88
+ date_pattern = r'\b\d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}:\d{2})?\b'
89
+ alphanumeric_pattern = r'\b[A-Za-z]+\d+|\d+[A-Za-z]+\b'
90
+
91
+ # Do not process IP addresses, date patterns, or alphanumerics
92
+ if re.search(ipv4_pattern, string) or re.search(ipv6_pattern, string) or re.search(range_pattern, string) or re.search(date_pattern, string) or re.search(alphanumeric_pattern, string):
93
+ return string
94
+
95
+ # Convert standalone numbers and port numbers
96
+ def convert_number(match):
97
+ number = match.group()
98
+ return num2words(int(number)) if number.isdigit() else number
99
+
100
+ pattern = re.compile(r'\b\d+\b')
101
+ return re.sub(pattern, convert_number, string)
102
+
103
+ # Replace abbreviations with expanded form
104
+ def replace_abbreviations(string):
105
+ words = string.split()
106
+ for i, word in enumerate(words):
107
+ if word.isupper() and len(word) <= 4 and not any(char.isdigit() for char in word) and word not in ["ID", "AM", "PM"]:
108
+ words[i] = ''.join([alphabet_map.get(char, char) for char in word])
109
+ return ' '.join(words)
110
+
111
+ def clean_whitespace(string):
112
+ # Remove spaces before punctuation
113
+ string = re.sub(r'\s+([.,?!])', r'\1', string)
114
+ # Collapse multiple spaces into one, but don’t touch inside tokens like "test.com"
115
+ string = re.sub(r'\s{2,}', ' ', string)
116
+ return string.strip()
117
+
118
+ def make_dots_tts_friendly(text):
119
+ # Handle IP addresses (force "dot")
120
+ ipv4_pattern = r'\b\d{1,3}(\.\d{1,3}){3}\b'
121
+ text = re.sub(ipv4_pattern, lambda m: m.group(0).replace('.', ' dot '), text)
122
+
123
+ # Handle domain-like endings (force "dot")
124
+ domain_pattern = r'\b([\w-]+)\.(com|net|org|io|gov|edu|exe|dll|local)\b'
125
+ text = re.sub(domain_pattern, lambda m: m.group(0).replace('.', ' dot '), text)
126
+
127
+ # Handle decimals (use "point")
128
+ decimal_pattern = r'\b\d+\.\d+\b'
129
+ text = re.sub(decimal_pattern, lambda m: m.group(0).replace('.', ' point '), text)
130
+
131
+ # Handle leading dot words (.Net → dot Net)
132
+ text = re.sub(r'\.(?=\w)', 'dot ', text)
133
+
134
+ return text
135
+
136
+ # Main preprocessing pipeline
137
+ def preprocess_all(string):
138
+ string = normalize_dates(string)
139
+ string = replace_invalid_chars(string)
140
+ string = replace_numbers(string)
141
+ string = replace_abbreviations(string)
142
+ string = make_dots_tts_friendly(string)
143
+ string = clean_whitespace(string)
144
+ return string
145
+
146
+ # Expose a testing function for external use
147
+ def test_preprocessing(file_path):
148
+ with open(file_path, 'r') as file:
149
+ lines = file.readlines()
150
+ for line in lines:
151
+ original = line.strip()
152
+ processed = preprocess_all(original)
153
+ print(f"Original: {original}")
154
+ print(f"Processed: {processed}\n")
155
+
156
+ if __name__ == "__main__":
157
+ import sys
158
+ if len(sys.argv) > 1:
159
+ test_file = sys.argv[1]
160
+ test_preprocessing(test_file)
161
+ else:
162
+ print("Please provide a file path as an argument.")
163
+