staraks's picture
Update app.py
9d6c396 verified
# app.py
# Whisper Transcriber — Gradio 3.x compatible complete file with UI improvements:
# - small buttons, advanced toggle, download selected extracted files,
# - auto-merge per-file transcripts, auto cleanup of temp files after N minutes
# Requirements: gradio (3.x), pydub, pyzipper, python-docx, ffmpeg, whisper or faster-whisper
import os
import sys
import json
import shutil
import tempfile
import subprocess
import traceback
import threading
import re
import zipfile
from difflib import get_close_matches
from uuid import uuid4
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing
import time
# Force unbuffered prints
os.environ["PYTHONUNBUFFERED"] = "1"
try:
import gradio as gr
except Exception as e:
print("FATAL: gradio import failed:", e)
raise
# try faster-whisper first for CPU speedups
USE_FASTER_WHISPER = False
try:
from faster_whisper import WhisperModel as FasterWhisperModel
USE_FASTER_WHISPER = True
print("INFO: faster-whisper detected.")
except Exception:
try:
import whisper
except Exception:
print("FATAL: Neither faster-whisper nor whisper available. Install whisper or faster-whisper.")
raise
from pydub import AudioSegment
import pyzipper
from docx import Document
# ---------- Config ----------
MEMORY_FILE = "memory.json"
MEMORY_LOCK = threading.Lock()
MIN_WAV_SIZE = 1024
FFMPEG_CANDIDATES = [
("s16le", 16000, 1),
("s16le", 44100, 2),
("pcm_s16le", 16000, 1),
("pcm_s16le", 44100, 2),
("mulaw", 8000, 1),
]
MODEL_CACHE = {}
EXTRACT_MAP = {} # friendly_name -> path
LAST_EXTRACT_DIR = None # path to last extraction folder (for download)
LAST_EXTRACT_LIST = [] # friendly names for last extraction (for select all)
DEFAULT_ZIP_PASS = "dietcoke1"
# NEW: last batch transcripts (set by batch generator). Each item: (friendly_name, txt_path, srt_path)
LAST_BATCH_TRANSCRIPTS = []
CPU_COUNT = max(1, multiprocessing.cpu_count())
MAX_WORKERS = min(4, CPU_COUNT) # tune for your environment
# Auto-cleanup configuration (minutes); can be changed in settings UI
AUTO_CLEANUP_MINUTES = 30
# Temp registry for cleanup: entries are tuples (path, created_timestamp)
_TEMP_REGISTRY_LOCK = threading.Lock()
_TEMP_REGISTRY = []
def register_temp_path(p):
"""Register a temp path for later cleanup."""
try:
with _TEMP_REGISTRY_LOCK:
_TEMP_REGISTRY.append((str(p), time.time()))
except Exception:
pass
def cleanup_temp_worker(interval_seconds=60):
"""Background thread to cleanup temp files older than AUTO_CLEANUP_MINUTES."""
while True:
try:
cutoff = time.time() - (AUTO_CLEANUP_MINUTES * 60)
to_delete = []
with _TEMP_REGISTRY_LOCK:
remaining = []
for p, ts in _TEMP_REGISTRY:
if ts < cutoff:
to_delete.append(p)
else:
remaining.append((p, ts))
_TEMP_REGISTRY[:] = remaining
for p in to_delete:
try:
if os.path.isdir(p):
shutil.rmtree(p)
elif os.path.exists(p):
os.unlink(p)
except Exception:
# ignore deletion errors
pass
except Exception:
pass
time.sleep(interval_seconds)
# Start cleanup thread as daemon
_cleanup_thread = threading.Thread(target=cleanup_temp_worker, daemon=True)
_cleanup_thread.start()
# ---------- Memory & postprocessing ----------
def load_memory():
try:
if os.path.exists(MEMORY_FILE):
with open(MEMORY_FILE, "r", encoding="utf-8") as fh:
data = json.load(fh)
if not isinstance(data, dict):
raise ValueError("memory.json root not dict")
data.setdefault("words", {})
data.setdefault("phrases", {})
return data
except Exception:
pass
mem = {"words": {}, "phrases": {}}
try:
with open(MEMORY_FILE, "w", encoding="utf-8") as fh:
json.dump(mem, fh, ensure_ascii=False, indent=2)
except Exception:
pass
return mem
def save_memory(mem):
with MEMORY_LOCK:
try:
with open(MEMORY_FILE, "w", encoding="utf-8") as fh:
json.dump(mem, fh, ensure_ascii=False, indent=2)
except Exception:
traceback.print_exc()
memory = load_memory()
MEDICAL_ABBREVIATIONS = {
"pt": "patient",
"dx": "diagnosis",
"hx": "history",
"sx": "symptoms",
"c/o": "complains of",
"bp": "blood pressure",
"hr": "heart rate",
"o2": "oxygen",
"r/o": "rule out",
"adm": "admit",
"disch": "discharge",
}
DRUG_NORMALIZATION = {
"metformin": "Metformin",
"aspirin": "Aspirin",
"amoxicillin": "Amoxicillin",
}
def expand_abbreviations(text):
tokens = re.split(r"(\s+)", text)
out = []
for t in tokens:
key = t.lower().strip(".,;:")
if key in MEDICAL_ABBREVIATIONS:
trailing = ""
m = re.match(r"([A-Za-z0-9/]+)([.,;:]*)", t)
if m:
trailing = m.group(2) or ""
out.append(MEDICAL_ABBREVIATIONS[key] + trailing)
else:
out.append(t)
return "".join(out)
def normalize_drugs(text):
for k, v in DRUG_NORMALIZATION.items():
text = re.sub(rf"\b{k}\b", v, text, flags=re.IGNORECASE)
return text
def punctuation_and_capitalization(text):
text = text.strip()
if not text:
return text
if not re.search(r"[.?!]\s*$", text):
text = text.rstrip() + "."
parts = re.split(r"([.?!]\s+)", text)
out = []
for p in parts:
if p and not re.match(r"[.?!]\s+", p):
out.append(p.capitalize())
else:
out.append(p)
return "".join(out)
def postprocess_transcript(text):
if not text:
return text
t = re.sub(r"\s+", " ", text).strip()
t = expand_abbreviations(t)
t = normalize_drugs(t)
t = punctuation_and_capitalization(t)
return t
def extract_words_and_phrases(text):
words = re.findall(r"[A-Za-z0-9\-']+", text)
sentences = [s.strip() for s in re.split(r"(?<=[.?!])\s+", text) if s.strip()]
return [w for w in words if w.strip()], sentences
def update_memory_with_transcript(transcript):
global memory
words, sentences = extract_words_and_phrases(transcript)
changed = False
with MEMORY_LOCK:
for w in words:
lw = w.lower()
memory["words"][lw] = memory["words"].get(lw, 0) + 1
changed = True
for s in sentences:
memory["phrases"][s] = memory["phrases"].get(s, 0) + 1
changed = True
if changed:
save_memory(memory)
def memory_correct_text(text, min_ratio=0.85):
if not text or (not memory.get("words") and not memory.get("phrases")):
return text
def fix_word(w):
lw = w.lower()
if lw in memory["words"]:
return w
candidates = get_close_matches(lw, memory["words"].keys(), n=1, cutoff=min_ratio)
if candidates:
cand = candidates[0]
if w and w[0].isupper():
return cand.capitalize()
return cand
return w
tokens = re.split(r"(\W+)", text)
corrected_tokens = []
for tok in tokens:
if re.match(r"^[A-Za-z0-9\-']+$", tok):
corrected_tokens.append(fix_word(tok))
else:
corrected_tokens.append(tok)
corrected = "".join(corrected_tokens)
for phrase in sorted(memory.get("phrases", {}).keys(), key=lambda s: -len(s)):
low_phrase = phrase.lower()
if len(low_phrase) < 8:
continue
if low_phrase in corrected.lower():
corrected = re.sub(re.escape(phrase), phrase, corrected, flags=re.IGNORECASE)
return corrected
# ---------- Utilities ----------
def save_as_word(text, filename=None):
if filename is None:
filename = os.path.join(tempfile.gettempdir(), f"merged_transcripts_{uuid4().hex[:8]}.docx")
doc = Document()
doc.add_paragraph(text)
doc.save(filename)
register_temp_path(filename)
return filename
def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
try:
cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"]
if fmt in ("s16le", "pcm_s16le", "mulaw"):
cmd += ["-f", fmt, "-ar", str(sr), "-ac", str(ch), "-i", input_path, out_path]
else:
cmd += ["-i", input_path, "-ar", str(sr), "-ac", str(ch), out_path]
proc = subprocess.run(cmd, capture_output=True, timeout=60, text=True)
stdout_stderr = (proc.stdout or "") + (proc.stderr or "")
if proc.returncode == 0 and os.path.exists(out_path) and os.path.getsize(out_path) > MIN_WAV_SIZE:
return True, stdout_stderr
else:
try:
if os.path.exists(out_path):
os.unlink(out_path)
except Exception:
pass
return False, stdout_stderr
except Exception as e:
try:
if os.path.exists(out_path):
os.unlink(out_path)
except Exception:
pass
return False, str(e)
def convert_to_wav_if_needed(input_path):
input_path = str(input_path)
lower = input_path.lower()
if lower.endswith(".wav"):
return input_path
auto_err = ""
tmp = None
try:
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.close()
AudioSegment.from_file(input_path).export(tmp.name, format="wav")
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > MIN_WAV_SIZE:
register_temp_path(tmp.name)
return tmp.name
else:
try:
os.unlink(tmp.name)
except Exception:
pass
except Exception:
auto_err = traceback.format_exc()
try:
if tmp and os.path.exists(tmp.name):
os.unlink(tmp.name)
except Exception:
pass
diag_dir = tempfile.mkdtemp(prefix="dct_diag_")
register_temp_path(diag_dir)
diag_log = os.path.join(diag_dir, "conversion_diagnostics.txt")
diagnostics = []
for fmt, sr, ch in FFMPEG_CANDIDATES:
out_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
out_wav.close()
register_temp_path(out_wav.name)
success, debug = _ffmpeg_convert(input_path, out_wav.name, fmt, sr, ch)
diagnostics.append(f"TRY fmt={fmt} sr={sr} ch={ch} success={success}\n{debug}\n")
if success:
try:
with open(diag_log, "w", encoding="utf-8") as fh:
fh.write("pydub auto error:\n")
fh.write(auto_err + "\n\n")
fh.write("Successful ffmpeg candidate:\n")
fh.write(f"fmt={fmt} sr={sr} ch={ch}\n\n")
fh.write("Diagnostics:\n")
fh.write("\n".join(diagnostics))
except Exception:
pass
return out_wav.name
else:
try:
if os.path.exists(out_wav.name):
os.unlink(out_wav.name)
except Exception:
pass
try:
fp = subprocess.run(
["ffprobe", "-v", "error", "-show_format", "-show_streams", input_path],
capture_output=True,
text=True,
timeout=10,
)
diagnostics.append("FFPROBE:\n" + (fp.stdout.strip() or fp.stderr.strip()))
except Exception as e:
diagnostics.append("ffprobe failed: " + str(e))
try:
with open(input_path, "rb") as fh:
head = fh.read(512)
diagnostics.append("HEX PREVIEW:\n" + head.hex())
except Exception as e:
diagnostics.append("could not read head: " + str(e))
try:
with open(diag_log, "w", encoding="utf-8") as fh:
fh.write("pydub auto error:\n")
fh.write(auto_err + "\n\n")
fh.write("Full diagnostics:\n\n")
fh.write("\n\n".join(diagnostics))
except Exception as e:
raise Exception(f"Conversion failed; diagnostics write error: {e}")
raise Exception(f"Could not convert file to WAV. Diagnostics saved to: {diag_log}")
# ---------- Model helper ----------
def whisper_available_models():
try:
if USE_FASTER_WHISPER:
return set(["tiny", "base", "small", "medium", "large", "large-v3"])
else:
models = whisper.available_models()
if isinstance(models, (list, tuple, set)):
return set(models)
except Exception:
pass
return set(["tiny", "base", "small", "medium", "large", "large-v3"])
AVAILABLE_MODEL_SET = whisper_available_models()
def safe_model_choices(prefer_default="small"):
base_choices = ["small", "medium", "large", "large-v3", "base", "tiny"]
choices = [m for m in base_choices if m in AVAILABLE_MODEL_SET]
if not choices:
choices = base_choices
default = prefer_default if prefer_default in choices else choices[0]
return choices, default
# ---------- worker used by ProcessPoolExecutor ----------
def _fmt_time(t):
h = int(t // 3600)
m = int((t % 3600) // 60)
s = int(t % 60)
ms = int((t - int(t)) * 1000)
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
def _segments_to_srt(segments):
lines = []
for i, seg in enumerate(segments, start=1):
start = seg.get("start", 0)
end = seg.get("end", 0)
text = seg.get("text", "").strip()
lines.append(str(i))
lines.append(f"{_fmt_time(start)} --> {_fmt_time(end)}")
lines.append(text)
lines.append("")
return "\n".join(lines)
def _worker_transcribe(args):
try:
(file_path, model_name, device_name, enable_memory, generate_srt, use_two_pass, fast_model, refine_threshold) = args
base = os.path.basename(file_path)
log_lines = []
device = None if device_name == "auto" else device_name
model = None
use_fw = False
try:
if USE_FASTER_WHISPER:
model = FasterWhisperModel(model_name, device=device if device else "cpu")
use_fw = True
log_lines.append(f"Worker: faster-whisper loaded {model_name}")
else:
import whisper as _wh
model = _wh.load_model(model_name)
use_fw = False
log_lines.append(f"Worker: whisper loaded {model_name}")
except Exception as e:
log_lines.append(f"Worker model load failed: {e}")
try:
if USE_FASTER_WHISPER:
model = FasterWhisperModel("small", device=device if device else "cpu")
use_fw = True
log_lines.append("Worker: fallback to faster-whisper small")
else:
model = whisper.load_model("small")
use_fw = False
log_lines.append("Worker: fallback whisper small")
except Exception as e2:
return {"file": base, "text_path": None, "srt_path": None, "log": "Model load failed: " + str(e2)}
try:
wav = convert_to_wav_if_needed(file_path)
log_lines.append(f"Converted to WAV: {os.path.basename(wav)}")
except Exception as e:
return {"file": base, "text_path": None, "srt_path": None, "log": "Conversion failed: " + str(e)}
try:
if use_fw:
segments, info = model.transcribe(wav, beam_size=5)
# faster-whisper segments objects differ; build text
text = "".join([getattr(seg, "text", "") for seg in segments]).strip()
srt_out = None
if generate_srt:
srt_lines = []
for idx, seg in enumerate(segments, start=1):
start = getattr(seg, "start", 0)
end = getattr(seg, "end", 0)
txt = getattr(seg, "text", "").strip()
srt_lines.append(str(idx))
srt_lines.append(f"{_fmt_time(start)} --> {_fmt_time(end)}")
srt_lines.append(txt)
srt_lines.append("")
srt_out = "\n".join(srt_lines)
else:
result = model.transcribe(wav)
text = result.get("text", "").strip()
srt_out = _segments_to_srt(result.get("segments")) if generate_srt and result.get("segments") else None
except Exception as e:
return {"file": base, "text_path": None, "srt_path": None, "log": "Transcription failed: " + str(e)}
if enable_memory and text:
text = memory_correct_text(text)
text = postprocess_transcript(text)
txt_tmp = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
txt_tmp.close()
register_temp_path(txt_tmp.name)
with open(txt_tmp.name, "w", encoding="utf-8") as fh:
fh.write(text)
srt_path = None
if generate_srt and srt_out:
srt_tmp = tempfile.NamedTemporaryFile(suffix=".srt", delete=False)
srt_tmp.close()
register_temp_path(srt_tmp.name)
with open(srt_tmp.name, "w", encoding="utf-8") as fh:
fh.write(srt_out)
srt_path = srt_tmp.name
try:
if wav and os.path.exists(wav) and not file_path.lower().endswith(".wav"):
os.unlink(wav)
except Exception:
pass
return {"file": base, "text_path": txt_tmp.name, "srt_path": srt_path, "log": "\n".join(log_lines)}
except Exception as e:
tb = traceback.format_exc()
return {"file": os.path.basename(file_path) if file_path else "unknown", "text_path": None, "srt_path": None, "log": f"Worker exception: {e}\n{tb}"}
# ---------- ZIP extraction & mapping ----------
def extract_zip_and_map(zip_path, zip_password=None):
"""
Extract ZIP into a per-run temp dir, populate EXTRACT_MAP (friendly name -> file path),
and set LAST_EXTRACT_DIR to the extraction folder for download.
Returns (friendly_list, logs_str)
"""
global EXTRACT_MAP, LAST_EXTRACT_DIR, LAST_EXTRACT_LIST
EXTRACT_MAP = {}
LAST_EXTRACT_DIR = None
LAST_EXTRACT_LIST = []
run_id = uuid4().hex
temp_extract_dir = os.path.join(tempfile.gettempdir(), f"extracted_audio_{run_id}")
logs = []
try:
os.makedirs(temp_extract_dir, exist_ok=True)
with pyzipper.ZipFile(zip_path, "r") as zf:
if zip_password:
try:
zf.setpassword(zip_password.encode())
except Exception:
logs.append("Warning: failed to set zip password (continuing).")
count = {}
supported = [".mp3", ".wav", ".aac", ".flac", ".ogg", ".m4a", ".dat", ".dct"]
for info in zf.infolist():
if info.is_dir():
continue
_, ext = os.path.splitext(info.filename)
if ext.lower() not in supported:
continue
try:
zf.extract(info, path=temp_extract_dir)
except RuntimeError as e:
logs.append(f"Password required or incorrect for {info.filename}: {e}")
continue
except Exception as e:
logs.append(f"Error extracting {info.filename}: {e}")
continue
fullp = os.path.normpath(os.path.join(temp_extract_dir, info.filename))
if not os.path.exists(fullp):
continue
base = os.path.basename(info.filename)
key = base
if key in EXTRACT_MAP:
idx = count.get(base, 1) + 1
count[base] = idx
name_only, extn = os.path.splitext(base)
key = f"{name_only} ({idx}){extn}"
else:
count[base] = 1
EXTRACT_MAP[key] = fullp
logs.append(f"Extracted: {info.filename} -> {key}")
if not EXTRACT_MAP:
logs.append("No supported audio files found in ZIP.")
# cleanup temp dir if empty
try:
if os.path.exists(temp_extract_dir) and not os.listdir(temp_extract_dir):
shutil.rmtree(temp_extract_dir)
except Exception:
pass
return [], "\n".join(logs)
friendly = sorted(EXTRACT_MAP.keys())
LAST_EXTRACT_DIR = temp_extract_dir
LAST_EXTRACT_LIST = friendly[:]
register_temp_path(temp_extract_dir)
return friendly, "\n".join(logs)
except Exception as e:
traceback.print_exc()
try:
if os.path.exists(temp_extract_dir):
shutil.rmtree(temp_extract_dir)
except Exception:
pass
return [], f"Extraction failed: {e}"
def download_extracted_folder():
"""
Zip LAST_EXTRACT_DIR and return zip path for download (or None + message if missing).
"""
global LAST_EXTRACT_DIR
if not LAST_EXTRACT_DIR or not os.path.exists(LAST_EXTRACT_DIR):
return None, "No extracted folder available for download."
try:
zip_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
zip_tmp.close()
register_temp_path(zip_tmp.name)
with zipfile.ZipFile(zip_tmp.name, "w", compression=zipfile.ZIP_DEFLATED) as zf:
# Walk and add files preserving relative path
for root, dirs, files in os.walk(LAST_EXTRACT_DIR):
for f in files:
fullp = os.path.join(root, f)
rel = os.path.relpath(fullp, LAST_EXTRACT_DIR)
zf.write(fullp, arcname=rel)
return zip_tmp.name, "OK"
except Exception as e:
return None, f"Failed to create ZIP: {e}"
def download_selected_extracted_files(selected_keys):
"""
Create a ZIP containing only the selected extracted files.
Returns the zip path or None.
"""
if not selected_keys:
return None, "No files selected."
entries = []
for k in selected_keys:
p = EXTRACT_MAP.get(k)
if p and os.path.exists(p):
entries.append((k, p))
if not entries:
return None, "No valid selected files found."
tmpzip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
tmpzip.close()
register_temp_path(tmpzip.name)
try:
with zipfile.ZipFile(tmpzip.name, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for k, p in entries:
arcname = k
try:
zf.write(p, arcname=arcname)
except Exception:
zf.write(p, arcname=os.path.basename(p))
return tmpzip.name, "OK"
except Exception as e:
return None, f"Failed to create selected ZIP: {e}"
# ---------- Merge uploaded text files into single Word file ----------
def merge_text_files_to_docx(uploaded_text_files):
"""
Accepts a list of uploaded text file paths (or single path), merges them in order into one .docx and returns path.
"""
if not uploaded_text_files:
return None, "No files provided."
if isinstance(uploaded_text_files, (str, os.PathLike)):
uploaded_text_files = [str(uploaded_text_files)]
elif isinstance(uploaded_text_files, dict) and uploaded_text_files.get("name"):
uploaded_text_files = [uploaded_text_files["name"]]
elif isinstance(uploaded_text_files, (list, tuple)):
normalized = []
for f in uploaded_text_files:
if isinstance(f, (str, os.PathLike)):
normalized.append(str(f))
elif isinstance(f, dict) and f.get("name"):
normalized.append(f["name"])
elif hasattr(f, "name"):
normalized.append(f.name)
uploaded_text_files = normalized
combined = []
for p in uploaded_text_files:
if not os.path.exists(p):
continue
try:
with open(p, "r", encoding="utf-8") as fh:
txt = fh.read()
except Exception:
with open(p, "r", encoding="latin-1", errors="replace") as fh:
txt = fh.read()
header = f"\n\n--- {os.path.basename(p)} ---\n\n"
combined.append(header + txt)
if not combined:
return None, "No readable text files."
merged_text = "\n".join(combined)
out_path = save_as_word(merged_text)
return out_path, "Merged"
# ---------- NEW: merge last batch transcripts ----------
def merge_last_batch_transcripts():
"""
Merge txt transcripts created by the last batch run (LAST_BATCH_TRANSCRIPTS) into a single .docx.
Returns (path_or_None, message)
"""
global LAST_BATCH_TRANSCRIPTS
if not LAST_BATCH_TRANSCRIPTS:
return None, "No last-batch transcripts available."
combined = []
for fname, txtp, srtp in LAST_BATCH_TRANSCRIPTS:
if not txtp or not os.path.exists(txtp):
continue
try:
with open(txtp, "r", encoding="utf-8", errors="replace") as fh:
txt = fh.read()
except Exception:
try:
with open(txtp, "r", encoding="latin-1", errors="replace") as fh:
txt = fh.read()
except Exception:
txt = ""
header = f"\n\n--- {fname} ---\n\n"
combined.append(header + txt)
if not combined:
return None, "No readable last-batch transcript files found."
merged_text = "\n".join(combined)
out_path = save_as_word(merged_text)
return out_path, f"Merged {len(combined)} files."
# ---------- Batch transcription generator (streaming) ----------
def batch_transcribe_parallel_generator(
friendly_selected,
uploaded_files,
model_name,
device_name,
merge_flag,
enable_mem,
generate_srt,
use_two_pass=False,
fast_model="small",
refine_threshold=-1.0,
zip_password=None,
auto_merge_per_file=True,
):
global LAST_BATCH_TRANSCRIPTS
LAST_BATCH_TRANSCRIPTS = [] # reset at start
logs = []
transcripts = []
per_file_paths = []
try:
paths = []
# gather selected extracted paths
if friendly_selected:
for key in friendly_selected:
p = EXTRACT_MAP.get(key)
if p:
paths.append(p)
else:
logs.append(f"Warning: selected not found in extract map: {key}")
# uploaded files
if uploaded_files:
if isinstance(uploaded_files, (list, tuple)):
for f in uploaded_files:
paths.append(str(f))
else:
paths.append(str(uploaded_files))
if not paths:
logs.append("No files selected or uploaded.")
yield "\n\n".join(logs), "", None, 100
return
total = len(paths)
logs.append(f"Starting batch of {total} files with up to {MAX_WORKERS} workers.")
yield "\n\n".join(logs), "", None, 2
tasks = []
for p in paths:
tasks.append((p, model_name, device_name, enable_mem, generate_srt, use_two_pass, fast_model, refine_threshold))
completed = 0
with ProcessPoolExecutor(max_workers=min(MAX_WORKERS, total)) as exe:
futs = {exe.submit(_worker_transcribe, t): t for t in tasks}
for fut in as_completed(futs):
res = fut.result()
completed += 1
fname = res.get("file")
res_log = res.get("log", "")
logs.append(f"[{completed}/{total}] {fname}: {res_log}")
txtp = res.get("text_path")
srtp = res.get("srt_path")
if txtp:
try:
with open(txtp, "r", encoding="utf-8") as fh:
txt_content = fh.read()
except Exception:
with open(txtp, "r", encoding="latin-1", errors="replace") as fh:
txt_content = fh.read()
transcripts.append(f"FILE: {fname}\n{txt_content}\n")
per_file_paths.append((fname, txtp, srtp))
pct = int(5 + (completed / total) * 90)
yield "\n\n".join(logs), "\n\n".join(transcripts), None, pct
# Save per-file transcript list into global for later merging/downloading
LAST_BATCH_TRANSCRIPTS = per_file_paths[:]
combined = "\n\n".join(transcripts)
out_doc = None
if merge_flag or auto_merge_per_file:
try:
out_doc = save_as_word(combined)
logs.append(f"Merged saved: {out_doc}")
except Exception as e:
logs.append(f"Merge failed: {e}")
# Create ZIP of per-file transcripts (not original audio)
if per_file_paths:
zip_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
zip_tmp.close()
register_temp_path(zip_tmp.name)
with zipfile.ZipFile(zip_tmp.name, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for fname, txtp, srtp in per_file_paths:
arc_txt = f"{fname}.txt"
try:
zf.write(txtp, arcname=arc_txt)
except Exception:
zf.write(txtp, arcname=os.path.basename(txtp))
if srtp and os.path.exists(srtp):
arc_srt = f"{fname}.srt"
try:
zf.write(srtp, arcname=arc_srt)
except Exception:
zf.write(srtp, arcname=os.path.basename(srtp))
logs.append(f"Per-file transcripts ZIP: {zip_tmp.name}")
yield "\n\n".join(logs), combined, zip_tmp.name, 100
else:
yield "\n\n".join(logs), combined, out_doc, 100
except Exception as e:
tb = traceback.format_exc()
logs.append(f"Batch error: {e}\n{tb}")
yield "\n\n".join(logs), "\n\n".join(transcripts), None, 100
# ---------- Memory import helpers ----------
def _read_file_text_try_encodings(path):
encodings = ["utf-8", "utf-16", "latin-1"]
for enc in encodings:
try:
with open(path, "r", encoding=enc) as fh:
return fh.read(), enc
except UnicodeDecodeError:
continue
except Exception:
break
try:
with open(path, "rb") as fh:
raw = fh.read()
try:
text = raw.decode("utf-8")
return text, "utf-8(guessed)"
except Exception:
text = raw.decode("latin-1", errors="replace")
return text, "latin-1(replaced)"
except Exception:
return None, None
def _process_single_memory_text(text):
added = 0
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
words = parsed.get("words", {})
phrases = parsed.get("phrases", {})
with MEMORY_LOCK:
for k, v in words.items():
try:
cnt = int(v)
except Exception:
cnt = 1
memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + cnt
added += 1
for k, v in phrases.items():
try:
cnt = int(v)
except Exception:
cnt = 1
memory["phrases"][k] = memory["phrases"].get(k, 0) + cnt
added += 1
return added
except Exception:
pass
lines = [l.strip() for l in text.splitlines() if l.strip()]
with MEMORY_LOCK:
for line in lines:
if "," in line:
parts = [p.strip() for p in line.split(",", 1)]
key = parts[0]
try:
cnt = int(parts[1])
except Exception:
cnt = 1
memory["words"][key.lower()] = memory["words"].get(key.lower(), 0) + cnt
added += 1
else:
if len(line.split()) <= 3:
memory["words"][line.lower()] = memory["words"].get(line.lower(), 0) + 1
added += 1
else:
memory["phrases"][line] = memory["phrases"].get(line, 0) + 1
added += 1
return added
def preview_zip_members_for_memory(zip_path):
members = []
logs = []
try:
with zipfile.ZipFile(zip_path, "r") as zf:
for info in zf.infolist():
if info.is_dir():
continue
name = info.filename
_, ext = os.path.splitext(name)
members.append(name)
if not members:
logs.append("No members found in ZIP.")
else:
logs.append(f"Found {len(members)} members.")
except Exception as e:
logs.append(f"ZIP preview failed: {e}")
return members, "\n".join(logs)
def import_memory_files_multiple(uploaded_files, zip_members_to_import=None):
if not uploaded_files:
return "No files provided."
if isinstance(uploaded_files, (str, os.PathLike)):
uploaded_files = [str(uploaded_files)]
elif isinstance(uploaded_files, dict) and uploaded_files.get("name"):
uploaded_files = [uploaded_files["name"]]
elif isinstance(uploaded_files, (list, tuple)):
normalized = []
for f in uploaded_files:
if isinstance(f, (str, os.PathLike)):
normalized.append(str(f))
elif isinstance(f, dict) and f.get("name"):
normalized.append(f["name"])
elif hasattr(f, "name"):
normalized.append(f.name)
uploaded_files = normalized
total_added = 0
messages = []
skipped = []
for fp in uploaded_files:
try:
if not os.path.exists(fp):
messages.append(f"Missing: {fp}")
continue
if fp.lower().endswith(".zip"):
try:
with zipfile.ZipFile(fp, "r") as zf:
for info in zf.infolist():
if info.is_dir():
continue
name = info.filename
if zip_members_to_import and name not in zip_members_to_import:
continue
try:
with zf.open(info) as member:
raw = member.read()
text = None
for enc in ("utf-8", "utf-16", "latin-1"):
try:
text = raw.decode(enc)
break
except Exception:
text = None
if text is None:
text = raw.decode("latin-1", errors="replace")
added = _process_single_memory_text(text)
total_added += added
messages.append(f"Imported {added} from {name} in {os.path.basename(fp)}")
except Exception as e:
skipped.append(f"{name}: {e}")
continue
except zipfile.BadZipFile:
skipped.append(f"Bad zip: {fp}")
continue
text, used_enc = _read_file_text_try_encodings(fp)
if text is None:
skipped.append(fp)
continue
added = _process_single_memory_text(text)
total_added += added
messages.append(f"Imported {added} from {os.path.basename(fp)} (enc={used_enc})")
except Exception as e:
skipped.append(f"{fp}: {e}")
save_memory(memory)
summary = [f"Total entries added: {total_added}"]
if messages:
summary.append("Details:")
summary.extend(messages)
if skipped:
summary.append("Skipped/failed:")
summary.extend(skipped)
return "\n".join(summary)
# ---------- Build Gradio UI ----------
print("DEBUG: building Gradio UI", flush=True)
available_choices, default_choice = safe_model_choices(prefer_default="small")
# CSS tweaks: small buttons and nicer layout
CSS = """
:root{
--accent:#4f46e5;
--muted:#6b7280;
--card:#ffffff;
--bg:#f7f8fb;
--text:#0f172a;
--transcript-bg:#0f172a;
--transcript-color:#e6eef8;
}
[data-theme="dark"] {
--accent: #7c3aed;
--muted: #9ca3af;
--card: #0b1220;
--bg: #071022;
--text: #e6eef8;
--transcript-bg: #071026;
--transcript-color: #e6eef8;
}
body { background: var(--bg); color: var(--text); font-family: Inter, system-ui, -apple-system, "Segoe UI", Roboto, "Helvetica Neue", Arial; }
.header { padding: 14px; border-radius: 10px; background: linear-gradient(90deg, rgba(79,70,229,0.08), rgba(99,102,241,0.02)); margin-bottom: 12px; display:flex;align-items:center;gap:12px; }
.app-icon { width:50px;height:50px;border-radius:10px;background:linear-gradient(135deg,var(--accent),#06b6d4);display:flex;align-items:center;justify-content:center;color:white;font-weight:700;font-size:20px; }
.card { background:var(--card); border-radius:10px; padding:12px; box-shadow: 0 6px 20px rgba(16,24,40,0.04); }
.transcript-area { white-space:pre-wrap; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, "Roboto Mono", monospace; background: var(--transcript-bg); color: var(--transcript-color); padding:12px; border-radius:8px; min-height:200px; }
.small-note { color:var(--muted); font-size:12px;}
.btn-row { display:flex; gap:8px; margin-top:8px; }
.gr-button.small { padding:6px 8px !important; font-size:12px !important; }
"""
with gr.Blocks(title="Whisper Transcriber — Parallel + Memory", css=CSS) as demo:
# set dark theme by default via injected JS
gr.HTML("""
<script>
(function() {
try {
const saved = localStorage.getItem('wt_theme');
if (saved) {
document.documentElement.setAttribute('data-theme', saved);
} else {
document.documentElement.setAttribute('data-theme', 'dark');
}
} catch (e) { console.warn('theme init failed', e); }
})();
</script>
""")
gr.Markdown("<h3>Whisper Transcriber — Parallel + Memory</h3>")
gr.Markdown("<div class='small-note'>Parallel batch transcription, memory correction, per-file transcript downloads. Use faster-whisper if available for faster CPU performance.</div>")
# Advanced toggle (hidden by default)
adv_toggle = gr.Checkbox(label="Advanced ▾", value=False)
# We'll put advanced controls inside this column and toggle visibility
with gr.Tabs():
# Single file tab
with gr.TabItem("Single File"):
with gr.Row():
with gr.Column(scale=1):
single_audio = gr.Audio(label="Upload audio", type="filepath")
model_sel_single = gr.Dropdown(choices=available_choices, value=default_choice, label="Model")
device_single = gr.Dropdown(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
mem_single = gr.Checkbox(label="Use memory corrections", value=False)
srt_single = gr.Checkbox(label="Generate SRT", value=False)
trans_single_btn = gr.Button("Transcribe", elem_classes="small")
with gr.Column(scale=1):
single_trans_out = gr.Textbox(label="Transcript", lines=14, interactive=False)
# LOGS at bottom
single_logs = gr.Textbox(label="Logs", lines=6, interactive=False)
def _do_single(audio, model_name, device_name, mem_on, srt_on):
if not audio:
return "", "No audio supplied."
path = audio if isinstance(audio, str) else (audio.name if hasattr(audio, "name") else str(audio))
res = _worker_transcribe((path, model_name, device_name, mem_on, srt_on, False, "small", -1.0))
if res.get("text_path"):
try:
with open(res["text_path"], "r", encoding="utf-8", errors="replace") as fh:
content = fh.read()
except Exception:
content = ""
else:
content = ""
logs = res.get("log", "")
return content, logs
trans_single_btn.click(fn=_do_single, inputs=[single_audio, model_sel_single, device_single, mem_single, srt_single], outputs=[single_trans_out, single_logs])
# Batch tab
with gr.TabItem("Batch Transcribe"):
with gr.Row():
with gr.Column(scale=1):
batch_files = gr.File(label="Upload audio files", file_count="multiple", type="filepath")
batch_zip = gr.File(label="Or upload ZIP (optional)", file_count="single", type="filepath")
batch_zip_pass = gr.Textbox(label="ZIP password (if any)", value=DEFAULT_ZIP_PASS)
# Extract and populate list
batch_preview_btn = gr.Button("Extract & List ZIP files", elem_classes="small")
batch_preview_out = gr.Textbox(label="ZIP members (preview)", lines=6, interactive=False)
batch_select = gr.CheckboxGroup(choices=[], label="Select extracted files to include", interactive=True)
# select-all / clear buttons (small)
with gr.Row(elem_classes="btn-row"):
batch_select_all_btn = gr.Button("Select All", elem_classes="small")
batch_clear_select_btn = gr.Button("Clear", elem_classes="small")
batch_download_extracted_btn = gr.Button("Download Extracted (all)", elem_classes="small")
batch_download_selected_btn = gr.Button("Download Selected", elem_classes="small")
batch_extracted_zip = gr.File(label="Downloaded extracted ZIP")
gr.Markdown("### Merge text files")
merge_text_files_input = gr.File(label="Upload text files to merge (.txt/.srt/.json)", file_count="multiple", type="filepath")
merge_text_btn = gr.Button("Merge uploaded text files -> DOCX", elem_classes="small")
merge_text_out = gr.File(label="Merged DOCX download")
# NEW: Merge last batch transcripts
merge_last_batch_btn = gr.Button("Merge Last Batch Transcripts", elem_classes="small")
merge_last_batch_status = gr.Textbox(label="Last-batch merge status", lines=2, interactive=False)
merge_last_batch_download = gr.File(label="Merged last-batch DOCX")
# Transcription parameters (basic)
batch_model = gr.Dropdown(choices=available_choices, value=default_choice, label="Model")
batch_mem = gr.Checkbox(label="Enable memory corrections", value=False)
batch_srt = gr.Checkbox(label="Generate SRTs", value=False)
auto_merge_per_file = gr.Checkbox(label="Auto-merge per-file transcripts", value=True)
# Advanced controls hidden by default
advanced_col = gr.Column(visible=False)
with advanced_col:
batch_device = gr.Dropdown(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
batch_use_two_pass = gr.Checkbox(label="Use two-pass refinement", value=False)
batch_fast_model = gr.Dropdown(choices=[c for c in ["tiny", "base", "small"] if c in AVAILABLE_MODEL_SET], value="small", label="Fast model")
batch_refine_thresh = gr.Number(value=-1.0, label="Refine threshold", precision=2)
batch_merge = gr.Checkbox(label="Merge transcripts into DOCX after run", value=True)
# Start button
batch_run_btn = gr.Button("Start Batch (parallel)", elem_classes="small")
with gr.Column(scale=1):
batch_combined_out = gr.Textbox(label="Combined transcripts", lines=12, interactive=False)
batch_progress = gr.Slider(minimum=0, maximum=100, value=0, step=1, label="Progress (%)", interactive=False)
batch_zip_download = gr.File(label="Download per-file transcripts ZIP")
batch_doc_download = gr.File(label="Download merged DOCX (if created)")
# Logs at bottom
batch_logs_out = gr.Textbox(label="Logs", lines=8, interactive=False)
def _preview_zip_and_populate(zip_file, password):
"""
Extract the zip, populate EXTRACT_MAP and return updated CheckboxGroup choices + preview text.
"""
if not zip_file:
return gr.update(choices=[]), "No ZIP provided."
path = zip_file.name if hasattr(zip_file, "name") else str(zip_file)
friendly, logs = extract_zip_and_map(path, password)
if friendly:
return gr.update(choices=friendly), "\n".join(friendly)
return gr.update(choices=[]), logs
batch_preview_btn.click(fn=_preview_zip_and_populate, inputs=[batch_zip, batch_zip_pass], outputs=[batch_select, batch_preview_out])
def _select_all_batch():
# uses LAST_EXTRACT_LIST set by extract
global LAST_EXTRACT_LIST
if LAST_EXTRACT_LIST:
return gr.update(value=LAST_EXTRACT_LIST)
return gr.update(value=[])
batch_select_all_btn.click(fn=_select_all_batch, inputs=[], outputs=[batch_select])
def _clear_batch_select():
return gr.update(value=[])
batch_clear_select_btn.click(fn=_clear_batch_select, inputs=[], outputs=[batch_select])
def _download_extracted_wrapper():
zip_path, msg = download_extracted_folder()
if zip_path:
return zip_path
return None
batch_download_extracted_btn.click(fn=_download_extracted_wrapper, inputs=[], outputs=[batch_extracted_zip])
def _download_selected_wrapper(selected):
zip_path, msg = download_selected_extracted_files(selected)
if zip_path:
return zip_path
return None
batch_download_selected_btn.click(fn=_download_selected_wrapper, inputs=[batch_select], outputs=[batch_extracted_zip])
def _merge_texts(uploaded_texts):
if not uploaded_texts:
return None, "No files provided."
out_path, msg = merge_text_files_to_docx(uploaded_texts)
if out_path:
return out_path
return None, msg
merge_text_btn.click(fn=_merge_texts, inputs=[merge_text_files_input], outputs=[merge_text_out])
def _merge_last_batch_action():
"""
Merge last batch transcripts (global LAST_BATCH_TRANSCRIPTS) into docx and return file path.
"""
path, msg = merge_last_batch_transcripts()
if path:
return path, msg
return None, msg
merge_last_batch_btn.click(fn=_merge_last_batch_action, inputs=[], outputs=[merge_last_batch_download, merge_last_batch_status])
# show/hide advanced panel when adv_toggle changes
def _toggle_advanced(show):
return gr.update(visible=bool(show))
adv_toggle.change(fn=_toggle_advanced, inputs=[adv_toggle], outputs=[advanced_col])
# wrapper generator — Gradio expects the function itself to be a generator that yields streaming tuples
def _start_batch(friendly_selected, uploaded_files, zip_file, zip_pass, model_name, mem_flag, srt_flag, auto_merge_flag, device_name=None, two_pass=False, fast_model="small", refine_thresh=-1.0, merge_flag=True):
# normalize uploaded_files
up = uploaded_files
if isinstance(up, dict) and up.get("name"):
up = [up["name"]]
gen = batch_transcribe_parallel_generator(
friendly_selected,
up,
model_name,
device_name if device_name is not None else "auto",
merge_flag,
mem_flag,
srt_flag,
use_two_pass=two_pass,
fast_model=fast_model,
refine_threshold=refine_thresh,
zip_password=zip_pass,
auto_merge_per_file=auto_merge_flag,
)
for item in gen:
yield item
# Depending on whether advanced is shown, pass extra params. We connect both basic and advanced inputs
batch_run_btn.click(
fn=_start_batch,
inputs=[batch_select, batch_files, batch_zip, batch_zip_pass, batch_model, batch_mem, batch_srt, auto_merge_per_file,
batch_device, batch_use_two_pass, batch_fast_model, batch_refine_thresh, batch_merge],
outputs=[batch_logs_out, batch_combined_out, batch_zip_download, batch_progress],
)
# Memory tab
with gr.TabItem("Memory"):
with gr.Row():
with gr.Column(scale=1):
mem_upload = gr.File(label="Upload memory files or ZIP (multiple)", file_count="multiple", type="filepath")
mem_preview_zip_btn = gr.Button("Preview ZIP members (for selected ZIPs)", elem_classes="small")
mem_zip_preview_out = gr.Textbox(label="ZIP members (preview)", lines=4, interactive=False)
mem_zip_select = gr.CheckboxGroup(choices=[], label="Select ZIP members to import", interactive=True)
mem_select_all_btn = gr.Button("Select All members", elem_classes="small")
mem_clear_select_btn = gr.Button("Clear selection", elem_classes="small")
mem_import_btn = gr.Button("Import selected files / uploaded files", elem_classes="small")
mem_status = gr.Textbox(label="Import status", lines=6, interactive=False)
mem_textbox = gr.Textbox(label="Add single word/phrase", placeholder="Type word or phrase")
mem_add_btn = gr.Button("Add to memory", elem_classes="small")
mem_clear_btn = gr.Button("Clear memory", elem_classes="small")
mem_view_btn = gr.Button("View memory", elem_classes="small")
with gr.Column(scale=1):
mem_help = gr.Markdown(
"- Upload multiple text/JSON files or ZIPs. Preview ZIP members and choose which members to import.\n"
"- Supported encodings: utf-8, utf-16, latin-1, fallback.\n"
"- JSON format: {\"words\":{\"word\":count}, \"phrases\":{\"phrase\":count}}"
)
# Logs at bottom
mem_logs = gr.Textbox(label="Logs", lines=6, interactive=False)
def _preview_many_zip(uploaded):
if not uploaded:
return "No files."
if isinstance(uploaded, dict) and uploaded.get("name"):
uploaded = [uploaded["name"]]
members_total = []
for f in uploaded:
if f and str(f).lower().endswith(".zip"):
members, log = preview_zip_members_for_memory(str(f))
members_total.extend(members)
if members_total:
return "\n".join(members_total)
return "No ZIPs found or no previewable members."
mem_preview_zip_btn.click(fn=_preview_many_zip, inputs=[mem_upload], outputs=[mem_zip_preview_out])
def _select_all_mem():
# try to use preview box content (not ideal) — but we stored last extract list globally as LAST_EXTRACT_LIST
global LAST_EXTRACT_LIST
if LAST_EXTRACT_LIST:
return gr.update(value=LAST_EXTRACT_LIST)
return gr.update(value=[])
mem_select_all_btn.click(fn=_select_all_mem, inputs=[], outputs=[mem_zip_select])
mem_clear_select_btn.click(fn=_clear_batch_select, inputs=[], outputs=[mem_zip_select])
def _import_mem(uploaded_files, selected_members):
try:
status = import_memory_files_multiple(uploaded_files, zip_members_to_import=selected_members)
return status
except Exception as e:
return f"Import failed: {e}"
mem_import_btn.click(fn=_import_mem, inputs=[mem_upload, mem_zip_select], outputs=[mem_status])
def _add_mem(entry):
if not entry or not entry.strip():
return "No entry provided."
e = entry.strip()
with MEMORY_LOCK:
if len(e.split()) <= 3:
memory["words"][e.lower()] = memory["words"].get(e.lower(), 0) + 1
save_memory(memory)
return f"Added word: {e.lower()}"
else:
memory["phrases"][e] = memory["phrases"].get(e, 0) + 1
save_memory(memory)
return f"Added phrase: {e}"
def _clear_mem():
global memory
with MEMORY_LOCK:
memory = {"words": {}, "phrases": {}}
save_memory(memory)
return "Memory cleared."
def _view_mem():
w = memory.get("words", {})
p = memory.get("phrases", {})
out_lines = []
out_lines.append("WORDS (top 30):")
for k, v in sorted(w.items(), key=lambda kv: -kv[1])[:30]:
out_lines.append(f"{k}: {v}")
out_lines.append("")
out_lines.append("PHRASES (top 20):")
for k, v in sorted(p.items(), key=lambda kv: -kv[1])[:20]:
out_lines.append(f"{k}: {v}")
return "\n".join(out_lines)
mem_add_btn.click(fn=_add_mem, inputs=[mem_textbox], outputs=[mem_status])
mem_clear_btn.click(fn=_clear_mem, inputs=[], outputs=[mem_status])
mem_view_btn.click(fn=_view_mem, inputs=[], outputs=[mem_status])
# Settings tab
with gr.TabItem("Settings"):
gr.Markdown("### Settings & tips")
gr.Markdown(f"- Faster-whisper auto-detected: {USE_FASTER_WHISPER}")
gr.Markdown(f"- Max workers for parallel transcribe: {MAX_WORKERS}")
gr.Markdown("- If memory or RAM is limited, set MAX_WORKERS lower in code.")
# Auto-cleanup settings
cleanup_minutes = gr.Number(value=AUTO_CLEANUP_MINUTES, label="Auto-cleanup minutes (temp files older than this will be removed)", precision=0)
cleanup_status = gr.Textbox(label="Cleanup status", lines=2, interactive=False)
def _set_cleanup_minutes(val):
global AUTO_CLEANUP_MINUTES
try:
v = int(val)
if v < 1:
v = 1
AUTO_CLEANUP_MINUTES = v
return f"Auto-cleanup set to {v} minutes."
except Exception:
return "Invalid value."
cleanup_minutes.change(fn=_set_cleanup_minutes, inputs=[cleanup_minutes], outputs=[cleanup_status])
# ---------- Launch ----------
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print("DEBUG: launching on port", port)
demo.queue().launch(server_name="0.0.0.0", server_port=port)