staraks commited on
Commit
2234d16
·
verified ·
1 Parent(s): 847997b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -143
app.py CHANGED
@@ -1,10 +1,11 @@
1
  # app.py
2
  # Whisper Transcriber — Gradio 3.x compatible full file
3
- # Features added: chunk size control, experimental parallel chunk transcription (CPU-only),
4
- # streaming progress bar (no audio preview), memory corrections, ZIP extraction, theme toggle.
 
 
5
  #
6
  # Requirements: gradio (3.x), whisper, pydub, pyzipper, python-docx, ffmpeg installed.
7
- # Experimental parallel mode uses multiprocessing and loads the 'fast' model in each worker.
8
 
9
  import os
10
  import sys
@@ -19,6 +20,10 @@ from difflib import get_close_matches
19
  from uuid import uuid4
20
  from pathlib import Path
21
  from multiprocessing import get_context
 
 
 
 
22
  from typing import Tuple, List
23
 
24
  # Force unbuffered prints for logs
@@ -52,9 +57,12 @@ FFMPEG_CANDIDATES = [
52
  MODEL_CACHE = {}
53
  EXTRACT_MAP = {} # friendly_name -> absolute path
54
 
 
 
 
 
55
  # ---------- Worker-global for multiprocessing ----------
56
- # These are defined for worker processes (initialized via initializer)
57
- WORKER_MODEL = None # type: ignore
58
 
59
  def worker_init(model_name: str, device: str):
60
  """
@@ -503,7 +511,7 @@ def trim_audio_segment(src_path, start_sec, end_sec):
503
  pass
504
  raise
505
 
506
- # ---------- Core transcription (single file) ----------
507
  def transcribe_single_file(
508
  path,
509
  model_name="small",
@@ -515,7 +523,6 @@ def transcribe_single_file(
515
  refine_model=None,
516
  refine_threshold=-1.0,
517
  ):
518
- # non-streaming convenience helper used for batch mode
519
  logs = []
520
  try:
521
  if not path:
@@ -554,64 +561,185 @@ def transcribe_single_file(
554
  pass
555
  return text, srt_path, "\n".join(logs)
556
 
557
- # Two-pass path not used for streaming generator here
558
- return "", None, "Two-pass not used in this helper."
559
  except Exception as e:
560
  tb = traceback.format_exc()
561
  return "", None, f"Transcription error: {e}\n{tb}"
562
 
563
- # ---------- Batch transcribe (unchanged) ----------
564
- def batch_transcribe(friendly_selected, uploaded_files, model_name, device_name, merge_flag, enable_mem, generate_srt, use_two_pass=False, fast_model="small", refine_threshold=-1.0):
565
- logs = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  transcripts = []
 
567
  srt_files = []
568
- out_doc = None
569
  paths = []
 
570
  if friendly_selected:
571
  for key in friendly_selected:
572
  p = EXTRACT_MAP.get(key)
573
  if p:
574
  paths.append(p)
575
  else:
576
- logs.append(f"Warning: selected not found in extract map: {key}")
 
577
  if uploaded_files:
578
  if isinstance(uploaded_files, (list, tuple)):
579
  for f in uploaded_files:
580
  paths.append(str(f))
581
  else:
582
  paths.append(str(uploaded_files))
 
583
  if not paths:
584
- return "", "No files selected or uploaded.", None, None
 
 
585
 
586
  total = len(paths)
 
 
 
587
  for idx, p in enumerate(paths, start=1):
588
- logs.append(f"[{idx}/{total}] Processing: {p}")
589
- text, srt_path, lg = transcribe_single_file(
590
- p,
591
- model_name=model_name,
592
- device_choice=device_name,
593
- enable_memory=enable_mem,
594
- generate_srt=generate_srt,
595
- use_two_pass=use_two_pass,
596
- fast_model=fast_model,
597
- refine_model=model_name,
598
- refine_threshold=refine_threshold,
599
- )
600
- logs.append(lg)
601
- transcripts.append(f"FILE: {os.path.basename(p)}\n{text}\n")
602
- if srt_path:
603
- srt_files.append(srt_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  combined = "\n\n".join(transcripts)
 
605
  if merge_flag:
606
  try:
607
- out_doc = save_as_word(combined)
608
- logs.append(f"Merged saved: {out_doc}")
609
  except Exception as e:
610
- logs.append(f"Merge failed: {e}")
611
- srt_return = srt_files[0] if srt_files else None
612
- return combined, "\n".join(logs), out_doc, srt_return
 
 
 
 
 
 
 
 
 
 
 
 
613
 
614
- # ---------- Build Gradio UI (3.x compatible) ----------
615
  print("DEBUG: building Gradio UI", flush=True)
616
  available_choices, default_choice = safe_model_choices(prefer_default="small")
617
 
@@ -643,7 +771,7 @@ body { background: var(--bg); color: var(--text); font-family: Inter, system-ui,
643
  """
644
 
645
  with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
646
- # Theme initializer + toggle injected via HTML (works across gradio versions)
647
  gr.HTML("""
648
  <script>
649
  (function() {
@@ -653,31 +781,22 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
653
  var chosen = null;
654
  if (saved === 'dark' || saved === 'light') {
655
  chosen = saved;
656
- } else if (window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches) {
657
- chosen = 'dark';
658
  } else {
659
- chosen = 'light';
 
660
  }
661
  document.documentElement.setAttribute('data-theme', chosen);
662
- try {
663
- var style = document.createElement('style');
664
- style.innerHTML = `
665
- :root, [data-theme="dark"] { transition: background-color 260ms ease, color 260ms ease; }
666
- `;
667
- document.head.appendChild(style);
668
- } catch(e){}
669
  } catch (e) { console.warn('theme init failed', e); }
670
  })();
671
  </script>
672
  """)
673
 
674
- # Header
675
  with gr.Row():
676
  with gr.Column(scale=0):
677
  gr.HTML("<div style='width:50px;height:50px;border-radius:10px;background:linear-gradient(135deg,#4f46e5,#06b6d4);display:flex;align-items:center;justify-content:center;color:white;font-weight:700;font-size:20px;'>WT</div>")
678
  with gr.Column():
679
  gr.Markdown("<h3 style='margin:0'>Whisper Transcriber (Gradio 3.x)</h3>")
680
- gr.Markdown("<div class='small-note'>Chunked streaming, experimental CPU parallel, per-run ZIP extraction, memory corrections, SRT export, dark/light toggle</div>")
681
 
682
  with gr.Tabs():
683
  # Single audio
@@ -690,15 +809,12 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
690
  device_choice = gr.Dropdown(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
691
  mem_toggle = gr.Checkbox(label="Enable memory corrections", value=False)
692
  srt_toggle = gr.Checkbox(label="Generate SRT", value=False)
693
- # chunk controls
694
- chunk_controls_row = gr.Row(visible=True)
695
  chunk_size_input = gr.Number(value=30, label="Chunk size (seconds)", precision=0)
696
  enable_chunking = gr.Checkbox(label="Enable chunking (recommended for long files)", value=True)
697
- # parallel experimental
698
  parallel_checkbox = gr.Checkbox(label="Enable experimental parallel chunk transcription (CPU only)", value=False)
699
  parallel_workers = gr.Slider(minimum=1, maximum=max(1, os.cpu_count() or 4), value=2, step=1, label="Parallel workers (processes)")
700
  use_two_pass_single = gr.Checkbox(label="Use two-pass speedup (fast then refine)", value=False)
701
- fast_model_choice = gr.Dropdown(choices=[c for c in ["tiny", "base", "small"] if c in AVAILABLE_MODEL_SET], value="small", label="Fast model (for two-pass / workers)")
702
  refine_threshold_single = gr.Number(value=-1.0, label="Refine threshold (avg_logprob)", precision=2)
703
  transcribe_btn = gr.Button("Transcribe", variant="primary")
704
  with gr.Column(scale=1):
@@ -708,7 +824,6 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
708
  srt_download = gr.File(label="SRT (if generated)")
709
  single_logs = gr.Textbox(label="Logs", lines=8, interactive=False)
710
 
711
- # streaming generator with optional multiprocessing
712
  def _single_generator(audio_file, model_name, device, mem_on, srt_on, chunk_size_sec, chunking_enabled, parallel_enabled, workers, use_two_pass_flag, fast_model, refine_thresh):
713
  yield 0, "", None, "Starting..."
714
  try:
@@ -722,7 +837,6 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
722
  wav = convert_to_wav_if_needed(path)
723
  yield 8, "", None, f"Converted to WAV: {os.path.basename(wav)}"
724
 
725
- # determine duration
726
  duration = None
727
  try:
728
  p = subprocess.run(["ffprobe","-v","error","-show_entries","format=duration","-of","default=noprint_wrappers=1:nokey=1", wav], capture_output=True, text=True, timeout=8)
@@ -737,7 +851,6 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
737
  except Exception:
738
  duration = None
739
 
740
- # build chunk ranges
741
  if chunking_enabled and (duration and duration > chunk_size_sec * 1.5):
742
  num_chunks = max(1, int((duration + chunk_size_sec - 1) // chunk_size_sec))
743
  chunk_ranges = []
@@ -752,28 +865,23 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
752
 
753
  yield 10, "", None, f"Preparing transcription ({len(chunk_ranges)} chunk(s))..."
754
 
755
- # Load model in main process (for serial or orchestration)
756
  model = get_whisper_model(model_name, device=None if device == "auto" else device)
757
  yield 15, "", None, f"Model loaded: {model_name}"
758
 
759
  overall_parts = []
760
  total_chunks = len(chunk_ranges)
761
 
762
- # Decide whether we can/should run parallel workers
763
  parallel_used = False
764
  if parallel_enabled and chunking_enabled and total_chunks > 1:
765
  if device != "cpu" and device != "auto":
766
- # Most likely GPU requested; parallel across multiple processes with GPU not recommended
767
  yield 15, "", None, "Parallel mode requested but device is not 'cpu'. Falling back to serial chunking."
768
  parallel_used = False
769
  else:
770
- # attempt to spawn a multiprocessing pool that initializes each worker with fast_model on CPU
771
  try:
772
  ctx = get_context("spawn")
773
  worker_count = max(1, int(workers))
774
  yield 18, "", None, f"Starting parallel pool with {worker_count} workers (fast_model={fast_model})..."
775
  pool = ctx.Pool(processes=worker_count, initializer=worker_init, initargs=(fast_model, "cpu"))
776
- # prepare chunk WAVs
777
  chunk_paths = []
778
  temp_chunk_files = []
779
  for (st, ed) in chunk_ranges:
@@ -783,11 +891,9 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
783
  cw = trim_audio_segment(wav, st, ed)
784
  chunk_paths.append(cw)
785
  temp_chunk_files.append(cw)
786
- # map transcribe jobs
787
  results = pool.map(worker_transcribe_chunk, chunk_paths)
788
  pool.close()
789
  pool.join()
790
- # process results in order
791
  for idx, (txt, err) in enumerate(results, start=1):
792
  if err:
793
  yield int(20 + idx * 70 / max(1, total_chunks)), "\n\n".join(overall_parts), None, f"Chunk {idx} worker error: {err}"
@@ -798,7 +904,6 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
798
  overall_parts.append(txt)
799
  prog = int(20 + idx * 70 / max(1, total_chunks))
800
  yield prog, "\n\n".join(overall_parts), None, f"Completed chunk {idx}/{total_chunks} (parallel)."
801
- # cleanup temp chunks (but not original wav)
802
  for tfile in temp_chunk_files:
803
  try:
804
  if os.path.exists(tfile):
@@ -811,7 +916,6 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
811
  parallel_used = False
812
 
813
  if not parallel_used:
814
- # serial chunk processing
815
  for idx, (st, ed) in enumerate(chunk_ranges, start=1):
816
  try:
817
  if ed is None:
@@ -823,9 +927,7 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
823
 
824
  yield int(15 + (idx - 1) * 70 / max(1, total_chunks)), "", None, f"Transcribing chunk {idx}/{total_chunks} ({note})..."
825
 
826
- # call model.transcribe on chunk
827
- whisper_opts = {}
828
- result = model.transcribe(chunk_wav, **whisper_opts)
829
  chunk_text = result.get("text", "").strip()
830
 
831
  if mem_on:
@@ -845,7 +947,6 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
845
  except Exception as e:
846
  yield int(15 + idx * 70 / max(1, total_chunks)), "\n\n".join(overall_parts), None, f"Chunk {idx} failed: {e}\n{traceback.format_exc()}"
847
 
848
- # final assembly
849
  final_text = "\n\n".join([p for p in overall_parts if p])
850
  if mem_on:
851
  try:
@@ -853,7 +954,6 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
853
  except Exception:
854
  pass
855
 
856
- # SRT generation best-effort (runs a full transcribe to get segments)
857
  srt_path = None
858
  if srt_on:
859
  try:
@@ -869,7 +969,6 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
869
 
870
  yield 98, final_text, srt_path, "Transcription complete."
871
 
872
- # cleanup tmp wav if created
873
  try:
874
  if os.path.exists(wav) and not path.lower().endswith(".wav"):
875
  os.unlink(wav)
@@ -887,7 +986,7 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
887
  outputs=[progress_num, transcript_out, srt_download, single_logs],
888
  )
889
 
890
- # Batch tab (unchanged UI and behavior)
891
  with gr.TabItem("Batch Transcribe"):
892
  with gr.Row():
893
  with gr.Column(scale=1):
@@ -913,6 +1012,8 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
913
  batch_logs = gr.Textbox(label="Logs", lines=10, interactive=False)
914
  batch_doc_download = gr.File(label="Merged DOCX (if created)")
915
  batch_srt_download = gr.File(label="First SRT (if any)")
 
 
916
 
917
  def _do_extract(zip_file, password):
918
  if not zip_file:
@@ -923,33 +1024,19 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
923
 
924
  batch_extract_btn.click(fn=_do_extract, inputs=[batch_zip, zip_password], outputs=[batch_select, batch_extract_logs])
925
 
926
- def _do_batch(friendly_selected, uploaded_files, model_name, device, merge_flag, mem_flag, srt_flag, use_two_pass_flag, fast_model, refine_thresh):
927
- combined, logs, out_doc, srt_path = batch_transcribe(
928
- friendly_selected,
929
- uploaded_files,
930
- model_name,
931
- device,
932
- merge_flag,
933
- mem_flag,
934
- srt_flag,
935
- use_two_pass=use_two_pass_flag,
936
- fast_model=fast_model,
937
- refine_threshold=refine_thresh,
938
- )
939
- return combined, logs, out_doc, srt_path
940
-
941
  batch_run_btn.click(
942
- fn=_do_batch,
943
  inputs=[batch_select, batch_files, batch_model, batch_device, batch_merge, batch_mem, batch_srt, batch_use_two_pass, batch_fast_model, batch_refine_threshold],
944
- outputs=[batch_trans_out, batch_logs, batch_doc_download, batch_srt_download],
945
  )
946
 
947
- # Memory tab (unchanged)
948
  with gr.TabItem("Memory"):
949
  with gr.Row():
950
  with gr.Column(scale=1):
951
  gr.Markdown("### Correction Memory")
952
- mem_upload = gr.File(label="Import memory file (JSON or text)", file_count="single", type="filepath")
 
953
  mem_import_btn = gr.Button("Import Memory")
954
  mem_text = gr.Textbox(label="Add word/phrase", placeholder="Type word or phrase")
955
  mem_add_btn = gr.Button("Add to Memory")
@@ -958,43 +1045,97 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
958
  mem_status = gr.Textbox(label="Memory status / preview", lines=12, interactive=False)
959
 
960
  def _import_mem(uploaded):
 
 
 
 
 
 
 
961
  if not uploaded:
962
- return "No file provided."
963
- path = uploaded.name if hasattr(uploaded, "name") else str(uploaded)
964
- try:
965
- with open(path, "r", encoding="utf-8") as fh:
966
- raw = fh.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
967
  parsed = None
968
  try:
969
  parsed = json.loads(raw)
970
  except Exception:
971
  parsed = None
 
972
  if isinstance(parsed, dict):
973
  with MEMORY_LOCK:
974
- for k, v in parsed.get("words", {}).items():
975
- memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + int(v)
976
- for k, v in parsed.get("phrases", {}).items():
977
- memory["phrases"][k] = memory["phrases"].get(k, 0) + int(v)
 
 
 
 
 
 
 
 
 
 
978
  save_memory(memory)
979
- return f"Imported JSON memory (words={len(parsed.get('words', {}))}, phrases={len(parsed.get('phrases', {}))})."
 
 
 
980
  lines = [l.strip() for l in raw.splitlines() if l.strip()]
981
- added = 0
982
  with MEMORY_LOCK:
983
  for line in lines:
984
  if "," in line:
985
- k, c = line.split(",", 1)
 
986
  try:
987
- cnt = int(c)
988
- except:
989
  cnt = 1
990
- memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + cnt
 
991
  else:
992
- memory["words"][line.lower()] = memory["words"].get(line.lower(), 0) + 1
993
- added += 1
 
 
 
 
 
 
994
  save_memory(memory)
995
- return f"Imported {added} entries."
996
- except Exception as e:
997
- return f"Import failed: {e}"
 
 
 
 
998
 
999
  def _add_mem(entry):
1000
  if not entry or not entry.strip():
@@ -1035,7 +1176,7 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
1035
  mem_clear_btn.click(fn=_clear_mem, inputs=[], outputs=[mem_status])
1036
  mem_view_btn.click(fn=_view_mem, inputs=[], outputs=[mem_status])
1037
 
1038
- # Settings tab (theme)
1039
  with gr.TabItem("Settings"):
1040
  with gr.Row():
1041
  with gr.Column():
@@ -1048,57 +1189,32 @@ with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
1048
  gr.HTML("""
1049
  <div style="display:flex;align-items:center;gap:12px;">
1050
  <button id="wt_theme_btn" style="display:flex;align-items:center;gap:8px;padding:8px 10px;border-radius:8px;border:1px solid rgba(0,0,0,0.06);background:var(--card);cursor:pointer;">
1051
- <span id="wt_theme_icon" style="display:inline-flex;width:18px;height:18px;align-items:center;justify-content:center;"></span>
1052
  <span id="wt_theme_label" style="font-weight:600;">Toggle Theme</span>
1053
  </button>
1054
- <div style="color:var(--muted);font-size:13px;">Theme preference saved in browser · <span id="wt_theme_hint">auto</span></div>
1055
  </div>
1056
  <script>
1057
  (function(){
1058
  try {
1059
  const root = document.documentElement;
1060
  const btn = document.getElementById('wt_theme_btn');
1061
- const icon = document.getElementById('wt_theme_icon');
1062
- const hint = document.getElementById('wt_theme_hint');
1063
-
1064
- function setIconFor(theme) {
1065
- if (!icon) return;
1066
- if (theme === 'dark') {
1067
- icon.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M21 12.79A9 9 0 1111.21 3 7 7 0 0021 12.79z" fill="currentColor"/></svg>';
1068
- } else {
1069
- icon.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M12 4V2M12 22v-2M4.2 4.2L2.8 2.8M21.2 21.2l-1.4-1.4M4 12H2m20 0h-2M4.2 19.8L2.8 21.2M21.2 2.8L19.8 4.2" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/><circle cx="12" cy="12" r="3" fill="currentColor"/></svg>';
1070
- }
1071
- }
1072
-
1073
  var saved = null;
1074
  try { saved = localStorage.getItem('wt_theme'); } catch(e){ saved = null; }
1075
  var effective = null;
1076
  if (saved === 'dark' || saved === 'light') {
1077
  effective = saved;
1078
- hint.textContent = 'saved';
1079
- } else if (window.matchMedia && window.matchMedia('(prefers-color-scheme: dark)').matches) {
1080
- effective = 'dark';
1081
- hint.textContent = 'OS-prefer';
1082
  } else {
1083
- effective = 'light';
1084
- hint.textContent = 'OS-prefer';
1085
  }
1086
  root.setAttribute('data-theme', effective);
1087
- setIconFor(effective);
1088
-
1089
  btn.addEventListener('click', function(){
1090
  try {
1091
  const cur = root.getAttribute('data-theme') === 'dark' ? 'light' : 'dark';
1092
  root.setAttribute('data-theme', cur);
1093
- try { localStorage.setItem('wt_theme', cur); hint.textContent = 'saved'; } catch(e){ hint.textContent = 'saved'; }
1094
- setIconFor(cur);
1095
- } catch(e){
1096
- console.error(e);
1097
- }
1098
  });
1099
- } catch(e){
1100
- console.warn('theme toggle init failed', e);
1101
- }
1102
  })();
1103
  </script>
1104
  """)
 
1
  # app.py
2
  # Whisper Transcriber — Gradio 3.x compatible full file
3
+ # Features: chunking, experimental parallel chunk transcription (CPU-only),
4
+ # memory corrections, two-pass refine, per-run ZIP extraction & selection,
5
+ # batch streaming with progress and per-file DOCX ZIP, in-memory buffer cache,
6
+ # dark/light theme toggle (dark default).
7
  #
8
  # Requirements: gradio (3.x), whisper, pydub, pyzipper, python-docx, ffmpeg installed.
 
9
 
10
  import os
11
  import sys
 
20
  from uuid import uuid4
21
  from pathlib import Path
22
  from multiprocessing import get_context
23
+ from collections import OrderedDict
24
+ import hashlib
25
+ import io
26
+ import zipfile
27
  from typing import Tuple, List
28
 
29
  # Force unbuffered prints for logs
 
57
  MODEL_CACHE = {}
58
  EXTRACT_MAP = {} # friendly_name -> absolute path
59
 
60
+ # Buffer cache configuration (in-memory LRU)
61
+ BUFFER_CACHE_MAX = 200 # tune to limit memory
62
+ BUFFER_CACHE = OrderedDict()
63
+
64
  # ---------- Worker-global for multiprocessing ----------
65
+ WORKER_MODEL = None # loaded in worker processes
 
66
 
67
  def worker_init(model_name: str, device: str):
68
  """
 
511
  pass
512
  raise
513
 
514
+ # ---------- Single-file transcription helper ----------
515
  def transcribe_single_file(
516
  path,
517
  model_name="small",
 
523
  refine_model=None,
524
  refine_threshold=-1.0,
525
  ):
 
526
  logs = []
527
  try:
528
  if not path:
 
561
  pass
562
  return text, srt_path, "\n".join(logs)
563
 
564
+ # If use_two_pass was requested, we keep a simple fallback (advanced two-pass handled elsewhere)
565
+ return "", None, "Two-pass not implemented in this helper."
566
  except Exception as e:
567
  tb = traceback.format_exc()
568
  return "", None, f"Transcription error: {e}\n{tb}"
569
 
570
+ # ---------- Buffer cache helpers ----------
571
+ def make_cache_key(file_path: str, model_name: str, device: str, mem_on: bool, two_pass: bool, fast_model: str, refine_threshold: float):
572
+ try:
573
+ h = hashlib.sha256()
574
+ with open(str(file_path), "rb") as fh:
575
+ for chunk in iter(lambda: fh.read(8192), b""):
576
+ h.update(chunk)
577
+ file_hash = h.hexdigest()
578
+ except Exception:
579
+ file_hash = f"path:{str(file_path)}"
580
+ key = f"{file_hash}|model={model_name}|dev={device}|mem={int(mem_on)}|two={int(two_pass)}|fast={fast_model}|th={refine_threshold}"
581
+ return key
582
+
583
+ def cache_put(key: str, value: dict):
584
+ try:
585
+ if key in BUFFER_CACHE:
586
+ BUFFER_CACHE.pop(key)
587
+ BUFFER_CACHE[key] = value
588
+ while len(BUFFER_CACHE) > BUFFER_CACHE_MAX:
589
+ BUFFER_CACHE.popitem(last=False)
590
+ except Exception:
591
+ pass
592
+
593
+ def cache_get(key: str):
594
+ try:
595
+ val = BUFFER_CACHE.get(key)
596
+ if val:
597
+ BUFFER_CACHE.pop(key)
598
+ BUFFER_CACHE[key] = val
599
+ return val
600
+ except Exception:
601
+ return None
602
+
603
+ def make_docx_bytes(text: str):
604
+ try:
605
+ bio = io.BytesIO()
606
+ doc = Document()
607
+ doc.add_paragraph(text)
608
+ doc.save(bio)
609
+ bio.seek(0)
610
+ return bio.read()
611
+ except Exception:
612
+ return None
613
+
614
+ def make_zip_bytes(files_dict: dict):
615
+ bio = io.BytesIO()
616
+ with zipfile.ZipFile(bio, "w", compression=zipfile.ZIP_DEFLATED) as zf:
617
+ for name, b in files_dict.items():
618
+ if isinstance(b, str):
619
+ b = b.encode("utf-8")
620
+ zf.writestr(name, b or b"")
621
+ bio.seek(0)
622
+ return bio.read()
623
+
624
+ # ---------- Batch streaming transcription (new) ----------
625
+ def batch_transcribe_stream(friendly_selected, uploaded_files, model_name, device_name, merge_flag, enable_mem, generate_srt, use_two_pass=False, fast_model="small", refine_threshold=-1.0):
626
+ logs_lines = []
627
  transcripts = []
628
+ per_file_docx = {}
629
  srt_files = []
 
630
  paths = []
631
+
632
  if friendly_selected:
633
  for key in friendly_selected:
634
  p = EXTRACT_MAP.get(key)
635
  if p:
636
  paths.append(p)
637
  else:
638
+ logs_lines.append(f"[WARN] Selected not found in extract map: {key}")
639
+
640
  if uploaded_files:
641
  if isinstance(uploaded_files, (list, tuple)):
642
  for f in uploaded_files:
643
  paths.append(str(f))
644
  else:
645
  paths.append(str(uploaded_files))
646
+
647
  if not paths:
648
+ logs_lines.append("No files selected or uploaded.")
649
+ yield "", "\n".join(logs_lines), None, None, 100, None
650
+ return
651
 
652
  total = len(paths)
653
+ logs_lines.append(f"Batch: {total} file(s) to process.")
654
+ yield "", "\n".join(logs_lines), None, None, 2, None
655
+
656
  for idx, p in enumerate(paths, start=1):
657
+ try:
658
+ logs_lines.append(f"[{idx}/{total}] Checking cache for: {os.path.basename(p)}")
659
+ yield "\n\n".join(transcripts), "\n".join(logs_lines), None, None, int(5 + (idx - 1) * 80 / total), None
660
+
661
+ cache_key = make_cache_key(p, model_name, device_name, enable_mem, use_two_pass, fast_model, refine_threshold)
662
+ cached = cache_get(cache_key)
663
+ if cached:
664
+ logs_lines.append(f"[{idx}/{total}] Cache hit: returning cached transcription.")
665
+ text = cached.get("text", "")
666
+ transcripts.append(f"FILE: {os.path.basename(p)}\n{text}\n")
667
+ docx_b = cached.get("docx_bytes")
668
+ if docx_b:
669
+ fname = f"{os.path.splitext(os.path.basename(p))[0]}.docx"
670
+ per_file_docx[fname] = docx_b
671
+ srt_b = cached.get("srt_bytes")
672
+ if srt_b:
673
+ srt_fp = os.path.join(tempfile.gettempdir(), f"{os.path.splitext(os.path.basename(p))[0]}.srt")
674
+ with open(srt_fp, "wb") as fh:
675
+ fh.write(srt_b)
676
+ srt_files.append(srt_fp)
677
+ yield "\n\n".join(transcripts), "\n".join(logs_lines), None, (srt_files[0] if srt_files else None), int(5 + idx * 80 / total), None
678
+ continue
679
+
680
+ logs_lines.append(f"[{idx}/{total}] Transcribing: {p}")
681
+ yield "\n\n".join(transcripts), "\n".join(logs_lines), None, None, int(5 + (idx - 1) * 80 / total), None
682
+
683
+ text, srt_path, lg = transcribe_single_file(
684
+ p,
685
+ model_name=model_name,
686
+ device_choice=device_name,
687
+ enable_memory=enable_mem,
688
+ generate_srt=generate_srt,
689
+ use_two_pass=use_two_pass,
690
+ fast_model=fast_model,
691
+ refine_model=model_name,
692
+ refine_threshold=refine_threshold,
693
+ )
694
+ logs_lines.append(lg or "")
695
+ if not text:
696
+ logs_lines.append(f"[{idx}/{total}] No transcript returned or error for {p}.")
697
+ transcripts.append(f"FILE: {os.path.basename(p)}\n{text}\n")
698
+
699
+ docx_b = make_docx_bytes(text)
700
+ if docx_b:
701
+ fname = f"{os.path.splitext(os.path.basename(p))[0]}.docx"
702
+ per_file_docx[fname] = docx_b
703
+
704
+ srt_b = None
705
+ if srt_path and os.path.exists(srt_path):
706
+ with open(srt_path, "rb") as fh:
707
+ srt_b = fh.read()
708
+ srt_files.append(srt_path)
709
+
710
+ cache_put(cache_key, {"text": text, "docx_bytes": docx_b, "srt_bytes": srt_b})
711
+
712
+ yield "\n\n".join(transcripts), "\n".join(logs_lines), None, (srt_files[0] if srt_files else None), int(5 + idx * 80 / total), None
713
+
714
+ except Exception as e:
715
+ logs_lines.append(f"[{idx}/{total}] Error processing {p}: {e}\n{traceback.format_exc()}")
716
+ yield "\n\n".join(transcripts), "\n".join(logs_lines), None, None, int(5 + idx * 80 / total), None
717
+ continue
718
+
719
  combined = "\n\n".join(transcripts)
720
+ merged_docx_path = None
721
  if merge_flag:
722
  try:
723
+ merged_docx_path = save_as_word(combined)
724
+ logs_lines.append(f"Merged DOCX written: {merged_docx_path}")
725
  except Exception as e:
726
+ logs_lines.append(f"Could not write merged DOCX: {e}")
727
+
728
+ per_files_zip_path = None
729
+ if per_file_docx:
730
+ try:
731
+ zip_bytes = make_zip_bytes(per_file_docx)
732
+ zip_fp = os.path.join(tempfile.gettempdir(), f"batch_docx_{uuid4().hex}.zip")
733
+ with open(zip_fp, "wb") as zf:
734
+ zf.write(zip_bytes)
735
+ per_files_zip_path = zip_fp
736
+ logs_lines.append(f"Per-file DOCX ZIP created: {per_files_zip_path}")
737
+ except Exception as e:
738
+ logs_lines.append(f"Failed to create per-file DOCX ZIP: {e}")
739
+
740
+ yield combined, "\n".join(logs_lines), merged_docx_path, (srt_files[0] if srt_files else None), 100, per_files_zip_path
741
 
742
+ # ---------- Build Gradio UI ----------
743
  print("DEBUG: building Gradio UI", flush=True)
744
  available_choices, default_choice = safe_model_choices(prefer_default="small")
745
 
 
771
  """
772
 
773
  with gr.Blocks(title="Whisper Transcriber (3.x)", css=CSS) as demo:
774
+ # Theme initializer: default to dark unless user has saved a preference
775
  gr.HTML("""
776
  <script>
777
  (function() {
 
781
  var chosen = null;
782
  if (saved === 'dark' || saved === 'light') {
783
  chosen = saved;
 
 
784
  } else {
785
+ // default to dark unless user explicitly chose otherwise earlier
786
+ chosen = 'dark';
787
  }
788
  document.documentElement.setAttribute('data-theme', chosen);
 
 
 
 
 
 
 
789
  } catch (e) { console.warn('theme init failed', e); }
790
  })();
791
  </script>
792
  """)
793
 
 
794
  with gr.Row():
795
  with gr.Column(scale=0):
796
  gr.HTML("<div style='width:50px;height:50px;border-radius:10px;background:linear-gradient(135deg,#4f46e5,#06b6d4);display:flex;align-items:center;justify-content:center;color:white;font-weight:700;font-size:20px;'>WT</div>")
797
  with gr.Column():
798
  gr.Markdown("<h3 style='margin:0'>Whisper Transcriber (Gradio 3.x)</h3>")
799
+ gr.Markdown("<div class='small-note'>Chunked streaming, experimental CPU parallel, per-run ZIP extraction, memory corrections, SRT export, dark theme default</div>")
800
 
801
  with gr.Tabs():
802
  # Single audio
 
809
  device_choice = gr.Dropdown(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
810
  mem_toggle = gr.Checkbox(label="Enable memory corrections", value=False)
811
  srt_toggle = gr.Checkbox(label="Generate SRT", value=False)
 
 
812
  chunk_size_input = gr.Number(value=30, label="Chunk size (seconds)", precision=0)
813
  enable_chunking = gr.Checkbox(label="Enable chunking (recommended for long files)", value=True)
 
814
  parallel_checkbox = gr.Checkbox(label="Enable experimental parallel chunk transcription (CPU only)", value=False)
815
  parallel_workers = gr.Slider(minimum=1, maximum=max(1, os.cpu_count() or 4), value=2, step=1, label="Parallel workers (processes)")
816
  use_two_pass_single = gr.Checkbox(label="Use two-pass speedup (fast then refine)", value=False)
817
+ fast_model_choice = gr.Dropdown(choices=[c for c in ["tiny", "base", "small"] if c in AVAILABLE_MODEL_SET], value="small", label="Fast model")
818
  refine_threshold_single = gr.Number(value=-1.0, label="Refine threshold (avg_logprob)", precision=2)
819
  transcribe_btn = gr.Button("Transcribe", variant="primary")
820
  with gr.Column(scale=1):
 
824
  srt_download = gr.File(label="SRT (if generated)")
825
  single_logs = gr.Textbox(label="Logs", lines=8, interactive=False)
826
 
 
827
  def _single_generator(audio_file, model_name, device, mem_on, srt_on, chunk_size_sec, chunking_enabled, parallel_enabled, workers, use_two_pass_flag, fast_model, refine_thresh):
828
  yield 0, "", None, "Starting..."
829
  try:
 
837
  wav = convert_to_wav_if_needed(path)
838
  yield 8, "", None, f"Converted to WAV: {os.path.basename(wav)}"
839
 
 
840
  duration = None
841
  try:
842
  p = subprocess.run(["ffprobe","-v","error","-show_entries","format=duration","-of","default=noprint_wrappers=1:nokey=1", wav], capture_output=True, text=True, timeout=8)
 
851
  except Exception:
852
  duration = None
853
 
 
854
  if chunking_enabled and (duration and duration > chunk_size_sec * 1.5):
855
  num_chunks = max(1, int((duration + chunk_size_sec - 1) // chunk_size_sec))
856
  chunk_ranges = []
 
865
 
866
  yield 10, "", None, f"Preparing transcription ({len(chunk_ranges)} chunk(s))..."
867
 
 
868
  model = get_whisper_model(model_name, device=None if device == "auto" else device)
869
  yield 15, "", None, f"Model loaded: {model_name}"
870
 
871
  overall_parts = []
872
  total_chunks = len(chunk_ranges)
873
 
 
874
  parallel_used = False
875
  if parallel_enabled and chunking_enabled and total_chunks > 1:
876
  if device != "cpu" and device != "auto":
 
877
  yield 15, "", None, "Parallel mode requested but device is not 'cpu'. Falling back to serial chunking."
878
  parallel_used = False
879
  else:
 
880
  try:
881
  ctx = get_context("spawn")
882
  worker_count = max(1, int(workers))
883
  yield 18, "", None, f"Starting parallel pool with {worker_count} workers (fast_model={fast_model})..."
884
  pool = ctx.Pool(processes=worker_count, initializer=worker_init, initargs=(fast_model, "cpu"))
 
885
  chunk_paths = []
886
  temp_chunk_files = []
887
  for (st, ed) in chunk_ranges:
 
891
  cw = trim_audio_segment(wav, st, ed)
892
  chunk_paths.append(cw)
893
  temp_chunk_files.append(cw)
 
894
  results = pool.map(worker_transcribe_chunk, chunk_paths)
895
  pool.close()
896
  pool.join()
 
897
  for idx, (txt, err) in enumerate(results, start=1):
898
  if err:
899
  yield int(20 + idx * 70 / max(1, total_chunks)), "\n\n".join(overall_parts), None, f"Chunk {idx} worker error: {err}"
 
904
  overall_parts.append(txt)
905
  prog = int(20 + idx * 70 / max(1, total_chunks))
906
  yield prog, "\n\n".join(overall_parts), None, f"Completed chunk {idx}/{total_chunks} (parallel)."
 
907
  for tfile in temp_chunk_files:
908
  try:
909
  if os.path.exists(tfile):
 
916
  parallel_used = False
917
 
918
  if not parallel_used:
 
919
  for idx, (st, ed) in enumerate(chunk_ranges, start=1):
920
  try:
921
  if ed is None:
 
927
 
928
  yield int(15 + (idx - 1) * 70 / max(1, total_chunks)), "", None, f"Transcribing chunk {idx}/{total_chunks} ({note})..."
929
 
930
+ result = model.transcribe(chunk_wav)
 
 
931
  chunk_text = result.get("text", "").strip()
932
 
933
  if mem_on:
 
947
  except Exception as e:
948
  yield int(15 + idx * 70 / max(1, total_chunks)), "\n\n".join(overall_parts), None, f"Chunk {idx} failed: {e}\n{traceback.format_exc()}"
949
 
 
950
  final_text = "\n\n".join([p for p in overall_parts if p])
951
  if mem_on:
952
  try:
 
954
  except Exception:
955
  pass
956
 
 
957
  srt_path = None
958
  if srt_on:
959
  try:
 
969
 
970
  yield 98, final_text, srt_path, "Transcription complete."
971
 
 
972
  try:
973
  if os.path.exists(wav) and not path.lower().endswith(".wav"):
974
  os.unlink(wav)
 
986
  outputs=[progress_num, transcript_out, srt_download, single_logs],
987
  )
988
 
989
+ # Batch tab
990
  with gr.TabItem("Batch Transcribe"):
991
  with gr.Row():
992
  with gr.Column(scale=1):
 
1012
  batch_logs = gr.Textbox(label="Logs", lines=10, interactive=False)
1013
  batch_doc_download = gr.File(label="Merged DOCX (if created)")
1014
  batch_srt_download = gr.File(label="First SRT (if any)")
1015
+ batch_progress = gr.Slider(minimum=0, maximum=100, value=0, step=1, label="Batch Progress (%)", interactive=False)
1016
+ batch_files_zip = gr.File(label="Download per-file DOCX ZIP (all files)", interactive=False)
1017
 
1018
  def _do_extract(zip_file, password):
1019
  if not zip_file:
 
1024
 
1025
  batch_extract_btn.click(fn=_do_extract, inputs=[batch_zip, zip_password], outputs=[batch_select, batch_extract_logs])
1026
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1027
  batch_run_btn.click(
1028
+ fn=batch_transcribe_stream,
1029
  inputs=[batch_select, batch_files, batch_model, batch_device, batch_merge, batch_mem, batch_srt, batch_use_two_pass, batch_fast_model, batch_refine_threshold],
1030
+ outputs=[batch_trans_out, batch_logs, batch_doc_download, batch_srt_download, batch_progress, batch_files_zip],
1031
  )
1032
 
1033
+ # Memory tab
1034
  with gr.TabItem("Memory"):
1035
  with gr.Row():
1036
  with gr.Column(scale=1):
1037
  gr.Markdown("### Correction Memory")
1038
+ # Allow multiple files upload for memory import
1039
+ mem_upload = gr.File(label="Import memory file(s) (JSON or text)", file_count="multiple", type="filepath")
1040
  mem_import_btn = gr.Button("Import Memory")
1041
  mem_text = gr.Textbox(label="Add word/phrase", placeholder="Type word or phrase")
1042
  mem_add_btn = gr.Button("Add to Memory")
 
1045
  mem_status = gr.Textbox(label="Memory status / preview", lines=12, interactive=False)
1046
 
1047
  def _import_mem(uploaded):
1048
+ """
1049
+ Accepts uploaded which may be:
1050
+ - None
1051
+ - a single file-like object / path
1052
+ - a list of file-like objects / paths
1053
+ Processes each file and merges into memory (words & phrases).
1054
+ """
1055
  if not uploaded:
1056
+ return "No file(s) provided."
1057
+
1058
+ files = []
1059
+ # Normalize uploaded to list of paths
1060
+ if isinstance(uploaded, (list, tuple)):
1061
+ for item in uploaded:
1062
+ if not item:
1063
+ continue
1064
+ path = item.name if hasattr(item, "name") else str(item)
1065
+ files.append(path)
1066
+ else:
1067
+ path = uploaded.name if hasattr(uploaded, "name") else str(uploaded)
1068
+ files.append(path)
1069
+
1070
+ total_added = 0
1071
+ merged_words = 0
1072
+ merged_phrases = 0
1073
+ errors = []
1074
+ for path in files:
1075
+ try:
1076
+ with open(path, "r", encoding="utf-8") as fh:
1077
+ raw = fh.read()
1078
+ except Exception as e:
1079
+ errors.append(f"Failed to read {path}: {e}")
1080
+ continue
1081
+
1082
  parsed = None
1083
  try:
1084
  parsed = json.loads(raw)
1085
  except Exception:
1086
  parsed = None
1087
+
1088
  if isinstance(parsed, dict):
1089
  with MEMORY_LOCK:
1090
+ pw = parsed.get("words", {})
1091
+ pp = parsed.get("phrases", {})
1092
+ for k, v in pw.items():
1093
+ try:
1094
+ memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + int(v)
1095
+ except Exception:
1096
+ memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + 1
1097
+ merged_words += 1
1098
+ for k, v in pp.items():
1099
+ try:
1100
+ memory["phrases"][k] = memory["phrases"].get(k, 0) + int(v)
1101
+ except Exception:
1102
+ memory["phrases"][k] = memory["phrases"].get(k, 0) + 1
1103
+ merged_phrases += 1
1104
  save_memory(memory)
1105
+ total_added += (len(pw) + len(pp))
1106
+ continue
1107
+
1108
+ # fallback to newline parsing
1109
  lines = [l.strip() for l in raw.splitlines() if l.strip()]
1110
+ added_here = 0
1111
  with MEMORY_LOCK:
1112
  for line in lines:
1113
  if "," in line:
1114
+ parts = [p.strip() for p in line.split(",", 1)]
1115
+ key = parts[0].lower()
1116
  try:
1117
+ cnt = int(parts[1])
1118
+ except Exception:
1119
  cnt = 1
1120
+ memory["words"][key] = memory["words"].get(key, 0) + cnt
1121
+ merged_words += 1
1122
  else:
1123
+ # short lines -> words, longer -> phrase
1124
+ if len(line.split()) <= 3:
1125
+ memory["words"][line.lower()] = memory["words"].get(line.lower(), 0) + 1
1126
+ merged_words += 1
1127
+ else:
1128
+ memory["phrases"][line] = memory["phrases"].get(line, 0) + 1
1129
+ merged_phrases += 1
1130
+ added_here += 1
1131
  save_memory(memory)
1132
+ total_added += added_here
1133
+
1134
+ msg_parts = [f"Imported {total_added} entries ({merged_words} words, {merged_phrases} phrases)."]
1135
+ if errors:
1136
+ msg_parts.append("Errors:")
1137
+ msg_parts.extend(errors)
1138
+ return "\n".join(msg_parts)
1139
 
1140
  def _add_mem(entry):
1141
  if not entry or not entry.strip():
 
1176
  mem_clear_btn.click(fn=_clear_mem, inputs=[], outputs=[mem_status])
1177
  mem_view_btn.click(fn=_view_mem, inputs=[], outputs=[mem_status])
1178
 
1179
+ # Settings tab
1180
  with gr.TabItem("Settings"):
1181
  with gr.Row():
1182
  with gr.Column():
 
1189
  gr.HTML("""
1190
  <div style="display:flex;align-items:center;gap:12px;">
1191
  <button id="wt_theme_btn" style="display:flex;align-items:center;gap:8px;padding:8px 10px;border-radius:8px;border:1px solid rgba(0,0,0,0.06);background:var(--card);cursor:pointer;">
 
1192
  <span id="wt_theme_label" style="font-weight:600;">Toggle Theme</span>
1193
  </button>
1194
+ <div style="color:var(--muted);font-size:13px;">Theme preference saved in browser</div>
1195
  </div>
1196
  <script>
1197
  (function(){
1198
  try {
1199
  const root = document.documentElement;
1200
  const btn = document.getElementById('wt_theme_btn');
 
 
 
 
 
 
 
 
 
 
 
 
1201
  var saved = null;
1202
  try { saved = localStorage.getItem('wt_theme'); } catch(e){ saved = null; }
1203
  var effective = null;
1204
  if (saved === 'dark' || saved === 'light') {
1205
  effective = saved;
 
 
 
 
1206
  } else {
1207
+ effective = 'dark';
 
1208
  }
1209
  root.setAttribute('data-theme', effective);
 
 
1210
  btn.addEventListener('click', function(){
1211
  try {
1212
  const cur = root.getAttribute('data-theme') === 'dark' ? 'light' : 'dark';
1213
  root.setAttribute('data-theme', cur);
1214
+ try { localStorage.setItem('wt_theme', cur); } catch(e){}
1215
+ } catch(e){ console.error(e); }
 
 
 
1216
  });
1217
+ } catch(e){}
 
 
1218
  })();
1219
  </script>
1220
  """)