arvind / app.py
staraks's picture
Update app.py
ec6edb7 verified
# app.py
# Whisper Transcriber — Full corrected app.py (multi-tab, Audio Transcribe focused)
# Requirements: gradio, whisper, pydub, pyzipper, python-docx, ffmpeg installed.
import os
import sys
import json
import shutil
import tempfile
import subprocess
import traceback
import threading
import re
from difflib import get_close_matches
from pathlib import Path
# Force unbuffered output so container logs show prints immediately
os.environ["PYTHONUNBUFFERED"] = "1"
print("DEBUG: app.py bootstrap starting", flush=True)
# Third-party imports
try:
import gradio as gr
import whisper
from pydub import AudioSegment
import pyzipper
from docx import Document
except Exception as e:
print("FATAL: import error for third-party libs:", e, flush=True)
traceback.print_exc()
raise
# ---------- Config ----------
MEMORY_FILE = "memory.json"
MEMORY_LOCK = threading.Lock()
MIN_WAV_SIZE = 1024
FFMPEG_CANDIDATES = [
("s16le", 16000, 1),
("s16le", 44100, 2),
("pcm_s16le", 16000, 1),
("pcm_s16le", 44100, 2),
("mulaw", 8000, 1),
]
MODEL_CACHE = {}
FINETUNE_WORKDIR = os.path.join(tempfile.gettempdir(), "finetune_workdir")
os.makedirs(FINETUNE_WORKDIR, exist_ok=True)
# ---------- Helpers: Memory & Postprocessing ----------
def load_memory():
try:
if os.path.exists(MEMORY_FILE):
with open(MEMORY_FILE, "r", encoding="utf-8") as fh:
data = json.load(fh)
if not isinstance(data, dict):
raise ValueError("memory.json root not dict")
data.setdefault("words", {})
data.setdefault("phrases", {})
return data
except Exception:
pass
mem = {"words": {}, "phrases": {}}
try:
with open(MEMORY_FILE, "w", encoding="utf-8") as fh:
json.dump(mem, fh, ensure_ascii=False, indent=2)
except Exception:
pass
return mem
def save_memory(mem):
with MEMORY_LOCK:
try:
with open(MEMORY_FILE, "w", encoding="utf-8") as fh:
json.dump(mem, fh, ensure_ascii=False, indent=2)
except Exception:
traceback.print_exc()
memory = load_memory()
MEDICAL_ABBREVIATIONS = {
"pt": "patient",
"dx": "diagnosis",
"hx": "history",
"sx": "symptoms",
"c/o": "complains of",
"bp": "blood pressure",
"hr": "heart rate",
"o2": "oxygen",
"r/o": "rule out",
"adm": "admit",
"disch": "discharge",
}
DRUG_NORMALIZATION = {
"metformin": "Metformin",
"aspirin": "Aspirin",
"amoxicillin": "Amoxicillin",
}
def expand_abbreviations(text):
tokens = re.split(r"(\s+)", text)
out = []
for t in tokens:
key = t.lower().strip(".,;:")
if key in MEDICAL_ABBREVIATIONS:
trailing = ""
m = re.match(r"([A-Za-z0-9/]+)([.,;:]*)", t)
if m:
trailing = m.group(2) or ""
out.append(MEDICAL_ABBREVIATIONS[key] + trailing)
else:
out.append(t)
return "".join(out)
def normalize_drugs(text):
for k, v in DRUG_NORMALIZATION.items():
text = re.sub(rf"\b{k}\b", v, text, flags=re.IGNORECASE)
return text
def punctuation_and_capitalization(text):
text = text.strip()
if not text:
return text
if not re.search(r"[.?!]\s*$", text):
text = text.rstrip() + "."
parts = re.split(r"([.?!]\s+)", text)
out = []
for p in parts:
if p and not re.match(r"[.?!]\s+", p):
out.append(p.capitalize())
else:
out.append(p)
return "".join(out)
def postprocess_transcript(text):
if not text:
return text
t = re.sub(r"\s+", " ", text).strip()
t = expand_abbreviations(t)
t = normalize_drugs(t)
t = punctuation_and_capitalization(t)
return t
def extract_words_and_phrases(text):
words = re.findall(r"[A-Za-z0-9\-']+", text)
sentences = [s.strip() for s in re.split(r"(?<=[.?!])\s+", text) if s.strip()]
return [w for w in words if w.strip()], sentences
def update_memory_with_transcript(transcript):
global memory
words, sentences = extract_words_and_phrases(transcript)
changed = False
with MEMORY_LOCK:
for w in words:
lw = w.lower()
memory["words"][lw] = memory["words"].get(lw, 0) + 1
changed = True
for s in sentences:
memory["phrases"][s] = memory["phrases"].get(s, 0) + 1
changed = True
if changed:
save_memory(memory)
def memory_correct_text(text, min_ratio=0.85):
if not text or (not memory.get("words") and not memory.get("phrases")):
return text
def fix_word(w):
lw = w.lower()
if lw in memory["words"]:
return w
candidates = get_close_matches(lw, memory["words"].keys(), n=1, cutoff=min_ratio)
if candidates:
cand = candidates[0]
if w and w[0].isupper():
return cand.capitalize()
return cand
return w
tokens = re.split(r"(\W+)", text)
corrected_tokens = []
for tok in tokens:
if re.match(r"^[A-Za-z0-9\-']+$", tok):
corrected_tokens.append(fix_word(tok))
else:
corrected_tokens.append(tok)
corrected = "".join(corrected_tokens)
for phrase in sorted(memory.get("phrases", {}).keys(), key=lambda s: -len(s)):
low_phrase = phrase.lower()
if len(low_phrase) < 8:
continue
if low_phrase in corrected.lower():
corrected = re.sub(re.escape(phrase), phrase, corrected, flags=re.IGNORECASE)
return corrected
# ---------- File utilities ----------
def save_as_word(text, filename=None):
if filename is None:
filename = os.path.join(tempfile.gettempdir(), "merged_transcripts.docx")
doc = Document()
doc.add_paragraph(text)
doc.save(filename)
return filename
# ---------- Conversion helpers ----------
def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
try:
cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"]
if fmt in ("s16le", "pcm_s16le", "mulaw"):
cmd += ["-f", fmt, "-ar", str(sr), "-ac", str(ch), "-i", input_path, out_path]
else:
cmd += ["-i", input_path, "-ar", str(sr), "-ac", str(ch), out_path]
proc = subprocess.run(cmd, capture_output=True, timeout=60, text=True)
stdout_stderr = (proc.stdout or "") + (proc.stderr or "")
if proc.returncode == 0 and os.path.exists(out_path) and os.path.getsize(out_path) > MIN_WAV_SIZE:
return True, stdout_stderr
else:
try:
if os.path.exists(out_path):
os.unlink(out_path)
except Exception:
pass
return False, stdout_stderr
except Exception as e:
try:
if os.path.exists(out_path):
os.unlink(out_path)
except Exception:
pass
return False, str(e)
def convert_to_wav_if_needed(input_path):
input_path = str(input_path)
lower = input_path.lower()
if lower.endswith(".wav"):
return input_path
auto_err = ""
tmp = None
try:
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.close()
AudioSegment.from_file(input_path).export(tmp.name, format="wav")
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > MIN_WAV_SIZE:
return tmp.name
else:
try:
os.unlink(tmp.name)
except Exception:
pass
except Exception:
auto_err = traceback.format_exc()
try:
if tmp and os.path.exists(tmp.name):
os.unlink(tmp.name)
except Exception:
pass
diag_dir = tempfile.mkdtemp(prefix="dct_diag_")
diag_log = os.path.join(diag_dir, "conversion_diagnostics.txt")
diagnostics = []
for fmt, sr, ch in FFMPEG_CANDIDATES:
out_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
out_wav.close()
success, debug = _ffmpeg_convert(input_path, out_wav.name, fmt, sr, ch)
diagnostics.append(f"TRY fmt={fmt} sr={sr} ch={ch} success={success}\n{debug}\n")
if success:
try:
with open(diag_log, "w", encoding="utf-8") as fh:
fh.write("pydub auto error:\n")
fh.write(auto_err + "\n\n")
fh.write("Successful ffmpeg candidate:\n")
fh.write(f"fmt={fmt} sr={sr} ch={ch}\n\n")
fh.write("Diagnostics:\n")
fh.write("\n".join(diagnostics))
except Exception:
pass
return out_wav.name
else:
try:
if os.path.exists(out_wav.name):
os.unlink(out_wav.name)
except Exception:
pass
try:
fp = subprocess.run(
["ffprobe", "-v", "error", "-show_format", "-show_streams", input_path],
capture_output=True,
text=True,
timeout=10,
)
diagnostics.append("FFPROBE:\n" + (fp.stdout.strip() or fp.stderr.strip()))
except Exception as e:
diagnostics.append("ffprobe failed: " + str(e))
try:
with open(input_path, "rb") as fh:
head = fh.read(512)
diagnostics.append("HEX PREVIEW:\n" + head.hex())
except Exception as e:
diagnostics.append("could not read head: " + str(e))
try:
with open(diag_log, "w", encoding="utf-8") as fh:
fh.write("pydub auto error:\n")
fh.write(auto_err + "\n\n")
fh.write("Full diagnostics:\n\n")
fh.write("\n\n".join(diagnostics))
except Exception as e:
raise Exception(f"Conversion failed; diagnostics write error: {e}")
raise Exception(f"Could not convert file to WAV. Diagnostics saved to: {diag_log}")
# ---------- Whisper model loader ----------
def get_whisper_model(name, device=None):
if name not in MODEL_CACHE:
print(f"DEBUG: loading whisper model '{name}'", flush=True)
try:
if device:
MODEL_CACHE[name] = whisper.load_model(name, device=device)
else:
MODEL_CACHE[name] = whisper.load_model(name)
except TypeError:
# some whisper versions don't accept device arg
MODEL_CACHE[name] = whisper.load_model(name)
return MODEL_CACHE[name]
# ---------- ZIP extraction helper ----------
def extract_zip_list(zip_file, zip_password):
temp_extract_dir = os.path.join(tempfile.gettempdir(), "extracted_audio")
try:
if os.path.exists(temp_extract_dir):
try:
shutil.rmtree(temp_extract_dir)
except Exception:
pass
os.makedirs(temp_extract_dir, exist_ok=True)
extracted = []
logs = []
with pyzipper.ZipFile(zip_file, "r") as zf:
if zip_password:
try:
zf.setpassword(zip_password.encode())
except Exception:
logs.append("Warning: failed to set zip password (unexpected).")
exts = [".mp3", ".wav", ".aac", ".flac", ".ogg", ".m4a", ".dat", ".dct"]
for info in zf.infolist():
if info.is_dir():
continue
_, ext = os.path.splitext(info.filename)
if ext.lower() in exts:
try:
zf.extract(info, path=temp_extract_dir)
except RuntimeError as e:
logs.append(f"Password required/incorrect for {info.filename}: {e}")
continue
except pyzipper.BadZipFile:
logs.append(f"Bad zip entry: {info.filename}")
continue
except Exception as e:
logs.append(f"Error extracting {info.filename}: {e}")
continue
p = os.path.normpath(os.path.join(temp_extract_dir, info.filename))
if os.path.exists(p):
extracted.append(p)
logs.append(f"Extracted: {info.filename}")
if not extracted:
logs.append("No supported audio files found in zip.")
return [], "\n".join(logs)
return extracted, "\n".join(logs)
except Exception as e:
traceback.print_exc()
return [], f"Extraction failed: {e}"
# ---------- Simple single-file transcriber ----------
def transcribe_single(audio_path, model_name="small", enable_memory=False, device_choice="auto"):
logs = []
transcript_text = ""
try:
if not audio_path:
return None, "No audio provided.", "No file provided."
path = str(audio_path)
device = None if device_choice == "auto" else device_choice
model = get_whisper_model(model_name, device=device)
logs.append(f"Loaded model: {model_name}")
wav = convert_to_wav_if_needed(path)
logs.append(f"Converted to WAV: {os.path.basename(wav)}")
result = model.transcribe(wav)
text = result.get("text", "").strip()
if enable_memory:
text = memory_correct_text(text)
text = postprocess_transcript(text)
transcript_text = text
if enable_memory:
try:
update_memory_with_transcript(text)
logs.append("Memory updated.")
except Exception:
pass
# cleanup temporary wav if created
if wav and os.path.exists(wav) and wav != path:
try:
os.unlink(wav)
except Exception:
pass
return path, transcript_text, "\n".join(logs)
except Exception as e:
tb = traceback.format_exc()
return None, "", f"Error: {e}\n{tb}"
# ---------- Fine-tune helpers (include old-files support) ----------
def _collect_old_files_into(dst_dir, old_dir_path):
msgs = []
copied = 0
try:
if not os.path.isdir(old_dir_path):
return 0, f"Old-files path is not a directory: {old_dir_path}"
for root, _, files in os.walk(old_dir_path):
for f in files:
if f.lower().endswith((".wav", ".mp3", ".flac", ".m4a", ".ogg")):
src_audio = os.path.join(root, f)
base = os.path.splitext(f)[0]
possible_txt = os.path.join(root, base + ".txt")
rel_subdir = os.path.relpath(root, old_dir_path)
target_subdir = os.path.join(dst_dir, rel_subdir)
os.makedirs(target_subdir, exist_ok=True)
target_audio = os.path.join(target_subdir, f)
shutil.copy2(src_audio, target_audio)
if os.path.exists(possible_txt):
shutil.copy2(possible_txt, os.path.join(target_subdir, base + ".txt"))
msgs.append(f"Copied pair: {os.path.join(rel_subdir, f)} + .txt")
else:
msgs.append(f"Copied audio (no transcript found): {os.path.join(rel_subdir, f)}")
copied += 1
return copied, "\n".join(msgs)
except Exception as e:
traceback.print_exc()
return copied, f"Error copying old files: {e}"
def prepare_finetune_dataset(uploaded_zip_or_dir, include_old_files=False, old_files_dir=""):
dst = os.path.join(FINETUNE_WORKDIR, "data")
try:
if os.path.exists(dst):
shutil.rmtree(dst)
os.makedirs(dst, exist_ok=True)
except Exception as e:
return f"Failed to prepare workdir: {e}", ""
path = None
try:
if uploaded_zip_or_dir:
if isinstance(uploaded_zip_or_dir, (str, os.PathLike)):
path = str(uploaded_zip_or_dir)
elif hasattr(uploaded_zip_or_dir, "name"):
path = uploaded_zip_or_dir.name
elif isinstance(uploaded_zip_or_dir, dict) and uploaded_zip_or_dir.get("name"):
path = uploaded_zip_or_dir["name"]
except Exception as e:
return f"Unable to determine uploaded path: {e}", ""
# extract or copy uploaded dataset if provided
if path and os.path.isfile(path) and path.lower().endswith(".zip"):
try:
with pyzipper.ZipFile(path, "r") as zf:
zf.extractall(dst)
except Exception as e:
return f"Failed to extract ZIP: {e}", ""
elif path and os.path.isdir(path):
try:
for item in os.listdir(path):
s = os.path.join(path, item)
d = os.path.join(dst, item)
if os.path.isdir(s):
shutil.copytree(s, d)
else:
shutil.copy2(s, d)
except Exception as e:
return f"Failed to copy dataset dir: {e}", ""
# include old files if requested
old_msgs = ""
if include_old_files and old_files_dir:
old_path = None
if isinstance(old_files_dir, (str, os.PathLike)):
old_path = str(old_files_dir)
elif hasattr(old_files_dir, "name"):
old_path = old_files_dir.name
elif isinstance(old_files_dir, dict) and old_files_dir.get("name"):
old_path = old_files_dir["name"]
if old_path:
copied, msg = _collect_old_files_into(dst, old_path)
old_msgs = f"\nOld-files: copied {copied} audio files.\nDetails:\n{msg}"
# find or build manifest
transcripts_candidates = [
os.path.join(dst, "transcripts.tsv"),
os.path.join(dst, "metadata.tsv"),
os.path.join(dst, "manifest.tsv"),
os.path.join(dst, "transcripts.txt"),
os.path.join(dst, "manifest.jsonl"),
]
manifest_path = os.path.join(FINETUNE_WORKDIR, "manifest.tsv")
found = False
for tpath in transcripts_candidates:
if os.path.exists(tpath):
try:
shutil.copy2(tpath, manifest_path)
found = True
break
except Exception:
pass
missing_transcripts = 0
if not found:
audio_files = []
for root, _, files in os.walk(dst):
for f in files:
if f.lower().endswith((".wav", ".mp3", ".flac", ".m4a", ".ogg")):
audio_files.append(os.path.join(root, f))
if not audio_files:
return f"No audio files found in dataset.{old_msgs}", ""
entries = []
for a in audio_files:
base = os.path.splitext(a)[0]
t_candidate = base + ".txt"
transcript = ""
if os.path.exists(t_candidate):
try:
with open(t_candidate, "r", encoding="utf-8") as fh:
transcript = fh.read().strip().replace("\n", " ")
except Exception:
transcript = ""
else:
missing_transcripts += 1
entries.append(f"{a}\t{transcript}")
try:
with open(manifest_path, "w", encoding="utf-8") as fh:
fh.write("\n".join(entries))
found = True
except Exception as e:
return f"Failed to write manifest: {e}{old_msgs}", ""
if not found:
return f"Failed to locate or build manifest.{old_msgs}", ""
status_msg = f"Dataset prepared. Manifest: {manifest_path}{old_msgs}"
if missing_transcripts > 0:
status_msg += f"\nWarning: {missing_transcripts} audio files have no matching .txt transcript (empty transcripts saved)."
return status_msg, manifest_path
def start_finetune(manifest_path, base_model, epochs, batch_size, lr, output_dir):
outdir = output_dir or os.path.join(FINETUNE_WORKDIR, "output")
os.makedirs(outdir, exist_ok=True)
START_CMD = [
sys.executable,
"fine_tune.py",
"--manifest",
manifest_path,
"--base_model",
base_model,
"--epochs",
str(epochs),
"--batch_size",
str(batch_size),
"--lr",
str(lr),
"--output_dir",
outdir,
]
try:
logfile = open(os.path.join(outdir, "finetune_stdout.log"), "a", encoding="utf-8")
proc = subprocess.Popen(START_CMD, stdout=logfile, stderr=logfile, cwd=os.getcwd())
return f"Fine-tune started (PID={proc.pid}). Logs: {logfile.name}"
except FileNotFoundError as e:
return f"Training script not found: {e}. Put 'fine_tune.py' in project root or change START_CMD."
except Exception as e:
return f"Failed to start fine-tune: {e}"
def tail_finetune_logs(logpath, lines=200):
try:
if not os.path.exists(logpath):
return "No logs yet."
with open(logpath, "r", encoding="utf-8", errors="ignore") as fh:
all_lines = fh.read().splitlines()
last = all_lines[-lines:]
return "\n".join(last)
except Exception as e:
return f"Failed to read logs: {e}"
# ---------- UI CSS ----------
CSS = """
:root{
--accent:#4f46e5;
--muted:#6b7280;
--card:#ffffff;
--bg:#f7f8fb;
}
body { background: var(--bg); font-family: Inter, system-ui, -apple-system, "Segoe UI", Roboto, "Helvetica Neue", Arial; }
.header { padding: 18px 24px; border-radius: 12px; background: linear-gradient(90deg, rgba(79,70,229,0.12), rgba(99,102,241,0.04)); margin-bottom: 18px; display:flex;align-items:center;gap:16px; }
.app-icon { width:62px;height:62px;border-radius:12px;background:linear-gradient(135deg,var(--accent),#06b6d4);display:flex;align-items:center;justify-content:center;color:white;font-weight:700;font-size:24px; }
.header-title h1 { margin:0;font-size:20px;}
.header-sub { color:var(--muted); margin-top:4px;font-size:13px;}
.card { background:var(--card); border-radius:12px; padding:14px; box-shadow: 0 6px 20px rgba(16,24,40,0.06); }
.transcript-area { white-space:pre-wrap; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, "Roboto Mono", monospace; background:#0f172a; color:#e6eef8; padding:12px; border-radius:10px; min-height:220px; }
.small-note { color:var(--muted); font-size:12px;}
"""
# ---------- Build UI ----------
print("DEBUG: building Gradio Blocks", flush=True)
with gr.Blocks(title="Whisper Transcriber", css=CSS) as demo:
# Header
with gr.Row(elem_classes="header"):
with gr.Column(scale=0):
gr.HTML("<div class='app-icon'>WT</div>")
with gr.Column():
gr.HTML("<h1 style='margin:0'>Whisper Transcriber</h1>")
gr.Markdown("<div class='header-sub'>Transcribe, batch, memory & fine-tune — multi-tab UI</div>")
with gr.Tabs():
# Audio Transcribe Tab
with gr.TabItem("Audio Transcribe"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="card"):
gr.Markdown("### Quick Single Audio Transcribe")
single_audio = gr.Audio(label="Upload or record audio", type="filepath")
with gr.Row():
model_select = gr.Dropdown(choices=["small","medium","large","large-v3","base"], value="large-v3", label="Model")
device_select = gr.Dropdown(choices=["auto","cpu","cuda"], value="auto", label="Device")
with gr.Row():
mem_toggle = gr.Checkbox(label="Enable correction memory", value=False)
format_choice = gr.Dropdown(choices=["Plain","SOAP (medical)"], value="Plain", label="Format")
transcribe_btn = gr.Button("Transcribe", variant="primary")
gr.Markdown("<div class='small-note'>Tip: choose large-v3 if your environment supports it.</div>")
with gr.Column(scale=1):
with gr.Group(elem_classes="card"):
gr.Markdown("### Player & Transcript")
audio_preview = gr.Audio(label="Player", interactive=False)
transcript_out = gr.Textbox(label="Transcript", lines=14, interactive=False, elem_classes="transcript-area")
transcript_logs = gr.Textbox(label="Logs", lines=6, interactive=False)
def _do_single_transcribe(audio_file, model_name, device_choice, enable_memory, fmt_choice):
player_path, transcript, logs = transcribe_single(audio_file, model_name=model_name, enable_memory=enable_memory, device_choice=device_choice)
if fmt_choice == "SOAP":
sentences = re.split(r"(?<=[.?!])\s+", transcript)
subj = sentences[0] if sentences else ""
obj = sentences[1] if len(sentences) > 1 else ""
soap = f"S: {subj}\nO: {obj}\nA: Assessment pending\nP: Plan: follow up"
transcript = soap
return player_path, transcript, logs
transcribe_btn.click(fn=_do_single_transcribe, inputs=[single_audio, model_select, device_select, mem_toggle, format_choice], outputs=[audio_preview, transcript_out, transcript_logs])
# Batch Transcribe Tab
with gr.TabItem("Batch Transcribe"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="card"):
gr.Markdown("### Batch / ZIP workflow")
batch_files = gr.File(label="Upload multiple audio files (optional)", file_count="multiple", type="filepath")
batch_zip = gr.File(label="Or upload ZIP with audio (optional)", file_count="single", type="filepath")
zip_password = gr.Textbox(label="ZIP password (optional)")
with gr.Row():
batch_model = gr.Dropdown(choices=["small","medium","large","large-v3","base"], value="small", label="Model")
batch_device = gr.Dropdown(choices=["auto","cpu","cuda"], value="auto", label="Device")
batch_merge = gr.Checkbox(label="Merge all transcripts into one .docx", value=True)
batch_mem = gr.Checkbox(label="Enable memory corrections", value=False)
batch_extract_btn = gr.Button("Extract ZIP & List Files")
batch_extract_logs = gr.Textbox(label="Extraction logs", lines=6, interactive=False)
batch_select = gr.CheckboxGroup(choices=[], label="Select extracted files to transcribe", interactive=True)
batch_trans_btn = gr.Button("Start Batch Transcription", variant="primary")
with gr.Column(scale=1):
with gr.Group(elem_classes="card"):
gr.Markdown("### Output")
batch_trans_out = gr.Textbox(label="Transcript (combined)", lines=16, interactive=False)
batch_logs = gr.Textbox(label="Logs", lines=10, interactive=False)
batch_download = gr.File(label="Merged .docx (when available)")
def _extract_zip_for_ui(zip_file, password):
if not zip_file:
return [], "No zip provided."
zip_path = zip_file.name if hasattr(zip_file, "name") else str(zip_file)
extracted, logs = extract_zip_list(zip_path, password)
short_logs = logs + "\n\nFiles:\n" + "\n".join([os.path.basename(p) for p in extracted])
return extracted, short_logs
batch_extract_btn.click(fn=_extract_zip_for_ui, inputs=[batch_zip, zip_password], outputs=[batch_select, batch_extract_logs])
def _batch_transcribe(selected_check, uploaded_files, model_name, device_name, merge_flag, enable_mem):
paths = []
if selected_check:
paths.extend(selected_check)
if uploaded_files:
if isinstance(uploaded_files, (list, tuple)):
for x in uploaded_files:
paths.append(str(x))
else:
paths.append(str(uploaded_files))
if not paths:
return "", "No files selected or uploaded.", None
logs = []
transcripts = []
out_doc = None
for p in paths:
try:
_, txt, lg = transcribe_single(p, model_name=model_name, enable_memory=enable_mem, device_choice=device_name)
logs.append(lg)
transcripts.append(f"FILE: {os.path.basename(str(p))}\n{txt}\n")
except Exception as e:
logs.append(f"Failed {p}: {e}")
combined = "\n\n".join(transcripts)
if merge_flag:
try:
out_doc = save_as_word(combined)
logs.append(f"Merged saved: {out_doc}")
except Exception as e:
logs.append(f"Merge failed: {e}")
return combined, "\n".join(logs), out_doc
batch_trans_btn.click(fn=_batch_transcribe, inputs=[batch_select, batch_files, batch_model, batch_device, batch_merge, batch_mem], outputs=[batch_trans_out, batch_logs, batch_download])
# Memory Tab
with gr.TabItem("Memory"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="card"):
gr.Markdown("### Correction Memory")
mem_upload = gr.File(label="Import memory (JSON or text)", file_count="single", type="filepath")
mem_import_btn = gr.Button("Import Memory")
mem_add_text = gr.Textbox(label="Add word / phrase", placeholder="Type and click Add")
mem_add_btn = gr.Button("Add to Memory")
mem_clear_btn = gr.Button("Clear Memory")
mem_view_btn = gr.Button("View Memory")
mem_status = gr.Textbox(label="Memory status / preview", lines=12, interactive=False)
def _import_mem(uploaded):
if not uploaded:
return "No file provided."
path = uploaded.name if hasattr(uploaded, "name") else str(uploaded)
try:
with open(path, "r", encoding="utf-8") as fh:
raw = fh.read()
parsed = None
try:
parsed = json.loads(raw)
except Exception:
parsed = None
if isinstance(parsed, dict):
with MEMORY_LOCK:
for k, v in parsed.get("words", {}).items():
memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + int(v)
for k, v in parsed.get("phrases", {}).items():
memory["phrases"][k] = memory["phrases"].get(k, 0) + int(v)
save_memory(memory)
return f"Imported JSON memory (words={len(parsed.get('words', {}))}, phrases={len(parsed.get('phrases', {}))})."
lines = [l.strip() for l in raw.splitlines() if l.strip()]
added = 0
with MEMORY_LOCK:
for line in lines:
if "," in line:
k, c = line.split(",", 1)
try:
cnt = int(c)
except:
cnt = 1
memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + cnt
else:
memory["words"][line.lower()] = memory["words"].get(line.lower(), 0) + 1
added += 1
save_memory(memory)
return f"Imported {added} entries."
except Exception as e:
return f"Import failed: {e}"
def _add_mem(entry):
if not entry or not entry.strip():
return "No entry provided."
e = entry.strip()
with MEMORY_LOCK:
if len(e.split()) <= 3:
memory["words"][e.lower()] = memory["words"].get(e.lower(), 0) + 1
save_memory(memory)
return f"Added word: {e.lower()}"
else:
memory["phrases"][e] = memory["phrases"].get(e, 0) + 1
save_memory(memory)
return f"Added phrase: {e}"
def _clear_mem():
global memory
with MEMORY_LOCK:
memory = {"words": {}, "phrases": {}}
save_memory(memory)
return "Memory cleared."
def _view_mem():
w = memory.get("words", {})
p = memory.get("phrases", {})
out = []
out.append("WORDS (top 30):")
for k, v in sorted(w.items(), key=lambda kv: -kv[1])[:30]:
out.append(f"{k}: {v}")
out.append("")
out.append("PHRASES (top 20):")
for k, v in sorted(p.items(), key=lambda kv: -kv[1])[:20]:
out.append(f"{k}: {v}")
return "\n".join(out)
mem_import_btn.click(fn=_import_mem, inputs=[mem_upload], outputs=[mem_status])
mem_add_btn.click(fn=_add_mem, inputs=[mem_add_text], outputs=[mem_status])
mem_clear_btn.click(fn=_clear_mem, inputs=[], outputs=[mem_status])
mem_view_btn.click(fn=_view_mem, inputs=[], outputs=[mem_status])
# Fine-tune Tab
with gr.TabItem("Fine-tune"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="card"):
gr.Markdown("### Prepare & Launch Fine-tune")
ft_upload = gr.File(label="Upload dataset ZIP (optional)", file_count="single", type="filepath")
ft_include_old = gr.Checkbox(label="Include old audio+transcript folder", value=False)
ft_old = gr.File(label="Old files folder (optional)", file_count="single", type="filepath")
ft_prepare_btn = gr.Button("Prepare dataset")
ft_manifest_box = gr.Textbox(label="Prepare status / manifest", lines=4, interactive=False)
ft_base_model = gr.Dropdown(choices=["small","base","medium","large","large-v3"], value="small", label="Base model")
ft_epochs = gr.Slider(minimum=1, maximum=100, value=3, step=1, label="Epochs")
ft_batch = gr.Number(label="Batch size", value=8)
ft_lr = gr.Number(label="Learning rate", value=1e-5, precision=8)
ft_output_dir = gr.Textbox(label="Output dir (optional)", value="", placeholder="Leave blank to use temp output")
ft_start_btn = gr.Button("Start Fine-tune")
ft_stop_btn = gr.Button("Stop Fine-tune")
ft_start_status = gr.Textbox(label="Start/Stop status", interactive=False, lines=4)
ft_tail_btn = gr.Button("Tail training logs")
ft_logs = gr.Textbox(label="Training logs (tail)", interactive=False, lines=12)
with gr.Column(scale=1):
with gr.Group(elem_classes="card"):
gr.Markdown("### Notes")
gr.Markdown("- Old-files folder should contain audio files and matching .txt transcripts with the same basename.")
gr.Markdown("- The app prepares a manifest and calls your `fine_tune.py` training script (you must provide it).")
def _prepare_action(ft_upload_file, include_old, old_dir):
status, manifest = prepare_finetune_dataset(ft_upload_file, include_old_files=include_old, old_files_dir=old_dir)
return status
def _start_action(manifest_text, base_model, epochs, batch_size, lr, output_dir):
manifest_guess = os.path.join(FINETUNE_WORKDIR, "manifest.tsv")
if not os.path.exists(manifest_guess):
return "Manifest not found. Prepare dataset first or manually provide manifest."
status = start_finetune(manifest_guess, base_model, int(epochs), int(batch_size), float(lr), output_dir)
return status
ft_prepare_btn.click(fn=_prepare_action, inputs=[ft_upload, ft_include_old, ft_old], outputs=[ft_manifest_box])
ft_start_btn.click(fn=_start_action, inputs=[ft_manifest_box, ft_base_model, ft_epochs, ft_batch, ft_lr, ft_output_dir], outputs=[ft_start_status])
ft_stop_btn.click(fn=lambda: "Stop not implemented in placeholder", inputs=[], outputs=[ft_start_status])
ft_tail_btn.click(fn=lambda: "Tail logs not implemented in placeholder", inputs=[], outputs=[ft_logs])
# Settings Tab
with gr.TabItem("Settings"):
with gr.Row():
with gr.Column():
with gr.Group(elem_classes="card"):
gr.Markdown("### Runtime & tips")
gr.Markdown("- Use large-v3 only if your whisper package supports it.")
gr.Markdown("- Extraction writes to system temp `extracted_audio`. Re-extracting overwrites it.")
gr.Markdown("- Provide your `fine_tune.py` for real fine-tuning.")
with gr.Column():
with gr.Group(elem_classes="card"):
gr.Markdown("### Diagnostics")
diag_btn = gr.Button("Show memory summary")
diag_out = gr.Textbox(label="Diagnostics", lines=12, interactive=False)
diag_btn.click(fn=lambda: _view_mem(), inputs=[], outputs=[diag_out])
# ---------- Launch ----------
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print("DEBUG: launching Gradio on port", port, flush=True)
try:
demo.queue().launch(server_name="0.0.0.0", server_port=port)
except Exception as e:
print("FATAL: demo.launch failed:", e, flush=True)
traceback.print_exc()
raise