staraks commited on
Commit
8174d48
·
verified ·
1 Parent(s): 77f67b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -124
app.py CHANGED
@@ -1,20 +1,6 @@
1
  # app.py
2
- # Whisper Transcriber — improved with:
3
- # - Memory ZIP preview + selective import
4
- # - faster-whisper auto-detect + multiprocessing batch transcription
5
- # - progress streaming generator + per-file transcript ZIP
6
- #
7
- # Requirements:
8
- # - gradio (3.x)
9
- # - python-docx
10
- # - pydub
11
- # - pyzipper
12
- # - ffmpeg installed on system
13
- # - whisper OR faster-whisper (faster-whisper recommended for CPU speed)
14
- #
15
- # Notes:
16
- # - Multiprocessing may load model per worker (memory heavy). Tune MAX_WORKERS.
17
- # - This file is designed to be dropped into your existing project and run.
18
 
19
  import os
20
  import sys
@@ -33,36 +19,30 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
33
  import multiprocessing
34
  import time
35
 
36
- # Gradio & model libs
37
  try:
38
  import gradio as gr
39
  except Exception as e:
40
  print("FATAL: gradio import failed:", e)
41
  raise
42
 
43
- # Try to import faster_whisper first, fallback to openai/whisper
44
  USE_FASTER_WHISPER = False
45
  try:
46
  from faster_whisper import WhisperModel as FasterWhisperModel
47
-
48
  USE_FASTER_WHISPER = True
49
  print("INFO: faster-whisper available — will use it for faster CPU inference.")
50
  except Exception:
51
  try:
52
- import whisper # fallback
53
  except Exception:
54
  print("FATAL: Neither faster-whisper nor whisper available.")
55
  raise
56
 
57
- # Audio processing
58
  from pydub import AudioSegment
59
  import pyzipper
60
  from docx import Document
61
 
62
- # Force unbuffered prints for logs
63
  os.environ["PYTHONUNBUFFERED"] = "1"
64
 
65
- # ---------- Config ----------
66
  MEMORY_FILE = "memory.json"
67
  MEMORY_LOCK = threading.Lock()
68
  MIN_WAV_SIZE = 1024
@@ -73,16 +53,13 @@ FFMPEG_CANDIDATES = [
73
  ("pcm_s16le", 44100, 2),
74
  ("mulaw", 8000, 1),
75
  ]
76
- MODEL_CACHE = {} # name -> model instance (only for main process, workers load models separately)
77
- EXTRACT_MAP = {} # friendly_name -> absolute path (per-run)
78
  DEFAULT_ZIP_PASS = "dietcoke1"
79
 
80
- # Multiprocessing tuning: set a sensible default
81
  CPU_COUNT = max(1, multiprocessing.cpu_count())
82
- MAX_WORKERS = min(4, CPU_COUNT) # adjust as needed; each worker loads a model (memory heavy)
83
-
84
 
85
- # ---------- Memory helpers ----------
86
  def load_memory():
87
  try:
88
  if os.path.exists(MEMORY_FILE):
@@ -103,7 +80,6 @@ def load_memory():
103
  pass
104
  return mem
105
 
106
-
107
  def save_memory(mem):
108
  with MEMORY_LOCK:
109
  try:
@@ -112,11 +88,8 @@ def save_memory(mem):
112
  except Exception:
113
  traceback.print_exc()
114
 
115
-
116
  memory = load_memory()
117
 
