johnbridges commited on
Commit
c153cff
·
1 Parent(s): 35df04b

added new models and normalizations

Browse files
Files changed (2) hide show
  1. app.py +264 -27
  2. tts_processor.py +287 -7
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from flask import Flask, request, jsonify, send_from_directory, abort
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
 
3
  import librosa
4
  import torch
5
  import numpy as np
@@ -10,6 +11,7 @@ import sys
10
  import uuid
11
  import logging
12
  from flask_cors import CORS
 
13
  import threading
14
  import werkzeug
15
  import tempfile
@@ -26,13 +28,6 @@ import onnxruntime as ort
26
  # ---------------------------
27
  MAX_THREADS = 2 # <-- change this number to control all thread usage
28
 
29
- # ---------------------------
30
- # ---------------------------
31
- # STORAGE ROOT
32
- # ---------------------------
33
- SERVE_DIR = "/home/user/app/files"
34
- os.makedirs(SERVE_DIR, exist_ok=True)
35
-
36
  # Limit NumPy / BLAS / MKL threads
37
  os.environ["OMP_NUM_THREADS"] = str(MAX_THREADS)
38
  os.environ["OPENBLAS_NUM_THREADS"] = str(MAX_THREADS)
@@ -66,7 +61,8 @@ model_path = 'kokoro_model'
66
  voice_name = 'am_adam' # Example voice: af (adjust as needed)
67
 
68
  # Directory to serve files from
69
- SERVE_DIR = os.environ.get("SERVE_DIR", "./files") # Default to './files' if not provided
 
70
 
71
  os.makedirs(SERVE_DIR, exist_ok=True)
72
  def validate_audio_file(file):
@@ -138,9 +134,18 @@ def is_cached(cached_file_path):
138
  file_cache[cached_file_path] = exists # Update the cache
139
  return exists
140
 
 
 
 
 
 
 
 
 
141
  # Initialize models
142
  def initialize_models():
143
- global sess, voice_style, processor, whisper_model
 
144
 
145
  try:
146
  # Download the ONNX model if not already downloaded
@@ -180,12 +185,64 @@ def initialize_models():
180
  voice_style = np.fromfile(voice_style_path, dtype=np.float32).reshape(-1, 1, 256)
181
  logger.info(f"Voice style vector loaded successfully from {voice_style_path}")
182
 
