Spaces:
Running
Running
| # app.py | |
| # Whisper Transcriber — Gradio 3.x compatible complete file with UI improvements: | |
| # - small buttons, advanced toggle, download selected extracted files, | |
| # - auto-merge per-file transcripts, auto cleanup of temp files after N minutes | |
| # Requirements: gradio (3.x), pydub, pyzipper, python-docx, ffmpeg, whisper or faster-whisper | |
| import os | |
| import sys | |
| import json | |
| import shutil | |
| import tempfile | |
| import subprocess | |
| import traceback | |
| import threading | |
| import re | |
| import zipfile | |
| from difflib import get_close_matches | |
| from uuid import uuid4 | |
| from pathlib import Path | |
| from concurrent.futures import ProcessPoolExecutor, as_completed | |
| import multiprocessing | |
| import time | |
| # Force unbuffered prints | |
| os.environ["PYTHONUNBUFFERED"] = "1" | |
| try: | |
| import gradio as gr | |
| except Exception as e: | |
| print("FATAL: gradio import failed:", e) | |
| raise | |
| # try faster-whisper first for CPU speedups | |
| USE_FASTER_WHISPER = False | |
| try: | |
| from faster_whisper import WhisperModel as FasterWhisperModel | |
| USE_FASTER_WHISPER = True | |
| print("INFO: faster-whisper detected.") | |
| except Exception: | |
| try: | |
| import whisper | |
| except Exception: | |
| print("FATAL: Neither faster-whisper nor whisper available. Install whisper or faster-whisper.") | |
| raise | |
| from pydub import AudioSegment | |
| import pyzipper | |
| from docx import Document | |
| # ---------- Config ---------- | |
| MEMORY_FILE = "memory.json" | |
| MEMORY_LOCK = threading.Lock() | |
| MIN_WAV_SIZE = 1024 | |
| FFMPEG_CANDIDATES = [ | |
| ("s16le", 16000, 1), | |
| ("s16le", 44100, 2), | |
| ("pcm_s16le", 16000, 1), | |
| ("pcm_s16le", 44100, 2), | |
| ("mulaw", 8000, 1), | |
| ] | |
| MODEL_CACHE = {} | |
| EXTRACT_MAP = {} # friendly_name -> path | |
| LAST_EXTRACT_DIR = None # path to last extraction folder (for download) | |
| LAST_EXTRACT_LIST = [] # friendly names for last extraction (for select all) | |
| DEFAULT_ZIP_PASS = "dietcoke1" | |
| # NEW: last batch transcripts (set by batch generator). Each item: (friendly_name, txt_path, srt_path) | |
| LAST_BATCH_TRANSCRIPTS = [] | |
| CPU_COUNT = max(1, multiprocessing.cpu_count()) | |
| MAX_WORKERS = min(4, CPU_COUNT) # tune for your environment | |
| # Auto-cleanup configuration (minutes); can be changed in settings UI | |
| AUTO_CLEANUP_MINUTES = 30 | |
| # Temp registry for cleanup: entries are tuples (path, created_timestamp) | |
| _TEMP_REGISTRY_LOCK = threading.Lock() | |
| _TEMP_REGISTRY = [] | |
| def register_temp_path(p): | |
| """Register a temp path for later cleanup.""" | |
| try: | |
| with _TEMP_REGISTRY_LOCK: | |
| _TEMP_REGISTRY.append((str(p), time.time())) | |
| except Exception: | |
| pass | |
| def cleanup_temp_worker(interval_seconds=60): | |
| """Background thread to cleanup temp files older than AUTO_CLEANUP_MINUTES.""" | |
| while True: | |
| try: | |
| cutoff = time.time() - (AUTO_CLEANUP_MINUTES * 60) | |
| to_delete = [] | |
| with _TEMP_REGISTRY_LOCK: | |
| remaining = [] | |
| for p, ts in _TEMP_REGISTRY: | |
| if ts < cutoff: | |
| to_delete.append(p) | |
| else: | |
| remaining.append((p, ts)) | |
| _TEMP_REGISTRY[:] = remaining | |
| for p in to_delete: | |
| try: | |
| if os.path.isdir(p): | |
| shutil.rmtree(p) | |
| elif os.path.exists(p): | |
| os.unlink(p) | |
| except Exception: | |
| # ignore deletion errors | |
| pass | |
| except Exception: | |
| pass | |
| time.sleep(interval_seconds) | |
| # Start cleanup thread as daemon | |
| _cleanup_thread = threading.Thread(target=cleanup_temp_worker, daemon=True) | |
| _cleanup_thread.start() | |
| # ---------- Memory & postprocessing ---------- | |
| def load_memory(): | |
| try: | |
| if os.path.exists(MEMORY_FILE): | |
| with open(MEMORY_FILE, "r", encoding="utf-8") as fh: | |
| data = json.load(fh) | |
| if not isinstance(data, dict): | |
| raise ValueError("memory.json root not dict") | |
| data.setdefault("words", {}) | |
| data.setdefault("phrases", {}) | |
| return data | |
| except Exception: | |
| pass | |
| mem = {"words": {}, "phrases": {}} | |
| try: | |
| with open(MEMORY_FILE, "w", encoding="utf-8") as fh: | |
| json.dump(mem, fh, ensure_ascii=False, indent=2) | |
| except Exception: | |
| pass | |
| return mem | |
| def save_memory(mem): | |
| with MEMORY_LOCK: | |
| try: | |
| with open(MEMORY_FILE, "w", encoding="utf-8") as fh: | |
| json.dump(mem, fh, ensure_ascii=False, indent=2) | |
| except Exception: | |
| traceback.print_exc() | |
| memory = load_memory() | |
| MEDICAL_ABBREVIATIONS = { | |
| "pt": "patient", | |
| "dx": "diagnosis", | |
| "hx": "history", | |
| "sx": "symptoms", | |
| "c/o": "complains of", | |
| "bp": "blood pressure", | |
| "hr": "heart rate", | |
| "o2": "oxygen", | |
| "r/o": "rule out", | |
| "adm": "admit", | |
| "disch": "discharge", | |
| } | |
| DRUG_NORMALIZATION = { | |
| "metformin": "Metformin", | |
| "aspirin": "Aspirin", | |
| "amoxicillin": "Amoxicillin", | |
| } | |
| def expand_abbreviations(text): | |
| tokens = re.split(r"(\s+)", text) | |
| out = [] | |
| for t in tokens: | |
| key = t.lower().strip(".,;:") | |
| if key in MEDICAL_ABBREVIATIONS: | |
| trailing = "" | |
| m = re.match(r"([A-Za-z0-9/]+)([.,;:]*)", t) | |
| if m: | |
| trailing = m.group(2) or "" | |
| out.append(MEDICAL_ABBREVIATIONS[key] + trailing) | |
| else: | |
| out.append(t) | |
| return "".join(out) | |
| def normalize_drugs(text): | |
| for k, v in DRUG_NORMALIZATION.items(): | |
| text = re.sub(rf"\b{k}\b", v, text, flags=re.IGNORECASE) | |
| return text | |
| def punctuation_and_capitalization(text): | |
| text = text.strip() | |
| if not text: | |
| return text | |
| if not re.search(r"[.?!]\s*$", text): | |
| text = text.rstrip() + "." | |
| parts = re.split(r"([.?!]\s+)", text) | |
| out = [] | |
| for p in parts: | |
| if p and not re.match(r"[.?!]\s+", p): | |
| out.append(p.capitalize()) | |
| else: | |
| out.append(p) | |
| return "".join(out) | |
| def postprocess_transcript(text): | |
| if not text: | |
| return text | |
| t = re.sub(r"\s+", " ", text).strip() | |
| t = expand_abbreviations(t) | |
| t = normalize_drugs(t) | |
| t = punctuation_and_capitalization(t) | |
| return t | |
| def extract_words_and_phrases(text): | |
| words = re.findall(r"[A-Za-z0-9\-']+", text) | |
| sentences = [s.strip() for s in re.split(r"(?<=[.?!])\s+", text) if s.strip()] | |
| return [w for w in words if w.strip()], sentences | |
| def update_memory_with_transcript(transcript): | |
| global memory | |
| words, sentences = extract_words_and_phrases(transcript) | |
| changed = False | |
| with MEMORY_LOCK: | |
| for w in words: | |
| lw = w.lower() | |
| memory["words"][lw] = memory["words"].get(lw, 0) + 1 | |
| changed = True | |
| for s in sentences: | |
| memory["phrases"][s] = memory["phrases"].get(s, 0) + 1 | |
| changed = True | |
| if changed: | |
| save_memory(memory) | |
| def memory_correct_text(text, min_ratio=0.85): | |
| if not text or (not memory.get("words") and not memory.get("phrases")): | |
| return text | |
| def fix_word(w): | |
| lw = w.lower() | |
| if lw in memory["words"]: | |
| return w | |
| candidates = get_close_matches(lw, memory["words"].keys(), n=1, cutoff=min_ratio) | |
| if candidates: | |
| cand = candidates[0] | |
| if w and w[0].isupper(): | |
| return cand.capitalize() | |
| return cand | |
| return w | |
| tokens = re.split(r"(\W+)", text) | |
| corrected_tokens = [] | |
| for tok in tokens: | |
| if re.match(r"^[A-Za-z0-9\-']+$", tok): | |
| corrected_tokens.append(fix_word(tok)) | |
| else: | |
| corrected_tokens.append(tok) | |
| corrected = "".join(corrected_tokens) | |
| for phrase in sorted(memory.get("phrases", {}).keys(), key=lambda s: -len(s)): | |
| low_phrase = phrase.lower() | |
| if len(low_phrase) < 8: | |
| continue | |
| if low_phrase in corrected.lower(): | |
| corrected = re.sub(re.escape(phrase), phrase, corrected, flags=re.IGNORECASE) | |
| return corrected | |
| # ---------- Utilities ---------- | |
| def save_as_word(text, filename=None): | |
| if filename is None: | |
| filename = os.path.join(tempfile.gettempdir(), f"merged_transcripts_{uuid4().hex[:8]}.docx") | |
| doc = Document() | |
| doc.add_paragraph(text) | |
| doc.save(filename) | |
| register_temp_path(filename) | |
| return filename | |
| def _ffmpeg_convert(input_path, out_path, fmt, sr, ch): | |
| try: | |
| cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"] | |
| if fmt in ("s16le", "pcm_s16le", "mulaw"): | |
| cmd += ["-f", fmt, "-ar", str(sr), "-ac", str(ch), "-i", input_path, out_path] | |
| else: | |
| cmd += ["-i", input_path, "-ar", str(sr), "-ac", str(ch), out_path] | |
| proc = subprocess.run(cmd, capture_output=True, timeout=60, text=True) | |
| stdout_stderr = (proc.stdout or "") + (proc.stderr or "") | |
| if proc.returncode == 0 and os.path.exists(out_path) and os.path.getsize(out_path) > MIN_WAV_SIZE: | |
| return True, stdout_stderr | |
| else: | |
| try: | |
| if os.path.exists(out_path): | |
| os.unlink(out_path) | |
| except Exception: | |
| pass | |
| return False, stdout_stderr | |
| except Exception as e: | |
| try: | |
| if os.path.exists(out_path): | |
| os.unlink(out_path) | |
| except Exception: | |
| pass | |
| return False, str(e) | |
| def convert_to_wav_if_needed(input_path): | |
| input_path = str(input_path) | |
| lower = input_path.lower() | |
| if lower.endswith(".wav"): | |
| return input_path | |
| auto_err = "" | |
| tmp = None | |
| try: | |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| tmp.close() | |
| AudioSegment.from_file(input_path).export(tmp.name, format="wav") | |
| if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > MIN_WAV_SIZE: | |
| register_temp_path(tmp.name) | |
| return tmp.name | |
| else: | |
| try: | |
| os.unlink(tmp.name) | |
| except Exception: | |
| pass | |
| except Exception: | |
| auto_err = traceback.format_exc() | |
| try: | |
| if tmp and os.path.exists(tmp.name): | |
| os.unlink(tmp.name) | |
| except Exception: | |
| pass | |
| diag_dir = tempfile.mkdtemp(prefix="dct_diag_") | |
| register_temp_path(diag_dir) | |
| diag_log = os.path.join(diag_dir, "conversion_diagnostics.txt") | |
| diagnostics = [] | |
| for fmt, sr, ch in FFMPEG_CANDIDATES: | |
| out_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| out_wav.close() | |
| register_temp_path(out_wav.name) | |
| success, debug = _ffmpeg_convert(input_path, out_wav.name, fmt, sr, ch) | |
| diagnostics.append(f"TRY fmt={fmt} sr={sr} ch={ch} success={success}\n{debug}\n") | |
| if success: | |
| try: | |
| with open(diag_log, "w", encoding="utf-8") as fh: | |
| fh.write("pydub auto error:\n") | |
| fh.write(auto_err + "\n\n") | |
| fh.write("Successful ffmpeg candidate:\n") | |
| fh.write(f"fmt={fmt} sr={sr} ch={ch}\n\n") | |
| fh.write("Diagnostics:\n") | |
| fh.write("\n".join(diagnostics)) | |
| except Exception: | |
| pass | |
| return out_wav.name | |
| else: | |
| try: | |
| if os.path.exists(out_wav.name): | |
| os.unlink(out_wav.name) | |
| except Exception: | |
| pass | |
| try: | |
| fp = subprocess.run( | |
| ["ffprobe", "-v", "error", "-show_format", "-show_streams", input_path], | |
| capture_output=True, | |
| text=True, | |
| timeout=10, | |
| ) | |
| diagnostics.append("FFPROBE:\n" + (fp.stdout.strip() or fp.stderr.strip())) | |
| except Exception as e: | |
| diagnostics.append("ffprobe failed: " + str(e)) | |
| try: | |
| with open(input_path, "rb") as fh: | |
| head = fh.read(512) | |
| diagnostics.append("HEX PREVIEW:\n" + head.hex()) | |
| except Exception as e: | |
| diagnostics.append("could not read head: " + str(e)) | |
| try: | |
| with open(diag_log, "w", encoding="utf-8") as fh: | |
| fh.write("pydub auto error:\n") | |
| fh.write(auto_err + "\n\n") | |
| fh.write("Full diagnostics:\n\n") | |
| fh.write("\n\n".join(diagnostics)) | |
| except Exception as e: | |
| raise Exception(f"Conversion failed; diagnostics write error: {e}") | |
| raise Exception(f"Could not convert file to WAV. Diagnostics saved to: {diag_log}") | |
| # ---------- Model helper ---------- | |
| def whisper_available_models(): | |
| try: | |
| if USE_FASTER_WHISPER: | |
| return set(["tiny", "base", "small", "medium", "large", "large-v3"]) | |
| else: | |
| models = whisper.available_models() | |
| if isinstance(models, (list, tuple, set)): | |
| return set(models) | |
| except Exception: | |
| pass | |
| return set(["tiny", "base", "small", "medium", "large", "large-v3"]) | |
| AVAILABLE_MODEL_SET = whisper_available_models() | |
| def safe_model_choices(prefer_default="small"): | |
| base_choices = ["small", "medium", "large", "large-v3", "base", "tiny"] | |
| choices = [m for m in base_choices if m in AVAILABLE_MODEL_SET] | |
| if not choices: | |
| choices = base_choices | |
| default = prefer_default if prefer_default in choices else choices[0] | |
| return choices, default | |
| # ---------- worker used by ProcessPoolExecutor ---------- | |
| def _fmt_time(t): | |
| h = int(t // 3600) | |
| m = int((t % 3600) // 60) | |
| s = int(t % 60) | |
| ms = int((t - int(t)) * 1000) | |
| return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" | |
| def _segments_to_srt(segments): | |
| lines = [] | |
| for i, seg in enumerate(segments, start=1): | |
| start = seg.get("start", 0) | |
| end = seg.get("end", 0) | |
| text = seg.get("text", "").strip() | |
| lines.append(str(i)) | |
| lines.append(f"{_fmt_time(start)} --> {_fmt_time(end)}") | |
| lines.append(text) | |
| lines.append("") | |
| return "\n".join(lines) | |
| def _worker_transcribe(args): | |
| try: | |
| (file_path, model_name, device_name, enable_memory, generate_srt, use_two_pass, fast_model, refine_threshold) = args | |
| base = os.path.basename(file_path) | |
| log_lines = [] | |
| device = None if device_name == "auto" else device_name | |
| model = None | |
| use_fw = False | |
| try: | |
| if USE_FASTER_WHISPER: | |
| model = FasterWhisperModel(model_name, device=device if device else "cpu") | |
| use_fw = True | |
| log_lines.append(f"Worker: faster-whisper loaded {model_name}") | |
| else: | |
| import whisper as _wh | |
| model = _wh.load_model(model_name) | |
| use_fw = False | |
| log_lines.append(f"Worker: whisper loaded {model_name}") | |
| except Exception as e: | |
| log_lines.append(f"Worker model load failed: {e}") | |
| try: | |
| if USE_FASTER_WHISPER: | |
| model = FasterWhisperModel("small", device=device if device else "cpu") | |
| use_fw = True | |
| log_lines.append("Worker: fallback to faster-whisper small") | |
| else: | |
| model = whisper.load_model("small") | |
| use_fw = False | |
| log_lines.append("Worker: fallback whisper small") | |
| except Exception as e2: | |
| return {"file": base, "text_path": None, "srt_path": None, "log": "Model load failed: " + str(e2)} | |
| try: | |
| wav = convert_to_wav_if_needed(file_path) | |
| log_lines.append(f"Converted to WAV: {os.path.basename(wav)}") | |
| except Exception as e: | |
| return {"file": base, "text_path": None, "srt_path": None, "log": "Conversion failed: " + str(e)} | |
| try: | |
| if use_fw: | |
| segments, info = model.transcribe(wav, beam_size=5) | |
| # faster-whisper segments objects differ; build text | |
| text = "".join([getattr(seg, "text", "") for seg in segments]).strip() | |
| srt_out = None | |
| if generate_srt: | |
| srt_lines = [] | |
| for idx, seg in enumerate(segments, start=1): | |
| start = getattr(seg, "start", 0) | |
| end = getattr(seg, "end", 0) | |
| txt = getattr(seg, "text", "").strip() | |
| srt_lines.append(str(idx)) | |
| srt_lines.append(f"{_fmt_time(start)} --> {_fmt_time(end)}") | |
| srt_lines.append(txt) | |
| srt_lines.append("") | |
| srt_out = "\n".join(srt_lines) | |
| else: | |
| result = model.transcribe(wav) | |
| text = result.get("text", "").strip() | |
| srt_out = _segments_to_srt(result.get("segments")) if generate_srt and result.get("segments") else None | |
| except Exception as e: | |
| return {"file": base, "text_path": None, "srt_path": None, "log": "Transcription failed: " + str(e)} | |
| if enable_memory and text: | |
| text = memory_correct_text(text) | |
| text = postprocess_transcript(text) | |
| txt_tmp = tempfile.NamedTemporaryFile(suffix=".txt", delete=False) | |
| txt_tmp.close() | |
| register_temp_path(txt_tmp.name) | |
| with open(txt_tmp.name, "w", encoding="utf-8") as fh: | |
| fh.write(text) | |
| srt_path = None | |
| if generate_srt and srt_out: | |
| srt_tmp = tempfile.NamedTemporaryFile(suffix=".srt", delete=False) | |
| srt_tmp.close() | |
| register_temp_path(srt_tmp.name) | |
| with open(srt_tmp.name, "w", encoding="utf-8") as fh: | |
| fh.write(srt_out) | |
| srt_path = srt_tmp.name | |
| try: | |
| if wav and os.path.exists(wav) and not file_path.lower().endswith(".wav"): | |
| os.unlink(wav) | |
| except Exception: | |
| pass | |
| return {"file": base, "text_path": txt_tmp.name, "srt_path": srt_path, "log": "\n".join(log_lines)} | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| return {"file": os.path.basename(file_path) if file_path else "unknown", "text_path": None, "srt_path": None, "log": f"Worker exception: {e}\n{tb}"} | |
| # ---------- ZIP extraction & mapping ---------- | |
| def extract_zip_and_map(zip_path, zip_password=None): | |
| """ | |
| Extract ZIP into a per-run temp dir, populate EXTRACT_MAP (friendly name -> file path), | |
| and set LAST_EXTRACT_DIR to the extraction folder for download. | |
| Returns (friendly_list, logs_str) | |
| """ | |
| global EXTRACT_MAP, LAST_EXTRACT_DIR, LAST_EXTRACT_LIST | |
| EXTRACT_MAP = {} | |
| LAST_EXTRACT_DIR = None | |
| LAST_EXTRACT_LIST = [] | |
| run_id = uuid4().hex | |
| temp_extract_dir = os.path.join(tempfile.gettempdir(), f"extracted_audio_{run_id}") | |
| logs = [] | |
| try: | |
| os.makedirs(temp_extract_dir, exist_ok=True) | |
| with pyzipper.ZipFile(zip_path, "r") as zf: | |
| if zip_password: | |
| try: | |
| zf.setpassword(zip_password.encode()) | |
| except Exception: | |
| logs.append("Warning: failed to set zip password (continuing).") | |
| count = {} | |
| supported = [".mp3", ".wav", ".aac", ".flac", ".ogg", ".m4a", ".dat", ".dct"] | |
| for info in zf.infolist(): | |
| if info.is_dir(): | |
| continue | |
| _, ext = os.path.splitext(info.filename) | |
| if ext.lower() not in supported: | |
| continue | |
| try: | |
| zf.extract(info, path=temp_extract_dir) | |
| except RuntimeError as e: | |
| logs.append(f"Password required or incorrect for {info.filename}: {e}") | |
| continue | |
| except Exception as e: | |
| logs.append(f"Error extracting {info.filename}: {e}") | |
| continue | |
| fullp = os.path.normpath(os.path.join(temp_extract_dir, info.filename)) | |
| if not os.path.exists(fullp): | |
| continue | |
| base = os.path.basename(info.filename) | |
| key = base | |
| if key in EXTRACT_MAP: | |
| idx = count.get(base, 1) + 1 | |
| count[base] = idx | |
| name_only, extn = os.path.splitext(base) | |
| key = f"{name_only} ({idx}){extn}" | |
| else: | |
| count[base] = 1 | |
| EXTRACT_MAP[key] = fullp | |
| logs.append(f"Extracted: {info.filename} -> {key}") | |
| if not EXTRACT_MAP: | |
| logs.append("No supported audio files found in ZIP.") | |
| # cleanup temp dir if empty | |
| try: | |
| if os.path.exists(temp_extract_dir) and not os.listdir(temp_extract_dir): | |
| shutil.rmtree(temp_extract_dir) | |
| except Exception: | |
| pass | |
| return [], "\n".join(logs) | |
| friendly = sorted(EXTRACT_MAP.keys()) | |
| LAST_EXTRACT_DIR = temp_extract_dir | |
| LAST_EXTRACT_LIST = friendly[:] | |
| register_temp_path(temp_extract_dir) | |
| return friendly, "\n".join(logs) | |
| except Exception as e: | |
| traceback.print_exc() | |
| try: | |
| if os.path.exists(temp_extract_dir): | |
| shutil.rmtree(temp_extract_dir) | |
| except Exception: | |
| pass | |
| return [], f"Extraction failed: {e}" | |
| def download_extracted_folder(): | |
| """ | |
| Zip LAST_EXTRACT_DIR and return zip path for download (or None + message if missing). | |
| """ | |
| global LAST_EXTRACT_DIR | |
| if not LAST_EXTRACT_DIR or not os.path.exists(LAST_EXTRACT_DIR): | |
| return None, "No extracted folder available for download." | |
| try: | |
| zip_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) | |
| zip_tmp.close() | |
| register_temp_path(zip_tmp.name) | |
| with zipfile.ZipFile(zip_tmp.name, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| # Walk and add files preserving relative path | |
| for root, dirs, files in os.walk(LAST_EXTRACT_DIR): | |
| for f in files: | |
| fullp = os.path.join(root, f) | |
| rel = os.path.relpath(fullp, LAST_EXTRACT_DIR) | |
| zf.write(fullp, arcname=rel) | |
| return zip_tmp.name, "OK" | |
| except Exception as e: | |
| return None, f"Failed to create ZIP: {e}" | |
| def download_selected_extracted_files(selected_keys): | |
| """ | |
| Create a ZIP containing only the selected extracted files. | |
| Returns the zip path or None. | |
| """ | |
| if not selected_keys: | |
| return None, "No files selected." | |
| entries = [] | |
| for k in selected_keys: | |
| p = EXTRACT_MAP.get(k) | |
| if p and os.path.exists(p): | |
| entries.append((k, p)) | |
| if not entries: | |
| return None, "No valid selected files found." | |
| tmpzip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) | |
| tmpzip.close() | |
| register_temp_path(tmpzip.name) | |
| try: | |
| with zipfile.ZipFile(tmpzip.name, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| for k, p in entries: | |
| arcname = k | |
| try: | |
| zf.write(p, arcname=arcname) | |
| except Exception: | |
| zf.write(p, arcname=os.path.basename(p)) | |
| return tmpzip.name, "OK" | |
| except Exception as e: | |
| return None, f"Failed to create selected ZIP: {e}" | |
| # ---------- Merge uploaded text files into single Word file ---------- | |
| def merge_text_files_to_docx(uploaded_text_files): | |
| """ | |
| Accepts a list of uploaded text file paths (or single path), merges them in order into one .docx and returns path. | |
| """ | |
| if not uploaded_text_files: | |
| return None, "No files provided." | |
| if isinstance(uploaded_text_files, (str, os.PathLike)): | |
| uploaded_text_files = [str(uploaded_text_files)] | |
| elif isinstance(uploaded_text_files, dict) and uploaded_text_files.get("name"): | |
| uploaded_text_files = [uploaded_text_files["name"]] | |
| elif isinstance(uploaded_text_files, (list, tuple)): | |
| normalized = [] | |
| for f in uploaded_text_files: | |
| if isinstance(f, (str, os.PathLike)): | |
| normalized.append(str(f)) | |
| elif isinstance(f, dict) and f.get("name"): | |
| normalized.append(f["name"]) | |
| elif hasattr(f, "name"): | |
| normalized.append(f.name) | |
| uploaded_text_files = normalized | |
| combined = [] | |
| for p in uploaded_text_files: | |
| if not os.path.exists(p): | |
| continue | |
| try: | |
| with open(p, "r", encoding="utf-8") as fh: | |
| txt = fh.read() | |
| except Exception: | |
| with open(p, "r", encoding="latin-1", errors="replace") as fh: | |
| txt = fh.read() | |
| header = f"\n\n--- {os.path.basename(p)} ---\n\n" | |
| combined.append(header + txt) | |
| if not combined: | |
| return None, "No readable text files." | |
| merged_text = "\n".join(combined) | |
| out_path = save_as_word(merged_text) | |
| return out_path, "Merged" | |
| # ---------- NEW: merge last batch transcripts ---------- | |
| def merge_last_batch_transcripts(): | |
| """ | |
| Merge txt transcripts created by the last batch run (LAST_BATCH_TRANSCRIPTS) into a single .docx. | |
| Returns (path_or_None, message) | |
| """ | |
| global LAST_BATCH_TRANSCRIPTS | |
| if not LAST_BATCH_TRANSCRIPTS: | |
| return None, "No last-batch transcripts available." | |
| combined = [] | |
| for fname, txtp, srtp in LAST_BATCH_TRANSCRIPTS: | |
| if not txtp or not os.path.exists(txtp): | |
| continue | |
| try: | |
| with open(txtp, "r", encoding="utf-8", errors="replace") as fh: | |
| txt = fh.read() | |
| except Exception: | |
| try: | |
| with open(txtp, "r", encoding="latin-1", errors="replace") as fh: | |
| txt = fh.read() | |
| except Exception: | |
| txt = "" | |
| header = f"\n\n--- {fname} ---\n\n" | |
| combined.append(header + txt) | |
| if not combined: | |
| return None, "No readable last-batch transcript files found." | |
| merged_text = "\n".join(combined) | |
| out_path = save_as_word(merged_text) | |
| return out_path, f"Merged {len(combined)} files." | |
| # ---------- Batch transcription generator (streaming) ---------- | |
| def batch_transcribe_parallel_generator( | |
| friendly_selected, | |
| uploaded_files, | |
| model_name, | |
| device_name, | |
| merge_flag, | |
| enable_mem, | |
| generate_srt, | |
| use_two_pass=False, | |
| fast_model="small", | |
| refine_threshold=-1.0, | |
| zip_password=None, | |
| auto_merge_per_file=True, | |
| ): | |
| global LAST_BATCH_TRANSCRIPTS | |
| LAST_BATCH_TRANSCRIPTS = [] # reset at start | |
| logs = [] | |
| transcripts = [] | |
| per_file_paths = [] | |
| try: | |
| paths = [] | |
| # gather selected extracted paths | |
| if friendly_selected: | |
| for key in friendly_selected: | |
| p = EXTRACT_MAP.get(key) | |
| if p: | |
| paths.append(p) | |
| else: | |
| logs.append(f"Warning: selected not found in extract map: {key}") | |
| # uploaded files | |
| if uploaded_files: | |
| if isinstance(uploaded_files, (list, tuple)): | |
| for f in uploaded_files: | |
| paths.append(str(f)) | |
| else: | |
| paths.append(str(uploaded_files)) | |
| if not paths: | |
| logs.append("No files selected or uploaded.") | |
| yield "\n\n".join(logs), "", None, 100 | |
| return | |
| total = len(paths) | |
| logs.append(f"Starting batch of {total} files with up to {MAX_WORKERS} workers.") | |
| yield "\n\n".join(logs), "", None, 2 | |
| tasks = [] | |
| for p in paths: | |
| tasks.append((p, model_name, device_name, enable_mem, generate_srt, use_two_pass, fast_model, refine_threshold)) | |
| completed = 0 | |
| with ProcessPoolExecutor(max_workers=min(MAX_WORKERS, total)) as exe: | |
| futs = {exe.submit(_worker_transcribe, t): t for t in tasks} | |
| for fut in as_completed(futs): | |
| res = fut.result() | |
| completed += 1 | |
| fname = res.get("file") | |
| res_log = res.get("log", "") | |
| logs.append(f"[{completed}/{total}] {fname}: {res_log}") | |
| txtp = res.get("text_path") | |
| srtp = res.get("srt_path") | |
| if txtp: | |
| try: | |
| with open(txtp, "r", encoding="utf-8") as fh: | |
| txt_content = fh.read() | |
| except Exception: | |
| with open(txtp, "r", encoding="latin-1", errors="replace") as fh: | |
| txt_content = fh.read() | |
| transcripts.append(f"FILE: {fname}\n{txt_content}\n") | |
| per_file_paths.append((fname, txtp, srtp)) | |
| pct = int(5 + (completed / total) * 90) | |
| yield "\n\n".join(logs), "\n\n".join(transcripts), None, pct | |
| # Save per-file transcript list into global for later merging/downloading | |
| LAST_BATCH_TRANSCRIPTS = per_file_paths[:] | |
| combined = "\n\n".join(transcripts) | |
| out_doc = None | |
| if merge_flag or auto_merge_per_file: | |
| try: | |
| out_doc = save_as_word(combined) | |
| logs.append(f"Merged saved: {out_doc}") | |
| except Exception as e: | |
| logs.append(f"Merge failed: {e}") | |
| # Create ZIP of per-file transcripts (not original audio) | |
| if per_file_paths: | |
| zip_tmp = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) | |
| zip_tmp.close() | |
| register_temp_path(zip_tmp.name) | |
| with zipfile.ZipFile(zip_tmp.name, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| for fname, txtp, srtp in per_file_paths: | |
| arc_txt = f"{fname}.txt" | |
| try: | |
| zf.write(txtp, arcname=arc_txt) | |
| except Exception: | |
| zf.write(txtp, arcname=os.path.basename(txtp)) | |
| if srtp and os.path.exists(srtp): | |
| arc_srt = f"{fname}.srt" | |
| try: | |
| zf.write(srtp, arcname=arc_srt) | |
| except Exception: | |
| zf.write(srtp, arcname=os.path.basename(srtp)) | |
| logs.append(f"Per-file transcripts ZIP: {zip_tmp.name}") | |
| yield "\n\n".join(logs), combined, zip_tmp.name, 100 | |
| else: | |
| yield "\n\n".join(logs), combined, out_doc, 100 | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| logs.append(f"Batch error: {e}\n{tb}") | |
| yield "\n\n".join(logs), "\n\n".join(transcripts), None, 100 | |
| # ---------- Memory import helpers ---------- | |
| def _read_file_text_try_encodings(path): | |
| encodings = ["utf-8", "utf-16", "latin-1"] | |
| for enc in encodings: | |
| try: | |
| with open(path, "r", encoding=enc) as fh: | |
| return fh.read(), enc | |
| except UnicodeDecodeError: | |
| continue | |
| except Exception: | |
| break | |
| try: | |
| with open(path, "rb") as fh: | |
| raw = fh.read() | |
| try: | |
| text = raw.decode("utf-8") | |
| return text, "utf-8(guessed)" | |
| except Exception: | |
| text = raw.decode("latin-1", errors="replace") | |
| return text, "latin-1(replaced)" | |
| except Exception: | |
| return None, None | |
| def _process_single_memory_text(text): | |
| added = 0 | |
| try: | |
| parsed = json.loads(text) | |
| if isinstance(parsed, dict): | |
| words = parsed.get("words", {}) | |
| phrases = parsed.get("phrases", {}) | |
| with MEMORY_LOCK: | |
| for k, v in words.items(): | |
| try: | |
| cnt = int(v) | |
| except Exception: | |
| cnt = 1 | |
| memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + cnt | |
| added += 1 | |
| for k, v in phrases.items(): | |
| try: | |
| cnt = int(v) | |
| except Exception: | |
| cnt = 1 | |
| memory["phrases"][k] = memory["phrases"].get(k, 0) + cnt | |
| added += 1 | |
| return added | |
| except Exception: | |
| pass | |
| lines = [l.strip() for l in text.splitlines() if l.strip()] | |
| with MEMORY_LOCK: | |
| for line in lines: | |
| if "," in line: | |
| parts = [p.strip() for p in line.split(",", 1)] | |
| key = parts[0] | |
| try: | |
| cnt = int(parts[1]) | |
| except Exception: | |
| cnt = 1 | |
| memory["words"][key.lower()] = memory["words"].get(key.lower(), 0) + cnt | |
| added += 1 | |
| else: | |
| if len(line.split()) <= 3: | |
| memory["words"][line.lower()] = memory["words"].get(line.lower(), 0) + 1 | |
| added += 1 | |
| else: | |
| memory["phrases"][line] = memory["phrases"].get(line, 0) + 1 | |
| added += 1 | |
| return added | |
| def preview_zip_members_for_memory(zip_path): | |
| members = [] | |
| logs = [] | |
| try: | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| for info in zf.infolist(): | |
| if info.is_dir(): | |
| continue | |
| name = info.filename | |
| _, ext = os.path.splitext(name) | |
| members.append(name) | |
| if not members: | |
| logs.append("No members found in ZIP.") | |
| else: | |
| logs.append(f"Found {len(members)} members.") | |
| except Exception as e: | |
| logs.append(f"ZIP preview failed: {e}") | |
| return members, "\n".join(logs) | |
| def import_memory_files_multiple(uploaded_files, zip_members_to_import=None): | |
| if not uploaded_files: | |
| return "No files provided." | |
| if isinstance(uploaded_files, (str, os.PathLike)): | |
| uploaded_files = [str(uploaded_files)] | |
| elif isinstance(uploaded_files, dict) and uploaded_files.get("name"): | |
| uploaded_files = [uploaded_files["name"]] | |
| elif isinstance(uploaded_files, (list, tuple)): | |
| normalized = [] | |
| for f in uploaded_files: | |
| if isinstance(f, (str, os.PathLike)): | |
| normalized.append(str(f)) | |
| elif isinstance(f, dict) and f.get("name"): | |
| normalized.append(f["name"]) | |
| elif hasattr(f, "name"): | |
| normalized.append(f.name) | |
| uploaded_files = normalized | |
| total_added = 0 | |
| messages = [] | |
| skipped = [] | |
| for fp in uploaded_files: | |
| try: | |
| if not os.path.exists(fp): | |
| messages.append(f"Missing: {fp}") | |
| continue | |
| if fp.lower().endswith(".zip"): | |
| try: | |
| with zipfile.ZipFile(fp, "r") as zf: | |
| for info in zf.infolist(): | |
| if info.is_dir(): | |
| continue | |
| name = info.filename | |
| if zip_members_to_import and name not in zip_members_to_import: | |
| continue | |
| try: | |
| with zf.open(info) as member: | |
| raw = member.read() | |
| text = None | |
| for enc in ("utf-8", "utf-16", "latin-1"): | |
| try: | |
| text = raw.decode(enc) | |
| break | |
| except Exception: | |
| text = None | |
| if text is None: | |
| text = raw.decode("latin-1", errors="replace") | |
| added = _process_single_memory_text(text) | |
| total_added += added | |
| messages.append(f"Imported {added} from {name} in {os.path.basename(fp)}") | |
| except Exception as e: | |
| skipped.append(f"{name}: {e}") | |
| continue | |
| except zipfile.BadZipFile: | |
| skipped.append(f"Bad zip: {fp}") | |
| continue | |
| text, used_enc = _read_file_text_try_encodings(fp) | |
| if text is None: | |
| skipped.append(fp) | |
| continue | |
| added = _process_single_memory_text(text) | |
| total_added += added | |
| messages.append(f"Imported {added} from {os.path.basename(fp)} (enc={used_enc})") | |
| except Exception as e: | |
| skipped.append(f"{fp}: {e}") | |
| save_memory(memory) | |
| summary = [f"Total entries added: {total_added}"] | |
| if messages: | |
| summary.append("Details:") | |
| summary.extend(messages) | |
| if skipped: | |
| summary.append("Skipped/failed:") | |
| summary.extend(skipped) | |
| return "\n".join(summary) | |
| # ---------- Build Gradio UI ---------- | |
| print("DEBUG: building Gradio UI", flush=True) | |
| available_choices, default_choice = safe_model_choices(prefer_default="small") | |
| # CSS tweaks: small buttons and nicer layout | |
| CSS = """ | |
| :root{ | |
| --accent:#4f46e5; | |
| --muted:#6b7280; | |
| --card:#ffffff; | |
| --bg:#f7f8fb; | |
| --text:#0f172a; | |
| --transcript-bg:#0f172a; | |
| --transcript-color:#e6eef8; | |
| } | |
| [data-theme="dark"] { | |
| --accent: #7c3aed; | |
| --muted: #9ca3af; | |
| --card: #0b1220; | |
| --bg: #071022; | |
| --text: #e6eef8; | |
| --transcript-bg: #071026; | |
| --transcript-color: #e6eef8; | |
| } | |
| body { background: var(--bg); color: var(--text); font-family: Inter, system-ui, -apple-system, "Segoe UI", Roboto, "Helvetica Neue", Arial; } | |
| .header { padding: 14px; border-radius: 10px; background: linear-gradient(90deg, rgba(79,70,229,0.08), rgba(99,102,241,0.02)); margin-bottom: 12px; display:flex;align-items:center;gap:12px; } | |
| .app-icon { width:50px;height:50px;border-radius:10px;background:linear-gradient(135deg,var(--accent),#06b6d4);display:flex;align-items:center;justify-content:center;color:white;font-weight:700;font-size:20px; } | |
| .card { background:var(--card); border-radius:10px; padding:12px; box-shadow: 0 6px 20px rgba(16,24,40,0.04); } | |
| .transcript-area { white-space:pre-wrap; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, "Roboto Mono", monospace; background: var(--transcript-bg); color: var(--transcript-color); padding:12px; border-radius:8px; min-height:200px; } | |
| .small-note { color:var(--muted); font-size:12px;} | |
| .btn-row { display:flex; gap:8px; margin-top:8px; } | |
| .gr-button.small { padding:6px 8px !important; font-size:12px !important; } | |
| """ | |
| with gr.Blocks(title="Whisper Transcriber — Parallel + Memory", css=CSS) as demo: | |
| # set dark theme by default via injected JS | |
| gr.HTML(""" | |
| <script> | |
| (function() { | |
| try { | |
| const saved = localStorage.getItem('wt_theme'); | |
| if (saved) { | |
| document.documentElement.setAttribute('data-theme', saved); | |
| } else { | |
| document.documentElement.setAttribute('data-theme', 'dark'); | |
| } | |
| } catch (e) { console.warn('theme init failed', e); } | |
| })(); | |
| </script> | |
| """) | |
| gr.Markdown("<h3>Whisper Transcriber — Parallel + Memory</h3>") | |
| gr.Markdown("<div class='small-note'>Parallel batch transcription, memory correction, per-file transcript downloads. Use faster-whisper if available for faster CPU performance.</div>") | |
| # Advanced toggle (hidden by default) | |
| adv_toggle = gr.Checkbox(label="Advanced ▾", value=False) | |
| # We'll put advanced controls inside this column and toggle visibility | |
| with gr.Tabs(): | |
| # Single file tab | |
| with gr.TabItem("Single File"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| single_audio = gr.Audio(label="Upload audio", type="filepath") | |
| model_sel_single = gr.Dropdown(choices=available_choices, value=default_choice, label="Model") | |
| device_single = gr.Dropdown(choices=["auto", "cpu", "cuda"], value="auto", label="Device") | |
| mem_single = gr.Checkbox(label="Use memory corrections", value=False) | |
| srt_single = gr.Checkbox(label="Generate SRT", value=False) | |
| trans_single_btn = gr.Button("Transcribe", elem_classes="small") | |
| with gr.Column(scale=1): | |
| single_trans_out = gr.Textbox(label="Transcript", lines=14, interactive=False) | |
| # LOGS at bottom | |
| single_logs = gr.Textbox(label="Logs", lines=6, interactive=False) | |
| def _do_single(audio, model_name, device_name, mem_on, srt_on): | |
| if not audio: | |
| return "", "No audio supplied." | |
| path = audio if isinstance(audio, str) else (audio.name if hasattr(audio, "name") else str(audio)) | |
| res = _worker_transcribe((path, model_name, device_name, mem_on, srt_on, False, "small", -1.0)) | |
| if res.get("text_path"): | |
| try: | |
| with open(res["text_path"], "r", encoding="utf-8", errors="replace") as fh: | |
| content = fh.read() | |
| except Exception: | |
| content = "" | |
| else: | |
| content = "" | |
| logs = res.get("log", "") | |
| return content, logs | |
| trans_single_btn.click(fn=_do_single, inputs=[single_audio, model_sel_single, device_single, mem_single, srt_single], outputs=[single_trans_out, single_logs]) | |
| # Batch tab | |
| with gr.TabItem("Batch Transcribe"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| batch_files = gr.File(label="Upload audio files", file_count="multiple", type="filepath") | |
| batch_zip = gr.File(label="Or upload ZIP (optional)", file_count="single", type="filepath") | |
| batch_zip_pass = gr.Textbox(label="ZIP password (if any)", value=DEFAULT_ZIP_PASS) | |
| # Extract and populate list | |
| batch_preview_btn = gr.Button("Extract & List ZIP files", elem_classes="small") | |
| batch_preview_out = gr.Textbox(label="ZIP members (preview)", lines=6, interactive=False) | |
| batch_select = gr.CheckboxGroup(choices=[], label="Select extracted files to include", interactive=True) | |
| # select-all / clear buttons (small) | |
| with gr.Row(elem_classes="btn-row"): | |
| batch_select_all_btn = gr.Button("Select All", elem_classes="small") | |
| batch_clear_select_btn = gr.Button("Clear", elem_classes="small") | |
| batch_download_extracted_btn = gr.Button("Download Extracted (all)", elem_classes="small") | |
| batch_download_selected_btn = gr.Button("Download Selected", elem_classes="small") | |
| batch_extracted_zip = gr.File(label="Downloaded extracted ZIP") | |
| gr.Markdown("### Merge text files") | |
| merge_text_files_input = gr.File(label="Upload text files to merge (.txt/.srt/.json)", file_count="multiple", type="filepath") | |
| merge_text_btn = gr.Button("Merge uploaded text files -> DOCX", elem_classes="small") | |
| merge_text_out = gr.File(label="Merged DOCX download") | |
| # NEW: Merge last batch transcripts | |
| merge_last_batch_btn = gr.Button("Merge Last Batch Transcripts", elem_classes="small") | |
| merge_last_batch_status = gr.Textbox(label="Last-batch merge status", lines=2, interactive=False) | |
| merge_last_batch_download = gr.File(label="Merged last-batch DOCX") | |
| # Transcription parameters (basic) | |
| batch_model = gr.Dropdown(choices=available_choices, value=default_choice, label="Model") | |
| batch_mem = gr.Checkbox(label="Enable memory corrections", value=False) | |
| batch_srt = gr.Checkbox(label="Generate SRTs", value=False) | |
| auto_merge_per_file = gr.Checkbox(label="Auto-merge per-file transcripts", value=True) | |
| # Advanced controls hidden by default | |
| advanced_col = gr.Column(visible=False) | |
| with advanced_col: | |
| batch_device = gr.Dropdown(choices=["auto", "cpu", "cuda"], value="auto", label="Device") | |
| batch_use_two_pass = gr.Checkbox(label="Use two-pass refinement", value=False) | |
| batch_fast_model = gr.Dropdown(choices=[c for c in ["tiny", "base", "small"] if c in AVAILABLE_MODEL_SET], value="small", label="Fast model") | |
| batch_refine_thresh = gr.Number(value=-1.0, label="Refine threshold", precision=2) | |
| batch_merge = gr.Checkbox(label="Merge transcripts into DOCX after run", value=True) | |
| # Start button | |
| batch_run_btn = gr.Button("Start Batch (parallel)", elem_classes="small") | |
| with gr.Column(scale=1): | |
| batch_combined_out = gr.Textbox(label="Combined transcripts", lines=12, interactive=False) | |
| batch_progress = gr.Slider(minimum=0, maximum=100, value=0, step=1, label="Progress (%)", interactive=False) | |
| batch_zip_download = gr.File(label="Download per-file transcripts ZIP") | |
| batch_doc_download = gr.File(label="Download merged DOCX (if created)") | |
| # Logs at bottom | |
| batch_logs_out = gr.Textbox(label="Logs", lines=8, interactive=False) | |
| def _preview_zip_and_populate(zip_file, password): | |
| """ | |
| Extract the zip, populate EXTRACT_MAP and return updated CheckboxGroup choices + preview text. | |
| """ | |
| if not zip_file: | |
| return gr.update(choices=[]), "No ZIP provided." | |
| path = zip_file.name if hasattr(zip_file, "name") else str(zip_file) | |
| friendly, logs = extract_zip_and_map(path, password) | |
| if friendly: | |
| return gr.update(choices=friendly), "\n".join(friendly) | |
| return gr.update(choices=[]), logs | |
| batch_preview_btn.click(fn=_preview_zip_and_populate, inputs=[batch_zip, batch_zip_pass], outputs=[batch_select, batch_preview_out]) | |
| def _select_all_batch(): | |
| # uses LAST_EXTRACT_LIST set by extract | |
| global LAST_EXTRACT_LIST | |
| if LAST_EXTRACT_LIST: | |
| return gr.update(value=LAST_EXTRACT_LIST) | |
| return gr.update(value=[]) | |
| batch_select_all_btn.click(fn=_select_all_batch, inputs=[], outputs=[batch_select]) | |
| def _clear_batch_select(): | |
| return gr.update(value=[]) | |
| batch_clear_select_btn.click(fn=_clear_batch_select, inputs=[], outputs=[batch_select]) | |
| def _download_extracted_wrapper(): | |
| zip_path, msg = download_extracted_folder() | |
| if zip_path: | |
| return zip_path | |
| return None | |
| batch_download_extracted_btn.click(fn=_download_extracted_wrapper, inputs=[], outputs=[batch_extracted_zip]) | |
| def _download_selected_wrapper(selected): | |
| zip_path, msg = download_selected_extracted_files(selected) | |
| if zip_path: | |
| return zip_path | |
| return None | |
| batch_download_selected_btn.click(fn=_download_selected_wrapper, inputs=[batch_select], outputs=[batch_extracted_zip]) | |
| def _merge_texts(uploaded_texts): | |
| if not uploaded_texts: | |
| return None, "No files provided." | |
| out_path, msg = merge_text_files_to_docx(uploaded_texts) | |
| if out_path: | |
| return out_path | |
| return None, msg | |
| merge_text_btn.click(fn=_merge_texts, inputs=[merge_text_files_input], outputs=[merge_text_out]) | |
| def _merge_last_batch_action(): | |
| """ | |
| Merge last batch transcripts (global LAST_BATCH_TRANSCRIPTS) into docx and return file path. | |
| """ | |
| path, msg = merge_last_batch_transcripts() | |
| if path: | |
| return path, msg | |
| return None, msg | |
| merge_last_batch_btn.click(fn=_merge_last_batch_action, inputs=[], outputs=[merge_last_batch_download, merge_last_batch_status]) | |
| # show/hide advanced panel when adv_toggle changes | |
| def _toggle_advanced(show): | |
| return gr.update(visible=bool(show)) | |
| adv_toggle.change(fn=_toggle_advanced, inputs=[adv_toggle], outputs=[advanced_col]) | |
| # wrapper generator — Gradio expects the function itself to be a generator that yields streaming tuples | |
| def _start_batch(friendly_selected, uploaded_files, zip_file, zip_pass, model_name, mem_flag, srt_flag, auto_merge_flag, device_name=None, two_pass=False, fast_model="small", refine_thresh=-1.0, merge_flag=True): | |
| # normalize uploaded_files | |
| up = uploaded_files | |
| if isinstance(up, dict) and up.get("name"): | |
| up = [up["name"]] | |
| gen = batch_transcribe_parallel_generator( | |
| friendly_selected, | |
| up, | |
| model_name, | |
| device_name if device_name is not None else "auto", | |
| merge_flag, | |
| mem_flag, | |
| srt_flag, | |
| use_two_pass=two_pass, | |
| fast_model=fast_model, | |
| refine_threshold=refine_thresh, | |
| zip_password=zip_pass, | |
| auto_merge_per_file=auto_merge_flag, | |
| ) | |
| for item in gen: | |
| yield item | |
| # Depending on whether advanced is shown, pass extra params. We connect both basic and advanced inputs | |
| batch_run_btn.click( | |
| fn=_start_batch, | |
| inputs=[batch_select, batch_files, batch_zip, batch_zip_pass, batch_model, batch_mem, batch_srt, auto_merge_per_file, | |
| batch_device, batch_use_two_pass, batch_fast_model, batch_refine_thresh, batch_merge], | |
| outputs=[batch_logs_out, batch_combined_out, batch_zip_download, batch_progress], | |
| ) | |
| # Memory tab | |
| with gr.TabItem("Memory"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| mem_upload = gr.File(label="Upload memory files or ZIP (multiple)", file_count="multiple", type="filepath") | |
| mem_preview_zip_btn = gr.Button("Preview ZIP members (for selected ZIPs)", elem_classes="small") | |
| mem_zip_preview_out = gr.Textbox(label="ZIP members (preview)", lines=4, interactive=False) | |
| mem_zip_select = gr.CheckboxGroup(choices=[], label="Select ZIP members to import", interactive=True) | |
| mem_select_all_btn = gr.Button("Select All members", elem_classes="small") | |
| mem_clear_select_btn = gr.Button("Clear selection", elem_classes="small") | |
| mem_import_btn = gr.Button("Import selected files / uploaded files", elem_classes="small") | |
| mem_status = gr.Textbox(label="Import status", lines=6, interactive=False) | |
| mem_textbox = gr.Textbox(label="Add single word/phrase", placeholder="Type word or phrase") | |
| mem_add_btn = gr.Button("Add to memory", elem_classes="small") | |
| mem_clear_btn = gr.Button("Clear memory", elem_classes="small") | |
| mem_view_btn = gr.Button("View memory", elem_classes="small") | |
| with gr.Column(scale=1): | |
| mem_help = gr.Markdown( | |
| "- Upload multiple text/JSON files or ZIPs. Preview ZIP members and choose which members to import.\n" | |
| "- Supported encodings: utf-8, utf-16, latin-1, fallback.\n" | |
| "- JSON format: {\"words\":{\"word\":count}, \"phrases\":{\"phrase\":count}}" | |
| ) | |
| # Logs at bottom | |
| mem_logs = gr.Textbox(label="Logs", lines=6, interactive=False) | |
| def _preview_many_zip(uploaded): | |
| if not uploaded: | |
| return "No files." | |
| if isinstance(uploaded, dict) and uploaded.get("name"): | |
| uploaded = [uploaded["name"]] | |
| members_total = [] | |
| for f in uploaded: | |
| if f and str(f).lower().endswith(".zip"): | |
| members, log = preview_zip_members_for_memory(str(f)) | |
| members_total.extend(members) | |
| if members_total: | |
| return "\n".join(members_total) | |
| return "No ZIPs found or no previewable members." | |
| mem_preview_zip_btn.click(fn=_preview_many_zip, inputs=[mem_upload], outputs=[mem_zip_preview_out]) | |
| def _select_all_mem(): | |
| # try to use preview box content (not ideal) — but we stored last extract list globally as LAST_EXTRACT_LIST | |
| global LAST_EXTRACT_LIST | |
| if LAST_EXTRACT_LIST: | |
| return gr.update(value=LAST_EXTRACT_LIST) | |
| return gr.update(value=[]) | |
| mem_select_all_btn.click(fn=_select_all_mem, inputs=[], outputs=[mem_zip_select]) | |
| mem_clear_select_btn.click(fn=_clear_batch_select, inputs=[], outputs=[mem_zip_select]) | |
| def _import_mem(uploaded_files, selected_members): | |
| try: | |
| status = import_memory_files_multiple(uploaded_files, zip_members_to_import=selected_members) | |
| return status | |
| except Exception as e: | |
| return f"Import failed: {e}" | |
| mem_import_btn.click(fn=_import_mem, inputs=[mem_upload, mem_zip_select], outputs=[mem_status]) | |
| def _add_mem(entry): | |
| if not entry or not entry.strip(): | |
| return "No entry provided." | |
| e = entry.strip() | |
| with MEMORY_LOCK: | |
| if len(e.split()) <= 3: | |
| memory["words"][e.lower()] = memory["words"].get(e.lower(), 0) + 1 | |
| save_memory(memory) | |
| return f"Added word: {e.lower()}" | |
| else: | |
| memory["phrases"][e] = memory["phrases"].get(e, 0) + 1 | |
| save_memory(memory) | |
| return f"Added phrase: {e}" | |
| def _clear_mem(): | |
| global memory | |
| with MEMORY_LOCK: | |
| memory = {"words": {}, "phrases": {}} | |
| save_memory(memory) | |
| return "Memory cleared." | |
| def _view_mem(): | |
| w = memory.get("words", {}) | |
| p = memory.get("phrases", {}) | |
| out_lines = [] | |
| out_lines.append("WORDS (top 30):") | |
| for k, v in sorted(w.items(), key=lambda kv: -kv[1])[:30]: | |
| out_lines.append(f"{k}: {v}") | |
| out_lines.append("") | |
| out_lines.append("PHRASES (top 20):") | |
| for k, v in sorted(p.items(), key=lambda kv: -kv[1])[:20]: | |
| out_lines.append(f"{k}: {v}") | |
| return "\n".join(out_lines) | |
| mem_add_btn.click(fn=_add_mem, inputs=[mem_textbox], outputs=[mem_status]) | |
| mem_clear_btn.click(fn=_clear_mem, inputs=[], outputs=[mem_status]) | |
| mem_view_btn.click(fn=_view_mem, inputs=[], outputs=[mem_status]) | |
| # Settings tab | |
| with gr.TabItem("Settings"): | |
| gr.Markdown("### Settings & tips") | |
| gr.Markdown(f"- Faster-whisper auto-detected: {USE_FASTER_WHISPER}") | |
| gr.Markdown(f"- Max workers for parallel transcribe: {MAX_WORKERS}") | |
| gr.Markdown("- If memory or RAM is limited, set MAX_WORKERS lower in code.") | |
| # Auto-cleanup settings | |
| cleanup_minutes = gr.Number(value=AUTO_CLEANUP_MINUTES, label="Auto-cleanup minutes (temp files older than this will be removed)", precision=0) | |
| cleanup_status = gr.Textbox(label="Cleanup status", lines=2, interactive=False) | |
| def _set_cleanup_minutes(val): | |
| global AUTO_CLEANUP_MINUTES | |
| try: | |
| v = int(val) | |
| if v < 1: | |
| v = 1 | |
| AUTO_CLEANUP_MINUTES = v | |
| return f"Auto-cleanup set to {v} minutes." | |
| except Exception: | |
| return "Invalid value." | |
| cleanup_minutes.change(fn=_set_cleanup_minutes, inputs=[cleanup_minutes], outputs=[cleanup_status]) | |
| # ---------- Launch ---------- | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| print("DEBUG: launching on port", port) | |
| demo.queue().launch(server_name="0.0.0.0", server_port=port) | |