118
-
119
- # ---------- Postprocessing ----------
120
  MEDICAL_ABBREVIATIONS = {
121
  "pt": "patient",
122
  "dx": "diagnosis",
@@ -136,7 +109,6 @@ DRUG_NORMALIZATION = {
136
  "amoxicillin": "Amoxicillin",
137
  }
138
 
139
-
140
  def expand_abbreviations(text):
141
  tokens = re.split(r"(\s+)", text)
142
  out = []
@@ -152,13 +124,11 @@ def expand_abbreviations(text):
152
  out.append(t)
153
  return "".join(out)
154
 
155
-
156
  def normalize_drugs(text):
157
  for k, v in DRUG_NORMALIZATION.items():
158
  text = re.sub(rf"\b{k}\b", v, text, flags=re.IGNORECASE)
159
  return text
160
 
161
-
162
  def punctuation_and_capitalization(text):
163
  text = text.strip()
164
  if not text:
@@ -174,7 +144,6 @@ def punctuation_and_capitalization(text):
174
  out.append(p)
175
  return "".join(out)
176
 
177
-
178
  def postprocess_transcript(text):
179
  if not text:
180
  return text
@@ -184,13 +153,11 @@ def postprocess_transcript(text):
184
  t = punctuation_and_capitalization(t)
185
  return t
186
 
187
-
188
  def extract_words_and_phrases(text):
189
  words = re.findall(r"[A-Za-z0-9\-']+", text)
190
  sentences = [s.strip() for s in re.split(r"(?<=[.?!])\s+", text) if s.strip()]
191
  return [w for w in words if w.strip()], sentences
192
 
193
-
194
  def update_memory_with_transcript(transcript):
195
  global memory
196
  words, sentences = extract_words_and_phrases(transcript)
@@ -206,7 +173,6 @@ def update_memory_with_transcript(transcript):
206
  if changed:
207
  save_memory(memory)
208
 
209
-
210
  def memory_correct_text(text, min_ratio=0.85):
211
  if not text or (not memory.get("words") and not memory.get("phrases")):
212
  return text
@@ -240,8 +206,6 @@ def memory_correct_text(text, min_ratio=0.85):
240
  corrected = re.sub(re.escape(phrase), phrase, corrected, flags=re.IGNORECASE)
241
  return corrected
242
 
243
-
244
- # ---------- File & conversion utilities ----------
245
  def save_as_word(text, filename=None):
246
  if filename is None:
247
  filename = os.path.join(tempfile.gettempdir(), "merged_transcripts.docx")
@@ -250,7 +214,6 @@ def save_as_word(text, filename=None):
250
  doc.save(filename)
251
  return filename
252
 
253
-
254
  def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
255
  try:
256
  cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"]
@@ -277,7 +240,6 @@ def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
277
  pass
278
  return False, str(e)
279
 
280
-
281
  def convert_to_wav_if_needed(input_path):
282
  input_path = str(input_path)
283
  lower = input_path.lower()
@@ -305,7 +267,6 @@ def convert_to_wav_if_needed(input_path):
305
  except Exception:
306
  pass
307
 
308
- # ffmpeg fallback
309
  diag_dir = tempfile.mkdtemp(prefix="dct_diag_")
310
  diag_log = os.path.join(diag_dir, "conversion_diagnostics.txt")
311
  diagnostics = []
@@ -361,12 +322,9 @@ def convert_to_wav_if_needed(input_path):
361
 
362
  raise Exception(f"Could not convert file to WAV. Diagnostics saved to: {diag_log}")
363
 
364
-
365
- # ---------- Model utils (main process; workers load locally inside worker fn) ----------
366
  def whisper_available_models():
367
  try:
368
  if USE_FASTER_WHISPER:
369
- # faster-whisper doesn't provide available_models; trust common names
370
  return set(["tiny", "base", "small", "medium", "large", "large-v3"])
371
  else:
372
  models = whisper.available_models()
@@ -376,10 +334,8 @@ def whisper_available_models():
376
  pass
377
  return set(["tiny", "base", "small", "medium", "large", "large-v3"])
378
 
379
-
380
  AVAILABLE_MODEL_SET = whisper_available_models()
381
 
382
-
383
  def safe_model_choices(prefer_default="small"):
384
  base_choices = ["small", "medium", "large", "large-v3", "base", "tiny"]
385
  choices = [m for m in base_choices if m in AVAILABLE_MODEL_SET]
@@ -388,27 +344,17 @@ def safe_model_choices(prefer_default="small"):
388
  default = prefer_default if prefer_default in choices else choices[0]
389
  return choices, default
390
 
391
-
392
- # Worker transcribe function (runs in worker process)
393
  def _worker_transcribe(args):
394
- """
395
- This function is invoked inside a worker process.
396
- args: (file_path, model_name, device_name, enable_memory, generate_srt, use_two_pass, fast_model, refine_threshold)
397
- Returns: dict{ 'file': basename, 'text_path': path, 'srt_path': path or None, 'log': str }
398
- """
399
  try:
400
  (file_path, model_name, device_name, enable_memory, generate_srt, use_two_pass, fast_model, refine_threshold) = args
401
  base = os.path.basename(file_path)
402
  log_lines = []
403
  device = None if device_name == "auto" else device_name
404
 
405
- # Load model inside worker (no sharing)
406
  model = None
407
  use_fw = False
408
  try:
409
  if USE_FASTER_WHISPER:
410
- # faster-whisper uses WhisperModel(model_size_or_path, device=..., compute_type=...)
411
- # Use default compute_type; user can customize code if desired
412
  model = FasterWhisperModel(model_name, device=device if device else "cpu")
413
  use_fw = True
414
  log_lines.append(f"Worker: faster-whisper loaded {model_name}")
@@ -419,7 +365,6 @@ def _worker_transcribe(args):
419
  log_lines.append(f"Worker: whisper loaded {model_name}")
420
  except Exception as e:
421
  log_lines.append(f"Worker model load failed: {e}")
422
- # attempt fallback to small
423
  try:
424
  if USE_FASTER_WHISPER:
425
  model = FasterWhisperModel("small", device=device if device else "cpu")
@@ -432,24 +377,19 @@ def _worker_transcribe(args):
432
  except Exception as e2:
433
  return {"file": base, "text_path": None, "srt_path": None, "log": "Model load failed: " + str(e2)}
434
 
435
- # Convert to WAV
436
  try:
437
  wav = convert_to_wav_if_needed(file_path)
438
  log_lines.append(f"Converted to WAV: {os.path.basename(wav)}")
439
  except Exception as e:
440
  return {"file": base, "text_path": None, "srt_path": None, "log": "Conversion failed: " + str(e)}
441
 
442
- # Transcribe — two modes: faster-whisper usage differs
443
  try:
444
  if use_fw:
445
- # faster-whisper returns (segments, info) via transcribe(..., beam_size=..., vad_filter=False)
446
  segments, info = model.transcribe(wav, beam_size=5)
447
  text = "".join([seg.text for seg in segments]).strip()
448
- # segments: objects with start/end/text
449
  if generate_srt:
450
  srt_text = []
451
  for i, seg in enumerate(segments, start=1):
452
- # seg.start, seg.end, seg.text
453
  start = getattr(seg, "start", 0)
454
  end = getattr(seg, "end", 0)
455
  txt = getattr(seg, "text", "").strip()
@@ -467,14 +407,11 @@ def _worker_transcribe(args):
467
  except Exception as e:
468
  return {"file": base, "text_path": None, "srt_path": None, "log": "Transcription failed: " + str(e)}
469
 
470
- # Apply memory correction if requested
471
  if enable_memory and text:
472
  text = memory_correct_text(text)
473
 
474
- # Postprocess
475
  text = postprocess_transcript(text)
476
 
477
- # Write text and srt to temp files
478
  txt_tmp = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
479
  txt_tmp.close()
480
  with open(txt_tmp.name, "w", encoding="utf-8") as fh:
@@ -488,7 +425,6 @@ def _worker_transcribe(args):
488
  fh.write(srt_out)
489
  srt_path = srt_tmp.name
490
 
491
- # Clean up WAV if created
492
  try:
493
  if wav and os.path.exists(wav) and not file_path.lower().endswith(".wav"):
494
  os.unlink(wav)
@@ -500,8 +436,6 @@ def _worker_transcribe(args):
500
  tb = traceback.format_exc()
501
  return {"file": os.path.basename(file_path) if file_path else "unknown", "text_path": None, "srt_path": None, "log": f"Worker exception: {e}\n{tb}"}
502
 
503
-
504
- # small helpers used by worker
505
  def _fmt_time(t):
506
  h = int(t // 3600)
507
  m = int((t % 3600) // 60)
@@ -509,7 +443,6 @@ def _fmt_time(t):
509
  ms = int((t - int(t)) * 1000)
510
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
511
 
512
-
513
  def _segments_to_srt(segments):
514
  lines = []
515
  for i, seg in enumerate(segments, start=1):
@@ -522,8 +455,6 @@ def _segments_to_srt(segments):
522
  lines.append("")
523
  return "\n".join(lines)
524
 
525
-
526
- # ---------- ZIP extraction (friendly mapping) ----------
527
  def extract_zip_and_map(zip_path, zip_password=None):
528
  global EXTRACT_MAP
529
  EXTRACT_MAP = {}
@@ -582,8 +513,6 @@ def extract_zip_and_map(zip_path, zip_password=None):
582
  pass
583
  return [], f"Extraction failed: {e}"
584
 
585
-
586
- # ---------- Batch transcription with parallel workers & streaming progress ----------
587
  def batch_transcribe_parallel_generator(
588
  friendly_selected,
589
  uploaded_files,
@@ -597,15 +526,10 @@ def batch_transcribe_parallel_generator(
597
  refine_threshold=-1.0,
598
  zip_password=None,
599
  ):
600
- """
601
- Generator that yields (logs_text, combined_text, zip_path_or_None, percent_int)
602
- It runs multiple workers in parallel and yields progress updates as files complete.
603
- """
604
  logs = []
605
  transcripts = []
606
- per_file_paths = [] # list of (basename, text_tmp, srt_tmp)
607
  try:
608
- # Build paths list
609
  paths = []
610
  if friendly_selected:
611
  for key in friendly_selected:
@@ -629,20 +553,16 @@ def batch_transcribe_parallel_generator(
629
  logs.append(f"Starting batch of {total} files with up to {MAX_WORKERS} workers.")
630
  yield "\n\n".join(logs), "", None, 2
631
 
632
- # Prepare task args
633
  tasks = []
634
  for p in paths:
635
  tasks.append((p, model_name, device_name, enable_mem, generate_srt, use_two_pass, fast_model, refine_threshold))
636
 
637
- # Run in ProcessPoolExecutor
638
  completed = 0
639
- results = []
640
  with ProcessPoolExecutor(max_workers=min(MAX_WORKERS, total)) as exe:
641
  futs = {exe.submit(_worker_transcribe, t): t for t in tasks}
642
  for fut in as_completed(futs):
643
  res = fut.result()
644
  completed += 1
645
- # res has keys: file, text_path, srt_path, log
646
  fname = res.get("file")
647
  res_log = res.get("log", "")
648
  logs.append(f"[{completed}/{total}] {fname}: {res_log}")
@@ -657,11 +577,9 @@ def batch_transcribe_parallel_generator(
657
  txt_content = fh.read()
658
  transcripts.append(f"FILE: {fname}\n{txt_content}\n")
659
  per_file_paths.append((fname, txtp, srtp))
660
- # progress update
661
  pct = int(5 + (completed / total) * 90)
662
  yield "\n\n".join(logs), "\n\n".join(transcripts), None, pct
663
 
664
- # combine and optionally merge into DOCX and per-file zip
665
  combined = "\n\n".join(transcripts)
666
  out_doc = None
667
  if merge_flag:
@@ -671,7 +589,6 @@ def batch_transcribe_parallel_generator(
671
  except Exception as e:
672
  logs.append(f"Merge failed: {e}")
673
 
674
- # Create ZIP with per-file transcripts (txt + srt if available)
675
  if per_file_paths:
676
  zip_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
677
  zip_tmp.close()
@@ -681,7 +598,6 @@ def batch_transcribe_parallel_generator(
681
  try:
682
  zf.write(txtp, arcname=arc_txt)
683
  except Exception:
684
- # fallback: write name-safe
685
  zf.write(txtp, arcname=os.path.basename(txtp))
686
  if srtp and os.path.exists(srtp):
687
  arc_srt = f"{fname}.srt"
@@ -698,9 +614,6 @@ def batch_transcribe_parallel_generator(
698
  logs.append(f"Batch error: {e}\n{tb}")
699
  yield "\n\n".join(logs), "\n\n".join(transcripts), None, 100
700
 
701
-
702
- # ---------- Robust multi-file memory importer with preview ----------
703
-
704
  def _read_file_text_try_encodings(path):
705
  encodings = ["utf-8", "utf-16", "latin-1"]
706
  for enc in encodings:
@@ -723,7 +636,6 @@ def _read_file_text_try_encodings(path):
723
  except Exception:
724
  return None, None
725
 
726
-
727
  def _process_single_memory_text(text):
728
  added = 0
729
  try:
@@ -770,9 +682,7 @@ def _process_single_memory_text(text):
770
  added += 1
771
  return added
772
 
773
-
774
  def preview_zip_members_for_memory(zip_path):
775
- """Return list of text-like members and a log string"""
776
  members = []
777
  logs = []
778
  try:
@@ -782,11 +692,9 @@ def preview_zip_members_for_memory(zip_path):
782
  continue
783
  name = info.filename
784
  _, ext = os.path.splitext(name)
785
- # consider likely text files
786
  if ext.lower() in [".txt", ".json", ".csv", ".list", ".md"]:
787
  members.append(name)
788
  else:
789
- # also include others but mark as maybe-binary
790
  members.append(name)
791
  if not members:
792
  logs.append("No members found in ZIP.")
@@ -796,15 +704,10 @@ def preview_zip_members_for_memory(zip_path):
796
  logs.append(f"ZIP preview failed: {e}")
797
  return members, "\n".join(logs)
798
 
799
-
800
  def import_memory_files_multiple(uploaded_files, zip_members_to_import=None):
801
- """
802
- Accept list of file paths (or single), or ZIP + selected ZIP members list.
803
- """
804
  if not uploaded_files:
805
  return "No files provided."
806
 
807
- # normalize uploaded_files
808
  if isinstance(uploaded_files, (str, os.PathLike)):
809
  uploaded_files = [str(uploaded_files)]
810
  elif isinstance(uploaded_files, dict) and uploaded_files.get("name"):
@@ -830,7 +733,6 @@ def import_memory_files_multiple(uploaded_files, zip_members_to_import=None):
830
  messages.append(f"Missing: {fp}")
831
  continue
832
  if fp.lower().endswith(".zip"):
833
- # if zip_members_to_import is provided, only import those
834
  try:
835
  with zipfile.ZipFile(fp, "r") as zf:
836
  for info in zf.infolist():
@@ -860,7 +762,6 @@ def import_memory_files_multiple(uploaded_files, zip_members_to_import=None):
860
  except zipfile.BadZipFile:
861
  skipped.append(f"Bad zip: {fp}")
862
  continue
863
- # otherwise plain file
864
  text, used_enc = _read_file_text_try_encodings(fp)
865
  if text is None:
866
  skipped.append(fp)
@@ -881,8 +782,6 @@ def import_memory_files_multiple(uploaded_files, zip_members_to_import=None):
881
  summary.extend(skipped)
882
  return "\n".join(summary)
883
 
884
-
885
- # ---------- Build Gradio UI ----------
886
  print("DEBUG: building Gradio UI", flush=True)
887
  available_choices, default_choice = safe_model_choices(prefer_default="small")
888
 
@@ -914,7 +813,6 @@ body { background: var(--bg); color: var(--text); font-family: Inter, system-ui,
914
  """
915
 
916
  with gr.Blocks(title="Whisper Transcriber — Parallel + Memory preview", css=CSS) as demo:
917
- # Theme init: dark by default
918
  gr.HTML("""
919
  <script>
920
  (function() {
@@ -934,7 +832,6 @@ with gr.Blocks(title="Whisper Transcriber — Parallel + Memory preview", css=CS
934
  gr.Markdown("<div class='small-note'>Preview ZIP members for memory import, parallel batch transcription, faster-whisper auto-detect, per-file transcript downloads</div>")
935
 
936
  with gr.Tabs():
937
- # Single tab (keeps simple)
938
  with gr.TabItem("Single File"):
939
  with gr.Row():
940
  with gr.Column(scale=1):
@@ -947,21 +844,25 @@ with gr.Blocks(title="Whisper Transcriber — Parallel + Memory preview", css=CS
947
  with gr.Column(scale=1):
948
  single_trans_out = gr.Textbox(label="Transcript", lines=14, interactive=False)
949
  single_logs = gr.Textbox(label="Logs", lines=8, interactive=False)
 
950
  def _do_single(audio, model_name, device_name, mem_on, srt_on):
951
  if not audio:
952
  return "", "No audio supplied."
953
  path = audio if isinstance(audio, str) else (audio.name if hasattr(audio, "name") else str(audio))
954
- txt, srtp, lg = _worker_transcribe((path, model_name, device_name, mem_on, srt_on, False, "small", -1.0))
955
- # read back text file if present
956
- if txt.get("text_path"):
957
- with open(txt["text_path"], "r", encoding="utf-8", errors="replace") as fh:
958
- content = fh.read()
 
 
959
  else:
960
  content = ""
961
- return content, txt.get("log", lg)
 
 
962
  trans_single_btn.click(fn=_do_single, inputs=[single_audio, model_sel_single, device_single, mem_single, srt_single], outputs=[single_trans_out, single_logs])
963
 
964
- # Batch tab
965
  with gr.TabItem("Batch Transcribe"):
966
  with gr.Row():
967
  with gr.Column(scale=1):
@@ -998,9 +899,7 @@ with gr.Blocks(title="Whisper Transcriber — Parallel + Memory preview", css=CS
998
 
999
  batch_preview_btn.click(fn=_preview_zip, inputs=[batch_zip, batch_zip_pass], outputs=[batch_preview_out])
1000
 
1001
- # Bind the generator for parallel batch with progress slider
1002
  def _start_batch(friendly_selected, uploaded_files, zip_file, zip_pass, model_name, device_name, merge_flag, mem_flag, srt_flag, use_two_pass, fast_model, refine_thresh):
1003
- # Normalize uploaded_files into list of paths (gradio provides list of dicts or strings)
1004
  up = uploaded_files
1005
  if isinstance(up, dict) and up.get("name"):
1006
  up = [up["name"]]
@@ -1022,10 +921,8 @@ with gr.Blocks(title="Whisper Transcriber — Parallel + Memory preview", css=CS
1022
  fn=_start_batch,
1023
  inputs=[batch_select, batch_files, batch_zip, batch_zip_pass, batch_model, batch_device, batch_merge, batch_mem, batch_srt, batch_use_two_pass, batch_fast_model, batch_refine_thresh],
1024
  outputs=[batch_logs_out, batch_combined_out, batch_zip_download, batch_progress],
1025
- _js=None, # compatibility
1026
  )
1027
 
1028
- # Memory tab with preview-import flow
1029
  with gr.TabItem("Memory"):
1030
  with gr.Row():
1031
  with gr.Column(scale=1):
@@ -1049,7 +946,6 @@ with gr.Blocks(title="Whisper Transcriber — Parallel + Memory preview", css=CS
1049
  def _preview_many_zip(uploaded):
1050
  if not uploaded:
1051
  return "No files."
1052
- # find first zip among uploaded and preview it
1053
  if isinstance(uploaded, dict) and uploaded.get("name"):
1054
  uploaded = [uploaded["name"]]
1055
  members_total = []
@@ -1110,14 +1006,12 @@ with gr.Blocks(title="Whisper Transcriber — Parallel + Memory preview", css=CS
1110
  mem_clear_btn.click(fn=_clear_mem, inputs=[], outputs=[mem_status])
1111
  mem_view_btn.click(fn=_view_mem, inputs=[], outputs=[mem_status])
1112
 
1113
- # Settings
1114
  with gr.TabItem("Settings"):
1115
  gr.Markdown("### Settings & tips")
1116
  gr.Markdown(f"- Faster-whisper auto-detected: {USE_FASTER_WHISPER}")
1117
  gr.Markdown(f"- Max workers for parallel transcribe: {MAX_WORKERS}")
1118
  gr.Markdown("- If memory or RAM is limited, set MAX_WORKERS lower in code.")
1119
 
1120
- # ---------- Launch ----------
1121
  if __name__ == "__main__":
1122
  port = int(os.environ.get("PORT", 7860))
1123
  print("DEBUG: launching on port", port)
 
1
  # app.py
2
+ # Whisper Transcriber — Parallel + Memory preview (fixed)
3
+ # (Same requirements as before: gradio 3.x, pydub, pyzipper, python-docx, ffmpeg, whisper or faster-whisper)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import os
6
  import sys
 
19
  import multiprocessing
20
  import time
21
 
 
22
  try:
23
  import gradio as gr
24
  except Exception as e:
25
  print("FATAL: gradio import failed:", e)
26
  raise
27
 
 
28
  USE_FASTER_WHISPER = False
29
  try:
30
  from faster_whisper import WhisperModel as FasterWhisperModel
 
31
  USE_FASTER_WHISPER = True
32
  print("INFO: faster-whisper available — will use it for faster CPU inference.")
33
  except Exception:
34
  try:
35
+ import whisper
36
  except Exception:
37
  print("FATAL: Neither faster-whisper nor whisper available.")
38
  raise
39
 
 
40
  from pydub import AudioSegment
41
  import pyzipper
42
  from docx import Document
43
 
 
44
  os.environ["PYTHONUNBUFFERED"] = "1"
45
 
 
46
  MEMORY_FILE = "memory.json"
47
  MEMORY_LOCK = threading.Lock()
48
  MIN_WAV_SIZE = 1024
 
53
  ("pcm_s16le", 44100, 2),
54
  ("mulaw", 8000, 1),
55
  ]
56
+ MODEL_CACHE = {}
57
+ EXTRACT_MAP = {}
58
  DEFAULT_ZIP_PASS = "dietcoke1"
59
 
 
60
  CPU_COUNT = max(1, multiprocessing.cpu_count())
61
+ MAX_WORKERS = min(4, CPU_COUNT)
 
62
 
 
63
  def load_memory():
64
  try:
65
  if os.path.exists(MEMORY_FILE):
 
80
  pass
81
  return mem
82
 
 
83
  def save_memory(mem):
84
  with MEMORY_LOCK:
85
  try:
 
88
  except Exception:
89
  traceback.print_exc()
90
 
 
91
  memory = load_memory()
92
 
 
 
93
  MEDICAL_ABBREVIATIONS = {
94
  "pt": "patient",
95
  "dx": "diagnosis",
 
109
  "amoxicillin": "Amoxicillin",
110
  }
111
 
 
112
  def expand_abbreviations(text):
113
  tokens = re.split(r"(\s+)", text)
114
  out = []
 
124
  out.append(t)
125
  return "".join(out)
126
 
 
127
  def normalize_drugs(text):
128
  for k, v in DRUG_NORMALIZATION.items():
129
  text = re.sub(rf"\b{k}\b", v, text, flags=re.IGNORECASE)
130
  return text
131
 
 
132
  def punctuation_and_capitalization(text):
133
  text = text.strip()
134
  if not text:
 
144
  out.append(p)
145
  return "".join(out)
146
 
 
147
  def postprocess_transcript(text):
148
  if not text:
149
  return text
 
153
  t = punctuation_and_capitalization(t)
154
  return t
155
 
 
156
  def extract_words_and_phrases(text):
157
  words = re.findall(r"[A-Za-z0-9\-']+", text)
158
  sentences = [s.strip() for s in re.split(r"(?<=[.?!])\s+", text) if s.strip()]
159
  return [w for w in words if w.strip()], sentences
160
 
 
161
  def update_memory_with_transcript(transcript):
162
  global memory
163
  words, sentences = extract_words_and_phrases(transcript)
 
173
  if changed:
174
  save_memory(memory)
175
 
 
176
  def memory_correct_text(text, min_ratio=0.85):
177
  if not text or (not memory.get("words") and not memory.get("phrases")):
178
  return text
 
206
  corrected = re.sub(re.escape(phrase), phrase, corrected, flags=re.IGNORECASE)
207
  return corrected
208
 
 
 
209
  def save_as_word(text, filename=None):
210
  if filename is None:
211
  filename = os.path.join(tempfile.gettempdir(), "merged_transcripts.docx")
 
214
  doc.save(filename)
215
  return filename
216
 
 
217
  def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
218
  try:
219
  cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"]
 
240
  pass
241
  return False, str(e)
242
 
 
243
  def convert_to_wav_if_needed(input_path):
244
  input_path = str(input_path)
245
  lower = input_path.lower()
 
267
  except Exception:
268
  pass
269
 
 
270
  diag_dir = tempfile.mkdtemp(prefix="dct_diag_")
271
  diag_log = os.path.join(diag_dir, "conversion_diagnostics.txt")
272
  diagnostics = []
 
322
 
323
  raise Exception(f"Could not convert file to WAV. Diagnostics saved to: {diag_log}")
324
 
 
 
325
  def whisper_available_models():
326
  try:
327
  if USE_FASTER_WHISPER:
 
328
  return set(["tiny", "base", "small", "medium", "large", "large-v3"])
329
  else:
330
  models = whisper.available_models()
 
334
  pass
335
  return set(["tiny", "base", "small", "medium", "large", "large-v3"])
336
 
 
337
  AVAILABLE_MODEL_SET = whisper_available_models()
338
 
 
339
  def safe_model_choices(prefer_default="small"):
340
  base_choices = ["small", "medium", "large", "large-v3", "base", "tiny"]
341
  choices = [m for m in base_choices if m in AVAILABLE_MODEL_SET]
 
344
  default = prefer_default if prefer_default in choices else choices[0]
345
  return choices, default
346
 
 
 
347
  def _worker_transcribe(args):
 
 
 
 
 
348
  try:
349
  (file_path, model_name, device_name, enable_memory, generate_srt, use_two_pass, fast_model, refine_threshold) = args
350
  base = os.path.basename(file_path)
351
  log_lines = []
352
  device = None if device_name == "auto" else device_name
353
 
 
354
  model = None
355
  use_fw = False
356
  try:
357
  if USE_FASTER_WHISPER:
 
 
358
  model = FasterWhisperModel(model_name, device=device if device else "cpu")
359
  use_fw = True
360
  log_lines.append(f"Worker: faster-whisper loaded {model_name}")
 
365
  log_lines.append(f"Worker: whisper loaded {model_name}")
366
  except Exception as e:
367
  log_lines.append(f"Worker model load failed: {e}")
 
368
  try:
369
  if USE_FASTER_WHISPER:
370
  model = FasterWhisperModel("small", device=device if device else "cpu")
 
377
  except Exception as e2:
378
  return {"file": base, "text_path": None, "srt_path": None, "log": "Model load failed: " + str(e2)}
379
 
 
380
  try:
381
  wav = convert_to_wav_if_needed(file_path)
382
  log_lines.append(f"Converted to WAV: {os.path.basename(wav)}")
383
  except Exception as e:
384
  return {"file": base, "text_path": None, "srt_path": None, "log": "Conversion failed: " + str(e)}
385
 
 
386
  try:
387
  if use_fw:
 
388
  segments, info = model.transcribe(wav, beam_size=5)
389
  text = "".join([seg.text for seg in segments]).strip()
 
390
  if generate_srt:
391
  srt_text = []
392
  for i, seg in enumerate(segments, start=1):
 
393
  start = getattr(seg, "start", 0)
394
  end = getattr(seg, "end", 0)
395
  txt = getattr(seg, "text", "").strip()
 
407
  except Exception as e:
408
  return {"file": base, "text_path": None, "srt_path": None, "log": "Transcription failed: " + str(e)}
409
 
 
410
  if enable_memory and text:
411
  text = memory_correct_text(text)
412
 
 
413
  text = postprocess_transcript(text)
414
 
 
415
  txt_tmp = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
416
  txt_tmp.close()
417
  with open(txt_tmp.name, "w", encoding="utf-8") as fh:
 
425
  fh.write(srt_out)
426
  srt_path = srt_tmp.name
427
 
 
428
  try:
429
  if wav and os.path.exists(wav) and not file_path.lower().endswith(".wav"):
430
  os.unlink(wav)
 
436
  tb = traceback.format_exc()
437
  return {"file": os.path.basename(file_path) if file_path else "unknown", "text_path": None, "srt_path": None, "log": f"Worker exception: {e}\n{tb}"}
438
 
 
 
439
  def _fmt_time(t):
440
  h = int(t // 3600)
441
  m = int((t % 3600) // 60)
 
443
  ms = int((t - int(t)) * 1000)
444
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
445
 
 
446
  def _segments_to_srt(segments):
447
  lines = []
448
  for i, seg in enumerate(segments, start=1):
 
455
  lines.append("")
456
  return "\n".join(lines)
457
 
 
 
458
  def extract_zip_and_map(zip_path, zip_password=None):
459
  global EXTRACT_MAP
460
  EXTRACT_MAP = {}
 
513
  pass
514
  return [], f"Extraction failed: {e}"
515
 
 
 
516
  def batch_transcribe_parallel_generator(
517
  friendly_selected,
518
  uploaded_files,
 
526
  refine_threshold=-1.0,
527
  zip_password=None,
528
  ):
 
 
 
 
529
  logs = []
530
  transcripts = []
531
+ per_file_paths = []
532
  try:
 
533
  paths = []
534
  if friendly_selected:
535
  for key in friendly_selected:
 
553
  logs.append(f"Starting batch of {total} files with up to {MAX_WORKERS} workers.")
554
  yield "\n\n".join(logs), "", None, 2
555
 
 
556
  tasks = []
557
  for p in paths:
558
  tasks.append((p, model_name, device_name, enable_mem, generate_srt, use_two_pass, fast_model, refine_threshold))
559
 
 
560
  completed = 0
 
561
  with ProcessPoolExecutor(max_workers=min(MAX_WORKERS, total)) as exe:
562
  futs = {exe.submit(_worker_transcribe, t): t for t in tasks}
563
  for fut in as_completed(futs):
564
  res = fut.result()
565
  completed += 1
 
566
  fname = res.get("file")
567
  res_log = res.get("log", "")
568
  logs.append(f"[{completed}/{total}] {fname}: {res_log}")
 
577
  txt_content = fh.read()
578
  transcripts.append(f"FILE: {fname}\n{txt_content}\n")
579
  per_file_paths.append((fname, txtp, srtp))
 
580
  pct = int(5 + (completed / total) * 90)
581
  yield "\n\n".join(logs), "\n\n".join(transcripts), None, pct
582
 
 
583
  combined = "\n\n".join(transcripts)
584
  out_doc = None
585
  if merge_flag:
 
589
  except Exception as e:
590
  logs.append(f"Merge failed: {e}")
591
 
 
592
  if per_file_paths:
593
  zip_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
594
  zip_tmp.close()
 
598
  try:
599
  zf.write(txtp, arcname=arc_txt)
600
  except Exception:
 
601
  zf.write(txtp, arcname=os.path.basename(txtp))
602
  if srtp and os.path.exists(srtp):
603
  arc_srt = f"{fname}.srt"
 
614
  logs.append(f"Batch error: {e}\n{tb}")
615
  yield "\n\n".join(logs), "\n\n".join(transcripts), None, 100
616
 
 
 
 
617
  def _read_file_text_try_encodings(path):
618
  encodings = ["utf-8", "utf-16", "latin-1"]
619
  for enc in encodings:
 
636
  except Exception:
637
  return None, None
638
 
 
639
  def _process_single_memory_text(text):
640
  added = 0
641
  try:
 
682
  added += 1
683
  return added
684
 
 
685
  def preview_zip_members_for_memory(zip_path):
 
686
  members = []
687
  logs = []
688
  try:
 
692
  continue
693
  name = info.filename
694
  _, ext = os.path.splitext(name)
 
695
  if ext.lower() in [".txt", ".json", ".csv", ".list", ".md"]:
696
  members.append(name)
697
  else:
 
698
  members.append(name)
699
  if not members:
700
  logs.append("No members found in ZIP.")
 
704
  logs.append(f"ZIP preview failed: {e}")
705
  return members, "\n".join(logs)
706
 
 
707
  def import_memory_files_multiple(uploaded_files, zip_members_to_import=None):
 
 
 
708
  if not uploaded_files:
709
  return "No files provided."
710
 
 
711
  if isinstance(uploaded_files, (str, os.PathLike)):
712
  uploaded_files = [str(uploaded_files)]
713
  elif isinstance(uploaded_files, dict) and uploaded_files.get("name"):
 
733
  messages.append(f"Missing: {fp}")
734
  continue
735
  if fp.lower().endswith(".zip"):
 
736
  try:
737
  with zipfile.ZipFile(fp, "r") as zf:
738
  for info in zf.infolist():
 
762
  except zipfile.BadZipFile:
763
  skipped.append(f"Bad zip: {fp}")
764
  continue
 
765
  text, used_enc = _read_file_text_try_encodings(fp)
766
  if text is None:
767
  skipped.append(fp)
 
782
  summary.extend(skipped)
783
  return "\n".join(summary)
784
 
 
 
785
  print("DEBUG: building Gradio UI", flush=True)
786
  available_choices, default_choice = safe_model_choices(prefer_default="small")
787
 
 
813
  """
814
 
815
  with gr.Blocks(title="Whisper Transcriber — Parallel + Memory preview", css=CSS) as demo:
 
816
  gr.HTML("""
817
  <script>
818
  (function() {
 
832
  gr.Markdown("<div class='small-note'>Preview ZIP members for memory import, parallel batch transcription, faster-whisper auto-detect, per-file transcript downloads</div>")
833
 
834
  with gr.Tabs():
 
835
  with gr.TabItem("Single File"):
836
  with gr.Row():
837
  with gr.Column(scale=1):
 
844
  with gr.Column(scale=1):
845
  single_trans_out = gr.Textbox(label="Transcript", lines=14, interactive=False)
846
  single_logs = gr.Textbox(label="Logs", lines=8, interactive=False)
847
+
848
  def _do_single(audio, model_name, device_name, mem_on, srt_on):
849
  if not audio:
850
  return "", "No audio supplied."
851
  path = audio if isinstance(audio, str) else (audio.name if hasattr(audio, "name") else str(audio))
852
+ res = _worker_transcribe((path, model_name, device_name, mem_on, srt_on, False, "small", -1.0))
853
+ if res.get("text_path"):
854
+ try:
855
+ with open(res["text_path"], "r", encoding="utf-8", errors="replace") as fh:
856
+ content = fh.read()
857
+ except Exception:
858
+ content = ""
859
  else:
860
  content = ""
861
+ logs = res.get("log", "")
862
+ return content, logs
863
+
864
  trans_single_btn.click(fn=_do_single, inputs=[single_audio, model_sel_single, device_single, mem_single, srt_single], outputs=[single_trans_out, single_logs])
865
 
 
866
  with gr.TabItem("Batch Transcribe"):
867
  with gr.Row():
868
  with gr.Column(scale=1):
 
899
 
900
  batch_preview_btn.click(fn=_preview_zip, inputs=[batch_zip, batch_zip_pass], outputs=[batch_preview_out])
901
 
 
902
  def _start_batch(friendly_selected, uploaded_files, zip_file, zip_pass, model_name, device_name, merge_flag, mem_flag, srt_flag, use_two_pass, fast_model, refine_thresh):
 
903
  up = uploaded_files
904
  if isinstance(up, dict) and up.get("name"):
905
  up = [up["name"]]
 
921
  fn=_start_batch,
922
  inputs=[batch_select, batch_files, batch_zip, batch_zip_pass, batch_model, batch_device, batch_merge, batch_mem, batch_srt, batch_use_two_pass, batch_fast_model, batch_refine_thresh],
923
  outputs=[batch_logs_out, batch_combined_out, batch_zip_download, batch_progress],
 
924
  )
925
 
 
926
  with gr.TabItem("Memory"):
927
  with gr.Row():
928
  with gr.Column(scale=1):
 
946
  def _preview_many_zip(uploaded):
947
  if not uploaded:
948
  return "No files."
 
949
  if isinstance(uploaded, dict) and uploaded.get("name"):
950
  uploaded = [uploaded["name"]]
951
  members_total = []
 
1006
  mem_clear_btn.click(fn=_clear_mem, inputs=[], outputs=[mem_status])
1007
  mem_view_btn.click(fn=_view_mem, inputs=[], outputs=[mem_status])
1008
 
 
1009
  with gr.TabItem("Settings"):
1010
  gr.Markdown("### Settings & tips")
1011
  gr.Markdown(f"- Faster-whisper auto-detected: {USE_FASTER_WHISPER}")
1012
  gr.Markdown(f"- Max workers for parallel transcribe: {MAX_WORKERS}")
1013
  gr.Markdown("- If memory or RAM is limited, set MAX_WORKERS lower in code.")
1014
 
 
1015
  if __name__ == "__main__":
1016
  port = int(os.environ.get("PORT", 7860))
1017
  print("DEBUG: launching on port", port)