183
- # Initialize Whisper model for S2T
184
- logger.info("Downloading and loading Whisper model...")
185
- processor = WhisperProcessor.from_pretrained("openai/whisper-base")
186
- whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
187
- whisper_model.config.forced_decoder_ids = None
188
- logger.info("Whisper model loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  except Exception as e:
191
  logger.error(f"Error initializing models: {str(e)}")
@@ -194,6 +251,150 @@ def initialize_models():
194
  # Initialize models
195
  initialize_models()
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  # Health check endpoint
198
  @app.route('/health', methods=['GET'])
199
  def health_check():
@@ -310,17 +511,54 @@ def transcribe_audio():
310
  logger.debug("Processing audio for transcription...")
311
  audio_array, sampling_rate = librosa.load(converted_audio_path, sr=16000)
312
 
313
- input_features = processor(
314
- audio_array,
315
- sampling_rate=sampling_rate,
316
- return_tensors="pt"
317
- ).input_features
318
-
319
- # Generate transcription
320
- logger.debug("Generating transcription...")
321
- predicted_ids = whisper_model.generate(input_features)
322
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
323
- logger.info(f"Transcription: {transcription}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  return jsonify({"status": "success", "transcription": transcription})
326
  except Exception as e:
@@ -374,4 +612,3 @@ def internal_error(error):
374
 
375
  if __name__ == "__main__":
376
  app.run(host="0.0.0.0", port=7860, threaded=False, processes=1)
377
-
 
1
  from flask import Flask, request, jsonify, send_from_directory, abort
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
+ from transformers import Wav2Vec2Processor, AutoTokenizer, AutoModelForTokenClassification
4
  import librosa
5
  import torch
6
  import numpy as np
 
11
  import uuid
12
  import logging
13
  from flask_cors import CORS
14
+ import re
15
  import threading
16
  import werkzeug
17
  import tempfile
 
28
  # ---------------------------
29
  MAX_THREADS = 2 # <-- change this number to control all thread usage
30
 
 
 
 
 
 
 
 
31
  # Limit NumPy / BLAS / MKL threads
32
  os.environ["OMP_NUM_THREADS"] = str(MAX_THREADS)
33
  os.environ["OPENBLAS_NUM_THREADS"] = str(MAX_THREADS)
 
61
  voice_name = 'am_adam' # Example voice: af (adjust as needed)
62
 
63
  # Directory to serve files from
64
+ default_serve_dir = os.path.join(os.path.expanduser("~"), "app", "files")
65
+ SERVE_DIR = os.environ.get("SERVE_DIR", default_serve_dir)
66
 
67
  os.makedirs(SERVE_DIR, exist_ok=True)
68
  def validate_audio_file(file):
 
134
  file_cache[cached_file_path] = exists # Update the cache
135
  return exists
136
 
137
+ use_wav2vec2 = os.environ.get("USE_WAV2VEC2", "").lower() in {"1", "true", "yes", "on"}
138
+ ASR_ENGINE = os.environ.get("ASR_ENGINE", "wav2vec2_onnx" if use_wav2vec2 else "whisper_pt").lower()
139
+ ASR_MODEL_NAME = os.environ.get("ASR_MODEL_NAME", "facebook/wav2vec2-base-960h")
140
+ ASR_ONNX_REPO = os.environ.get("ASR_ONNX_REPO", "onnx-community/wav2vec2-base-960h-ONNX")
141
+ PUNCTUATE_TEXT = os.environ.get("PUNCTUATE_TEXT", "0").lower() in {"1", "true", "yes", "on"}
142
+ TECH_NORMALIZE = os.environ.get("TECH_NORMALIZE", "0").lower() in {"1", "true", "yes", "on"}
143
+ PUNCTUATION_MODEL = os.environ.get("PUNCTUATION_MODEL", "kredor/punctuate-all")
144
+
145
  # Initialize models
146
  def initialize_models():
147
+ global sess, voice_style, processor, whisper_model, asr_session, asr_processor
148
+ global punctuation_model, punctuation_tokenizer
149
 
150
  try:
151
  # Download the ONNX model if not already downloaded
 
185
  voice_style = np.fromfile(voice_style_path, dtype=np.float32).reshape(-1, 1, 256)
186
  logger.info(f"Voice style vector loaded successfully from {voice_style_path}")
187
 
188
+ # Initialize ASR engine
189
+ if ASR_ENGINE == "wav2vec2_onnx":
190
+ logger.info(f"Loading Wav2Vec2 ONNX ASR model ({ASR_MODEL_NAME})...")
191
+ # Load processor for feature extraction + CTC labels
192
+ asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL_NAME)
193
+
194
+ # Try to locate/download ONNX model; if not present, download a ready-made ONNX repo.
195
+ default_onnx_path = f"asr_onnx/{ASR_MODEL_NAME.replace('/', '_')}.onnx"
196
+ asr_onnx_path_env = os.environ.get("ASR_ONNX_PATH", default_onnx_path)
197
+ if not os.path.exists(asr_onnx_path_env):
198
+ logger.info(f"ASR ONNX not found at {asr_onnx_path_env}. Attempting to download from {ASR_ONNX_REPO}...")
199
+ try:
200
+ cache_dir = os.environ.get("ASR_ONNX_CACHE_DIR", "asr_onnx_cache")
201
+ repo_dir = snapshot_download(ASR_ONNX_REPO, cache_dir=cache_dir)
202
+ # Look for common ONNX filenames
203
+ onnx_path = None
204
+ for root, _, files in os.walk(repo_dir):
205
+ for cand in ["model.onnx", "wav2vec2.onnx", "onnx/model.onnx"]:
206
+ if cand in files:
207
+ onnx_path = os.path.join(root, cand if cand != "onnx/model.onnx" else "model.onnx")
208
+ break
209
+ if onnx_path:
210
+ break
211
+ if not onnx_path:
212
+ # Fallback: pick first .onnx file found
213
+ for root, _, files in os.walk(repo_dir):
214
+ for f in files:
215
+ if f.endswith(".onnx"):
216
+ onnx_path = os.path.join(root, f)
217
+ break
218
+ if onnx_path:
219
+ break
220
+ if not onnx_path:
221
+ raise FileNotFoundError("No .onnx file found in downloaded repo")
222
+ os.makedirs(os.path.dirname(asr_onnx_path_env), exist_ok=True)
223
+ # Copy to stable location
224
+ import shutil
225
+ shutil.copyfile(onnx_path, asr_onnx_path_env)
226
+ logger.info(f"Downloaded ASR ONNX to {asr_onnx_path_env}")
227
+ except Exception as de:
228
+ logger.error(f"Failed to download ASR ONNX: {de}")
229
+ logger.warning("Falling back to Whisper PT engine.")
230
+ raise
231
+ asr_session = InferenceSession(asr_onnx_path_env, sess_options)
232
+ logger.info("Wav2Vec2 ONNX ASR model loaded")
233
+ else:
234
+ logger.info("ASR_ENGINE set to whisper_pt; loading Whisper model...")
235
+ processor = WhisperProcessor.from_pretrained("openai/whisper-base")
236
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
237
+ whisper_model.config.forced_decoder_ids = None
238
+ logger.info("Whisper model loaded successfully")
239
+
240
+ if PUNCTUATE_TEXT:
241
+ logger.info(f"Loading punctuation model ({PUNCTUATION_MODEL})...")
242
+ punctuation_tokenizer = AutoTokenizer.from_pretrained(PUNCTUATION_MODEL)
243
+ punctuation_model = AutoModelForTokenClassification.from_pretrained(PUNCTUATION_MODEL)
244
+ punctuation_model.eval()
245
+ logger.info("Punctuation model loaded successfully")
246
 
247
  except Exception as e:
248
  logger.error(f"Error initializing models: {str(e)}")
 
251
  # Initialize models
252
  initialize_models()
253
 
254
+ def restore_punctuation(text, max_words=120):
255
+ if not PUNCTUATE_TEXT:
256
+ return text
257
+ if "punctuation_model" not in globals() or punctuation_model is None:
258
+ return text
259
+ words = text.strip().lower().split()
260
+ if not words:
261
+ return text
262
+
263
+ label_to_punct = {
264
+ "O": "",
265
+ "COMMA": ",",
266
+ "PERIOD": ".",
267
+ "QUESTION": "?",
268
+ "EXCLAMATION": "!",
269
+ "COLON": ":",
270
+ "SEMICOLON": ";",
271
+ }
272
+
273
+ def process_chunk(chunk_words, capitalize_next):
274
+ inputs = punctuation_tokenizer(
275
+ chunk_words,
276
+ is_split_into_words=True,
277
+ return_tensors="pt",
278
+ truncation=True,
279
+ )
280
+ with torch.no_grad():
281
+ logits = punctuation_model(**inputs).logits
282
+ pred_ids = torch.argmax(logits, dim=-1)[0].tolist()
283
+ word_ids = inputs.word_ids()
284
+ last_word = -1
285
+ word_end_labels = {}
286
+ for idx, word_id in enumerate(word_ids):
287
+ if word_id is None:
288
+ continue
289
+ if word_id != last_word:
290
+ last_word = word_id
291
+ word_end_labels[word_id] = pred_ids[idx]
292
+
293
+ decoded = []
294
+ for i, word in enumerate(chunk_words):
295
+ label_id = word_end_labels.get(i)
296
+ label = punctuation_model.config.id2label.get(label_id, "O")
297
+ punct = label_to_punct.get(label, "")
298
+ if capitalize_next and word:
299
+ word = word[0].upper() + word[1:]
300
+ capitalize_next = False
301
+ decoded.append(word + punct)
302
+ if punct in {".", "?", "!"}:
303
+ capitalize_next = True
304
+ return " ".join(decoded), capitalize_next
305
+
306
+ out_parts = []
307
+ capitalize_next = True
308
+ for i in range(0, len(words), max_words):
309
+ chunk = words[i:i + max_words]
310
+ chunk_text, capitalize_next = process_chunk(chunk, capitalize_next)
311
+ out_parts.append(chunk_text)
312
+ return " ".join(out_parts).strip()
313
+
314
+ def normalize_tech_text(text):
315
+ """
316
+ Normalize spoken "tech" tokens (dot/com/slash/etc.) into symbols.
317
+ Intended for wav2vec2 output; Whisper already handles this better.
318
+ """
319
+ normalized = text
320
+
321
+ # Common domain suffixes
322
+ normalized = re.sub(r"\bdot com\b", ".com", normalized, flags=re.IGNORECASE)
323
+ normalized = re.sub(r"\bdot come\b", ".com", normalized, flags=re.IGNORECASE)
324
+ normalized = re.sub(r"\bdot comm\b", ".com", normalized, flags=re.IGNORECASE)
325
+ normalized = re.sub(r"\bdot net\b", ".net", normalized, flags=re.IGNORECASE)
326
+ normalized = re.sub(r"\bdot org\b", ".org", normalized, flags=re.IGNORECASE)
327
+ normalized = re.sub(r"\bdot io\b", ".io", normalized, flags=re.IGNORECASE)
328
+ normalized = re.sub(r"\bdot ai\b", ".ai", normalized, flags=re.IGNORECASE)
329
+ normalized = re.sub(r"\bdot co\b", ".co", normalized, flags=re.IGNORECASE)
330
+ normalized = re.sub(r"\bdot uk\b", ".uk", normalized, flags=re.IGNORECASE)
331
+ normalized = re.sub(r"\bdot dev\b", ".dev", normalized, flags=re.IGNORECASE)
332
+ normalized = re.sub(r"\bdot local\b", ".local", normalized, flags=re.IGNORECASE)
333
+ normalized = re.sub(r"\\.\\s+(com|net|org|io|ai|co|uk|dev|local)\\b", r".\\1", normalized, flags=re.IGNORECASE)
334
+ normalized = re.sub(r"(\\w)\\s+\\.(com|net|org|io|ai|co|uk|dev|local)\\b", r"\\1.\\2", normalized, flags=re.IGNORECASE)
335
+
336
+ # Symbols between tokens
337
+ normalized = re.sub(r"(?<=\\w)\\s+dot\\s+(?=\\w)", ".", normalized, flags=re.IGNORECASE)
338
+ normalized = re.sub(r"(?<=\\w)\\s+at\\s+(?=\\w)", "@", normalized, flags=re.IGNORECASE)
339
+ normalized = re.sub(r"(?<=\\w)\\s+colon\\s+(?=\\w)", ":", normalized, flags=re.IGNORECASE)
340
+ normalized = re.sub(r"(?<=\\w)\\s+dash\\s+(?=\\w)", "-", normalized, flags=re.IGNORECASE)
341
+ normalized = re.sub(r"(?<=\\w)\\s+hyphen\\s+(?=\\w)", "-", normalized, flags=re.IGNORECASE)
342
+ normalized = re.sub(r"\\bhyphen\\b", "-", normalized, flags=re.IGNORECASE)
343
+ normalized = re.sub(r"\\bunderscore\\b", "_", normalized, flags=re.IGNORECASE)
344
+
345
+ # Slashes
346
+ normalized = re.sub(r"\\bback\\s+slash\\b", r"\\\\", normalized, flags=re.IGNORECASE)
347
+ normalized = re.sub(r"\\bbackslash\\b", r"\\\\", normalized, flags=re.IGNORECASE)
348
+ normalized = re.sub(r"\\bbash\\b", r"\\\\", normalized, flags=re.IGNORECASE)
349
+ normalized = re.sub(r"\\bforward\\s+slash\\b", "/", normalized, flags=re.IGNORECASE)
350
+ normalized = re.sub(r"\\bslash\\b", "/", normalized, flags=re.IGNORECASE)
351
+
352
+ # Spoken punctuation tokens
353
+ normalized = re.sub(r"\\bcomma\\b", ",", normalized, flags=re.IGNORECASE)
354
+ normalized = re.sub(r"\\bperiod\\b", ".", normalized, flags=re.IGNORECASE)
355
+ normalized = re.sub(r"\\bquestion\\s+mark\\b", "?", normalized, flags=re.IGNORECASE)
356
+ normalized = re.sub(r"\\bexclamation\\s+point\\b", "!", normalized, flags=re.IGNORECASE)
357
+ normalized = re.sub(r"\\bexclamation\\s+mark\\b", "!", normalized, flags=re.IGNORECASE)
358
+ normalized = re.sub(r"\\bhash\\b", "#", normalized, flags=re.IGNORECASE)
359
+
360
+ # Collapse sequences of spoken digits into numbers (useful for IPs/ports).
361
+ num_map = {
362
+ "zero": "0",
363
+ "oh": "0",
364
+ "one": "1",
365
+ "two": "2",
366
+ "three": "3",
367
+ "four": "4",
368
+ "five": "5",
369
+ "six": "6",
370
+ "seven": "7",
371
+ "eight": "8",
372
+ "nine": "9",
373
+ }
374
+ parts = normalized.split()
375
+ out = []
376
+ buffer = []
377
+ for token in parts:
378
+ lower = token.lower()
379
+ if lower in num_map:
380
+ buffer.append(num_map[lower])
381
+ continue
382
+ if lower == ".":
383
+ buffer.append(".")
384
+ continue
385
+ if lower == "dot":
386
+ buffer.append(".")
387
+ continue
388
+ if buffer:
389
+ out.append("".join(buffer))
390
+ buffer = []
391
+ out.append(token)
392
+ if buffer:
393
+ out.append("".join(buffer))
394
+ normalized = " ".join(out)
395
+
396
+ return normalized
397
+
398
  # Health check endpoint
399
  @app.route('/health', methods=['GET'])
400
  def health_check():
 
511
  logger.debug("Processing audio for transcription...")
512
  audio_array, sampling_rate = librosa.load(converted_audio_path, sr=16000)
513
 
514
+ if ASR_ENGINE == "wav2vec2_onnx" and 'asr_session' in globals() and asr_session is not None:
515
+ # Prepare input for Wav2Vec2 ONNX: float32 PCM, shape (batch, samples)
516
+ inputs = asr_processor(audio_array, sampling_rate=16000, return_tensors="np")
517
+ # Some exports expect input as (batch, sequence); adjust key as needed
518
+ ort_inputs = {}
519
+ # Common input name variants
520
+ for name in ["input_values", "input_features", "inputs"]:
521
+ if name in [i.name for i in asr_session.get_inputs()]:
522
+ ort_inputs[name] = inputs["input_values"].astype(np.float32)
523
+ break
524
+ else:
525
+ # Fall back to first input name
526
+ first_name = asr_session.get_inputs()[0].name
527
+ ort_inputs[first_name] = inputs["input_values"].astype(np.float32)
528
+
529
+ logits = asr_session.run(None, ort_inputs)[0] # (batch, time, vocab)
530
+ # Greedy CTC decode
531
+ pred_ids = np.argmax(logits, axis=-1)
532
+ # Collapse repeats and remove CTC blank (id 0 for many models; rely on processor)
533
+ transcription = asr_processor.batch_decode(pred_ids)[0]
534
+ transcription = transcription.strip()
535
+ logger.info(f"Transcription (Wav2Vec2 ONNX): {transcription}")
536
+ else:
537
+ # Whisper fallback
538
+ input_features = processor(
539
+ audio_array,
540
+ sampling_rate=sampling_rate,
541
+ return_tensors="pt"
542
+ ).input_features
543
+
544
+ logger.debug("Generating transcription (Whisper)...")
545
+ predicted_ids = whisper_model.generate(input_features)
546
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
547
+ logger.info(f"Transcription (Whisper): {transcription}")
548
+
549
+ if PUNCTUATE_TEXT:
550
+ try:
551
+ transcription = restore_punctuation(transcription)
552
+ logger.info(f"Transcription (Punctuated): {transcription}")
553
+ except Exception as pe:
554
+ logger.warning(f"Punctuation restore failed: {pe}")
555
+
556
+ if TECH_NORMALIZE:
557
+ try:
558
+ transcription = normalize_tech_text(transcription)
559
+ logger.info(f"Transcription (Normalized): {transcription}")
560
+ except Exception as ne:
561
+ logger.warning(f"Tech normalization failed: {ne}")
562
 
563
  return jsonify({"status": "success", "transcription": transcription})
564
  except Exception as e:
 
612
 
613
  if __name__ == "__main__":
614
  app.run(host="0.0.0.0", port=7860, threaded=False, processes=1)
 
tts_processor.py CHANGED
@@ -16,6 +16,129 @@ alphabet_map = {
16
  "U": " You ", "V": " Vee ", "W": " Double You ", "X": " Ex ", "Y": " Why ", "Z": " Zed "
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Function to add ordinal suffix to a number
20
  def add_ordinal_suffix(day):
21
  """Adds ordinal suffix to a day (e.g., 13 -> 13th)."""
@@ -82,20 +205,26 @@ def replace_invalid_chars(string):
82
 
83
  # Replace numbers with their word equivalents
84
  def replace_numbers(string):
85
- ipv4_pattern = r'(\b\d{1,3}(\.\d{1,3}){3}\b)'
86
- ipv6_pattern = r'([0-9a-fA-F]{1,4}:){2,7}[0-9a-fA-F]{1,4}'
87
  range_pattern = r'\b\d+-\d+\b' # Detect ranges like 1-4
88
  date_pattern = r'\b\d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}:\d{2})?\b'
89
  alphanumeric_pattern = r'\b[A-Za-z]+\d+|\d+[A-Za-z]+\b'
90
 
91
- # Do not process IP addresses, date patterns, or alphanumerics
92
- if re.search(ipv4_pattern, string) or re.search(ipv6_pattern, string) or re.search(range_pattern, string) or re.search(date_pattern, string) or re.search(alphanumeric_pattern, string):
93
- return string
 
 
 
 
94
 
95
  # Convert standalone numbers and port numbers
96
  def convert_number(match):
97
  number = match.group()
98
- return num2words(int(number)) if number.isdigit() else number
 
 
99
 
100
  pattern = re.compile(r'\b\d+\b')
101
  return re.sub(pattern, convert_number, string)
@@ -133,11 +262,163 @@ def make_dots_tts_friendly(text):
133
 
134
  return text
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  # Main preprocessing pipeline
137
  def preprocess_all(string):
138
  string = normalize_dates(string)
139
  string = replace_invalid_chars(string)
140
  string = replace_numbers(string)
 
141
  string = replace_abbreviations(string)
142
  string = make_dots_tts_friendly(string)
143
  string = clean_whitespace(string)
@@ -160,4 +441,3 @@ if __name__ == "__main__":
160
  test_preprocessing(test_file)
161
  else:
162
  print("Please provide a file path as an argument.")
163
-
 
16
  "U": " You ", "V": " Vee ", "W": " Double You ", "X": " Ex ", "Y": " Why ", "Z": " Zed "
17
  }
18
 
19
+ TECH_ACRONYM_REPLACEMENTS = [
20
+ (r"\bhttps\b", "H T T P S"),
21
+ (r"\bhttp\b", "H T T P"),
22
+ (r"\bssh\b", "S S H"),
23
+ (r"\bdns\b", "D N S"),
24
+ (r"\bntp\b", "N T P"),
25
+ (r"\bsnmp\b", "S N M P"),
26
+ (r"\btcp\b", "T C P"),
27
+ (r"\budp\b", "U D P"),
28
+ (r"\bicmp\b", "I C M P"),
29
+ (r"\bip\b", "I P"),
30
+ (r"\bipv4\b", "I P v four"),
31
+ (r"\bipv6\b", "I P v six"),
32
+ (r"\btls\b", "T L S"),
33
+ (r"\bssl\b", "S S L"),
34
+ (r"\brdp\b", "R D P"),
35
+ (r"\bsql\b", "sequel"),
36
+ (r"\bapi\b", "A P I"),
37
+ (r"\buid\b", "U I D"),
38
+ (r"\bgpu\b", "G P U"),
39
+ (r"\bcpu\b", "C P U"),
40
+ (r"\bram\b", "R A M"),
41
+ (r"\bttl\b", "T T L"),
42
+ (r"\brtt\b", "R T T"),
43
+ (r"\bbgp\b", "B G P"),
44
+ (r"\bospf\b", "O S P F"),
45
+ (r"\bospfv2\b", "O S P F v two"),
46
+ (r"\bospfv3\b", "O S P F v three"),
47
+ (r"\bis-is\b", "I S I S"),
48
+ (r"\brip\b", "R I P"),
49
+ (r"\bdhcp\b", "D H C P"),
50
+ (r"\barp\b", "A R P"),
51
+ (r"\bndp\b", "N D P"),
52
+ (r"\bnat\b", "N A T"),
53
+ (r"\bpat\b", "P A T"),
54
+ (r"\bgre\b", "G R E"),
55
+ (r"\bvrrp\b", "V R R P"),
56
+ (r"\bhsrp\b", "H S R P"),
57
+ (r"\bglbp\b", "G L B P"),
58
+ (r"\bstp\b", "S T P"),
59
+ (r"\brstp\b", "R S T P"),
60
+ (r"\bmstp\b", "M S T P"),
61
+ (r"\blldp\b", "L L D P"),
62
+ (r"\bcdp\b", "C D P"),
63
+ (r"\bldap\b", "ell dap"),
64
+ (r"\bsaml\b", "sam el"),
65
+ (r"\boauth\b", "oh auth"),
66
+ (r"\boidc\b", "O I D C"),
67
+ (r"\bsso\b", "S S O"),
68
+ (r"\bsmtp\b", "S M T P"),
69
+ (r"\bimap\b", "I M A P"),
70
+ (r"\bpop3\b", "P O P three"),
71
+ (r"\bpop\b", "P O P"),
72
+ (r"\bftp\b", "F T P"),
73
+ (r"\bsftp\b", "S F T P"),
74
+ (r"\bftps\b", "F T P S"),
75
+ (r"\btftp\b", "T F T P"),
76
+ (r"\bmqtt\b", "M Q T T"),
77
+ (r"\bamqp\b", "A M Q P"),
78
+ (r"\bcoap\b", "C O A P"),
79
+ (r"\bquic\b", "Q U I C"),
80
+ (r"\bgrpc\b", "gee R P C"),
81
+ (r"\bsoap\b", "S O A P"),
82
+ (r"\bjson\b", "jay son"),
83
+ (r"\byaml\b", "yam el"),
84
+ (r"\bxml\b", "ex em el"),
85
+ (r"\bwebsocket\b", "web socket"),
86
+ (r"\bwss\b", "W S S"),
87
+ (r"\bws\b", "W S"),
88
+ (r"\bicmpv6\b", "I C M P v six"),
89
+ (r"\bntlm\b", "N T L M"),
90
+ (r"\bpki\b", "P K I"),
91
+ (r"\bcsr\b", "C S R"),
92
+ (r"\bcrt\b", "C R T"),
93
+ (r"\bca\b", "C A"),
94
+ (r"\bwan\b", "W A N"),
95
+ (r"\blan\b", "L A N"),
96
+ (r"\bvlan\b", "V L A N"),
97
+ (r"\bvxlan\b", "V X L A N"),
98
+ (r"\bqos\b", "Q O S"),
99
+ (r"\bmtu\b", "M T U"),
100
+ (r"\bpoe\b", "P O E"),
101
+ (r"\bpoe\+", "P O E plus"),
102
+ (r"\bvrf\b", "V R F"),
103
+ (r"\bacl\b", "A C L"),
104
+ (r"\bnat64\b", "N A T sixty four"),
105
+ (r"\bdsr\b", "D S R"),
106
+ (r"\bsiem\b", "S I E M"),
107
+ (r"\bids\b", "I D S"),
108
+ (r"\bips\b", "I P S"),
109
+ (r"\bedr\b", "E D R"),
110
+ (r"\bxdr\b", "X D R"),
111
+ (r"\bsoc\b", "S O C"),
112
+ (r"\bmdr\b", "M D R"),
113
+ (r"\bndr\b", "N D R"),
114
+ (r"\bav\b", "A V"),
115
+ (r"\bendpoint\b", "end point"),
116
+ (r"\bsaas\b", "S A A S"),
117
+ (r"\biaas\b", "I A A S"),
118
+ (r"\bpaas\b", "P A A S"),
119
+ (r"\bdlp\b", "D L P"),
120
+ (r"\bmfa\b", "M F A"),
121
+ (r"\b2fa\b", "two F A"),
122
+ (r"\b3fa\b", "three F A"),
123
+ (r"\bmd5\b", "M D five"),
124
+ (r"\bsha1\b", "sha one"),
125
+ (r"\bsha256\b", "sha two five six"),
126
+ (r"\bsha512\b", "sha five one two"),
127
+ (r"\baes\b", "A E S"),
128
+ (r"\baes-?gcm\b", "A E S G C M"),
129
+ (r"\baes-?cbc\b", "A E S C B C"),
130
+ (r"\brsa\b", "R S A"),
131
+ (r"\becdsa\b", "E C D S A"),
132
+ (r"\bed25519\b", "E D two five five one nine"),
133
+ (r"\bjwt\b", "J W T"),
134
+ (r"\bsshd\b", "S S H D"),
135
+ (r"\bntp\d?\b", "N T P"),
136
+ (r"\bntp\s+server\b", "N T P server"),
137
+ (r"\bntp\s+pool\b", "N T P pool"),
138
+ (r"\bhttpd\b", "H T T P D"),
139
+ (r"\bnginx\b", "engine x"),
140
+ ]
141
+
142
  # Function to add ordinal suffix to a number
143
  def add_ordinal_suffix(day):
144
  """Adds ordinal suffix to a day (e.g., 13 -> 13th)."""
 
205
 
206
  # Replace numbers with their word equivalents
207
  def replace_numbers(string):
208
+ ipv4_pattern = r'\b\d{1,3}(?:\.\d{1,3}){3}\b'
209
+ ipv6_pattern = r'\b(?:[0-9a-fA-F]{1,4}:){2,7}[0-9a-fA-F]{1,4}\b'
210
  range_pattern = r'\b\d+-\d+\b' # Detect ranges like 1-4
211
  date_pattern = r'\b\d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}:\d{2})?\b'
212
  alphanumeric_pattern = r'\b[A-Za-z]+\d+|\d+[A-Za-z]+\b'
213
 
214
+ skip_spans = []
215
+ for pattern in [ipv4_pattern, ipv6_pattern, range_pattern, date_pattern, alphanumeric_pattern]:
216
+ for match in re.finditer(pattern, string):
217
+ skip_spans.append((match.start(), match.end()))
218
+
219
+ def is_skipped(start, end):
220
+ return any(start >= s and end <= e for s, e in skip_spans)
221
 
222
  # Convert standalone numbers and port numbers
223
  def convert_number(match):
224
  number = match.group()
225
+ if is_skipped(match.start(), match.end()):
226
+ return number
227
+ return num2words(int(number)).replace("-", " ") if number.isdigit() else number
228
 
229
  pattern = re.compile(r'\b\d+\b')
230
  return re.sub(pattern, convert_number, string)
 
262
 
263
  return text
264
 
265
+ def apply_replacements(value, replacements):
266
+ for pattern, replacement in replacements:
267
+ value = re.sub(pattern, replacement, value, flags=re.IGNORECASE)
268
+ return value
269
+
270
+ def tech_humanize(text):
271
+ """
272
+ Humanize technical tokens (URLs, emails, UUIDs, MACs, paths) for TTS.
273
+ Keep outputs ASCII and TTS-friendly.
274
+ """
275
+ def spell_chars(token):
276
+ return " ".join(list(token))
277
+
278
+ def normalize_url(match):
279
+ url = match.group(0)
280
+ url = url.replace("https://", "HTTPS://").replace("http://", "HTTP://")
281
+ url = url.replace("://", " colon slash slash ")
282
+ url = url.replace("/", " forward slash ")
283
+ url = url.replace("?", " question mark ")
284
+ url = url.replace("&", " and ")
285
+ url = url.replace("=", " equals ")
286
+ url = url.replace("#", " hash ")
287
+ url = url.replace("_", " underscore ")
288
+ url = url.replace("-", " dash ")
289
+ url = url.replace(".", " dot ")
290
+ return url
291
+
292
+ def normalize_email(match):
293
+ email = match.group(0)
294
+ email = email.replace("@", " at ")
295
+ email = email.replace(".", " dot ")
296
+ email = email.replace("_", " underscore ")
297
+ email = email.replace("-", " dash ")
298
+ return email
299
+
300
+ def normalize_uuid(match):
301
+ uuid_text = match.group(0)
302
+ groups = uuid_text.split("-")
303
+ spelled = [" ".join(list(group)) for group in groups]
304
+ return " dash ".join(spelled)
305
+
306
+ def normalize_mac(match):
307
+ mac_text = match.group(0)
308
+ groups = mac_text.split(":")
309
+ spelled = [" ".join(list(group)) for group in groups]
310
+ return " colon ".join(spelled)
311
+
312
+ def normalize_ipv6(match):
313
+ ipv6_text = match.group(0)
314
+ groups = ipv6_text.split(":")
315
+ spelled = [" ".join(list(group)) for group in groups if group]
316
+ return " colon ".join(spelled)
317
+
318
+ def normalize_ipv6_compact(match):
319
+ ipv6_text = match.group(0)
320
+ left, _, right = ipv6_text.partition("::")
321
+ left_groups = [g for g in left.split(":") if g]
322
+ right_groups = [g for g in right.split(":") if g]
323
+ left_spelled = [" ".join(list(group)) for group in left_groups]
324
+ right_spelled = [" ".join(list(group)) for group in right_groups]
325
+ middle = " double colon "
326
+ left_part = " colon ".join(left_spelled)
327
+ right_part = " colon ".join(right_spelled)
328
+ if left_part and right_part:
329
+ return f"{left_part}{middle}{right_part}"
330
+ if left_part:
331
+ return f"{left_part}{middle}"
332
+ return f"{middle}{right_part}"
333
+
334
+ def normalize_mac_dash(match):
335
+ mac_text = match.group(0)
336
+ groups = mac_text.split("-")
337
+ spelled = [" ".join(list(group)) for group in groups]
338
+ return " dash ".join(spelled)
339
+
340
+ def normalize_hex(match):
341
+ hex_text = match.group(1)
342
+ return "hex " + " ".join(list(hex_text))
343
+
344
+ def normalize_cve(match):
345
+ year = match.group(1)
346
+ ident = match.group(2)
347
+ return f"C V E {year} dash {ident}"
348
+
349
+ # URLs and emails (do this early before protocol expansions)
350
+ text = re.sub(r"\bhttps?://[^\s]+", normalize_url, text, flags=re.IGNORECASE)
351
+ text = re.sub(r"\b[\w.+-]+@[\w.-]+\.\w+\b", normalize_email, text)
352
+
353
+ # Version tokens like TLS1.3 or HTTP/2
354
+ text = re.sub(r"\b(tls|ssl)\s*(\d+(?:\.\d+)?)\b", lambda m: f"{m.group(1).upper()} {m.group(2).replace('.', ' point ')}", text, flags=re.IGNORECASE)
355
+ text = re.sub(r"\bhttps?/(\d+(?:\.\d+)?)\b", lambda m: f"H T T P slash {m.group(1).replace('.', ' point ')}", text, flags=re.IGNORECASE)
356
+
357
+ # Common protocol tokens (force letter-by-letter)
358
+ text = apply_replacements(text, TECH_ACRONYM_REPLACEMENTS)
359
+
360
+ # Hex values and CVEs
361
+ text = re.sub(r"\b0x([0-9A-Fa-f]+)\b", normalize_hex, text)
362
+ text = re.sub(r"\bCVE-(\d{4})-(\d{4,7})\b", normalize_cve, text)
363
+
364
+ # Interfaces like eth0, wlan0, en0, lo0
365
+ text = re.sub(r"\b(eth|wlan|en|lo)(\d+)\b", lambda m: f"{m.group(1)} {m.group(2)}", text, flags=re.IGNORECASE)
366
+
367
+ # UUIDs, MACs, IPv6
368
+ text = re.sub(r"\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b", normalize_uuid, text)
369
+ text = re.sub(r"\b(?:[0-9A-Fa-f]{2}:){5}[0-9A-Fa-f]{2}\b", normalize_mac, text)
370
+ text = re.sub(r"\b(?:[0-9A-Fa-f]{2}-){5}[0-9A-Fa-f]{2}\b", normalize_mac_dash, text)
371
+ text = re.sub(r"\b(?:[0-9A-Fa-f]{1,4}:){2,7}[0-9A-Fa-f]{1,4}\b", normalize_ipv6, text)
372
+ text = re.sub(r"\b[0-9A-Fa-f:]*::[0-9A-Fa-f:]*\b", normalize_ipv6_compact, text)
373
+
374
+ # Acronym/acroynm like TCP/IP -> "TCP slash IP"
375
+ text = re.sub(r"\b([A-Z]{2,})\s*/\s*([A-Z]{2,})\b", r"\1 slash \2", text)
376
+ # Word/word patterns like this/that -> "this or that"
377
+ text = re.sub(r"\b([A-Za-z]+)\s*/\s*([A-Za-z]+)\b", r"\1 or \2", text)
378
+
379
+ # Common separators in paths/flags
380
+ text = re.sub(r"(?<=\w)/(?!\s)", " forward slash ", text)
381
+ text = re.sub(r"\\", " backslash ", text)
382
+ text = re.sub(r"(?<=\w)-(?=\w)", " dash ", text)
383
+ text = re.sub(r"(?<=\w)_(?=\w)", " underscore ", text)
384
+ text = re.sub(r"(?<=\w):(?=\w)", " colon ", text)
385
+ text = re.sub(r"--", " double dash ", text)
386
+ text = re.sub(r"->", " arrow ", text)
387
+ text = re.sub(r"=>", " arrow ", text)
388
+ text = re.sub(r"\b(\d+)%\b", r"\1 percent", text)
389
+
390
+ # Versions like v1.2.3 -> v 1 point 2 point 3
391
+ text = re.sub(r"\bv(\d+(?:\.\d+)+)\b", lambda m: "v " + m.group(1).replace(".", " point "), text, flags=re.IGNORECASE)
392
+ text = re.sub(r"\b(\d+\.\d+\.\d+)\b", lambda m: m.group(1).replace(".", " point "), text)
393
+
394
+ # Units and rates
395
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*kbps\b", r"\1 kilobits per second", text, flags=re.IGNORECASE)
396
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*mbps\b", r"\1 megabits per second", text, flags=re.IGNORECASE)
397
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*gbps\b", r"\1 gigabits per second", text, flags=re.IGNORECASE)
398
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*tbps\b", r"\1 terabits per second", text, flags=re.IGNORECASE)
399
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*kb\b", r"\1 kilobytes", text, flags=re.IGNORECASE)
400
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*mb\b", r"\1 megabytes", text, flags=re.IGNORECASE)
401
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*gb\b", r"\1 gigabytes", text, flags=re.IGNORECASE)
402
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*tb\b", r"\1 terabytes", text, flags=re.IGNORECASE)
403
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*mhz\b", r"\1 mega hertz", text, flags=re.IGNORECASE)
404
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*ghz\b", r"\1 giga hertz", text, flags=re.IGNORECASE)
405
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*ms\b", r"\1 milliseconds", text, flags=re.IGNORECASE)
406
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*us\b", r"\1 microseconds", text, flags=re.IGNORECASE)
407
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*ns\b", r"\1 nanoseconds", text, flags=re.IGNORECASE)
408
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*s\b", r"\1 seconds", text)
409
+ text = re.sub(r"\b(\d+(?:\.\d+)?)\s*min\b", r"\1 minutes", text, flags=re.IGNORECASE)
410
+
411
+ # Optional plural markers like domain(s) -> "domain or domains"
412
+ text = re.sub(r"\b([A-Za-z]+)\(s\)(?!\w)", r"\1 or \1s", text)
413
+
414
+ return text
415
+
416
  # Main preprocessing pipeline
417
  def preprocess_all(string):
418
  string = normalize_dates(string)
419
  string = replace_invalid_chars(string)
420
  string = replace_numbers(string)
421
+ string = tech_humanize(string)
422
  string = replace_abbreviations(string)
423
  string = make_dots_tts_friendly(string)
424
  string = clean_whitespace(string)
 
441
  test_preprocessing(test_file)
442
  else:
443
  print("Please provide a file path as an argument.")