arvind / app.py
staraks's picture
Update app.py
b7cd478 verified
raw
history blame
28.4 kB
# app.py
# Whisper transcription app - HYBRID conversion (pydub + small ffmpeg fallback)
# Clean, single-version file for Hugging Face Spaces.
import os
import sys
import json
import shutil
import tempfile
import subprocess
import traceback
import threading
import re
from difflib import get_close_matches
# Force unbuffered output so container logs show prints immediately
os.environ["PYTHONUNBUFFERED"] = "1"
print("DEBUG: app.py bootstrap starting", flush=True)
# Third-party imports (must be installed in the environment)
try:
from docx import Document
import whisper
import gradio as gr
import pyzipper
from pydub import AudioSegment
except Exception as e:
print("FATAL: import error for third-party libs:", e, flush=True)
traceback.print_exc()
raise
print("DEBUG: imports OK", flush=True)
# ---------- Config ----------
MEMORY_FILE = "memory.json"
MEMORY_LOCK = threading.Lock()
MIN_WAV_SIZE = 1024 # bytes
FFMPEG_CANDIDATES = [
("s16le", 16000, 1),
("s16le", 44100, 2),
("pcm_s16le", 16000, 1),
("pcm_s16le", 44100, 2),
("mulaw", 8000, 1),
]
# ----------------------------
# ---------- Memory helpers ----------
def load_memory():
try:
if os.path.exists(MEMORY_FILE):
with open(MEMORY_FILE, "r", encoding="utf-8") as fh:
data = json.load(fh)
if not isinstance(data, dict):
raise ValueError("memory.json root not dict")
data.setdefault("words", {})
data.setdefault("phrases", {})
return data
except Exception:
pass
mem = {"words": {}, "phrases": {}}
try:
with open(MEMORY_FILE, "w", encoding="utf-8") as fh:
json.dump(mem, fh, ensure_ascii=False, indent=2)
except Exception:
pass
return mem
def save_memory(mem):
with MEMORY_LOCK:
try:
with open(MEMORY_FILE, "w", encoding="utf-8") as fh:
json.dump(mem, fh, ensure_ascii=False, indent=2)
except Exception:
traceback.print_exc()
memory = load_memory()
print(
"DEBUG: memory loaded (words=%d phrases=%d)"
% (len(memory.get("words", {})), len(memory.get("phrases", {}))),
flush=True,
)
# ---------- Postprocessing ----------
MEDICAL_ABBREVIATIONS = {
"pt": "patient",
"dx": "diagnosis",
"hx": "history",
"sx": "symptoms",
"c/o": "complains of",
"bp": "blood pressure",
"hr": "heart rate",
"o2": "oxygen",
"r/o": "rule out",
"adm": "admit",
"disch": "discharge",
}
DRUG_NORMALIZATION = {
"metformin": "Metformin",
"aspirin": "Aspirin",
"amoxicillin": "Amoxicillin",
}
def expand_abbreviations(text):
tokens = re.split(r"(\s+)", text)
out = []
for t in tokens:
key = t.lower().strip(".,;:")
if key in MEDICAL_ABBREVIATIONS:
trailing = ""
m = re.match(r"([A-Za-z0-9/]+)([.,;:]*)", t)
if m:
trailing = m.group(2) or ""
out.append(MEDICAL_ABBREVIATIONS[key] + trailing)
else:
out.append(t)
return "".join(out)
def normalize_drugs(text):
for k, v in DRUG_NORMALIZATION.items():
text = re.sub(rf"\b{k}\b", v, text, flags=re.IGNORECASE)
return text
def punctuation_and_capitalization(text):
text = text.strip()
if not text:
return text
if not re.search(r"[.?!]\s*$", text):
text = text.rstrip() + "."
parts = re.split(r"([.?!]\s+)", text)
out = []
for p in parts:
if p and not re.match(r"[.?!]\s+", p):
out.append(p.capitalize())
else:
out.append(p)
return "".join(out)
def postprocess_transcript(text, format_soap=False):
if not text:
return text
t = re.sub(r"\s+", " ", text).strip()
t = expand_abbreviations(t)
t = normalize_drugs(t)
t = punctuation_and_capitalization(t)
if format_soap:
sentences = re.split(r"(?<=[.?!])\s+", t)
subj = sentences[0] if len(sentences) >= 1 else ""
obj = sentences[1] if len(sentences) >= 2 else ""
assessment = ""
for kw in ["diagnosis", "dx", "rule out", "r/o", "probable"]:
if kw in t.lower():
assessment = "Assessment: " + subj
break
soap = (
f"S: {subj}\nO: {obj}\nA: {assessment}\nP: Plan: follow up as indicated."
)
return soap
return t
# ---------- Memory utilities ----------
def extract_words_and_phrases(text):
words = re.findall(r"[A-Za-z0-9\-']+", text)
sentences = [s.strip() for s in re.split(r"(?<=[.?!])\s+", text) if s.strip()]
return [w for w in words if w.strip()], sentences
def update_memory_with_transcript(transcript):
global memory
words, sentences = extract_words_and_phrases(transcript)
changed = False
with MEMORY_LOCK:
for w in words:
lw = w.lower()
if lw in memory["words"]:
memory["words"][lw] += 1
else:
memory["words"][lw] = 1
changed = True
for s in sentences:
key = s.strip()
if key in memory["phrases"]:
memory["phrases"][key] += 1
else:
memory["phrases"][key] = 1
changed = True
if changed:
try:
with open(MEMORY_FILE, "w", encoding="utf-8") as fh:
json.dump(memory, fh, ensure_ascii=False, indent=2)
except Exception:
pass
def memory_correct_text(text, min_ratio=0.85):
if not text or (not memory.get("words") and not memory.get("phrases")):
return text
def fix_word(w):
lw = w.lower()
if lw in memory["words"]:
return w
candidates = get_close_matches(lw, memory["words"].keys(), n=1, cutoff=min_ratio)
if candidates:
cand = candidates[0]
if w and w[0].isupper():
return cand.capitalize()
return cand
return w
tokens = re.split(r"(\W+)", text)
corrected_tokens = []
for tok in tokens:
if re.match(r"^[A-Za-z0-9\-']+$", tok):
corrected_tokens.append(fix_word(tok))
else:
corrected_tokens.append(tok)
corrected = "".join(corrected_tokens)
for phrase in sorted(memory.get("phrases", {}).keys(), key=lambda s: -len(s)):
low_phrase = phrase.lower()
if len(low_phrase) < 8:
continue
if low_phrase in corrected.lower():
corrected = re.sub(re.escape(phrase), phrase, corrected, flags=re.IGNORECASE)
return corrected
# ---------- Memory management UI helpers ----------
def import_memory_file(uploaded):
global memory
if not uploaded:
return "No file provided."
path = None
try:
if isinstance(uploaded, (str, os.PathLike)):
path = str(uploaded)
elif hasattr(uploaded, "name"):
path = uploaded.name
elif isinstance(uploaded, dict) and uploaded.get("name"):
path = uploaded["name"]
else:
return "Unable to determine uploaded file path."
with open(path, "r", encoding="utf-8") as fh:
raw = fh.read()
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
with MEMORY_LOCK:
parsed_words = parsed.get("words", {})
parsed_phrases = parsed.get("phrases", {})
for k, v in parsed_words.items():
memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + int(v)
for k, v in parsed_phrases.items():
memory["phrases"][k] = memory["phrases"].get(k, 0) + int(v)
save_memory(memory)
return f"Imported JSON memory (words={len(parsed_words)}, phrases={len(parsed_phrases)})."
except Exception:
pass
lines = [l.strip() for l in raw.splitlines() if l.strip()]
added_words = 0
added_phrases = 0
with MEMORY_LOCK:
for line in lines:
if "," in line:
parts = [p.strip() for p in line.split(",", 1)]
key = parts[0].lower()
try:
cnt = int(parts[1])
except Exception:
cnt = 1
memory["words"][key] = memory["words"].get(key, 0) + cnt
added_words += 1
else:
if len(line.split()) <= 3:
key = line.lower()
memory["words"][key] = memory["words"].get(key, 0) + 1
added_words += 1
else:
memory["phrases"][line] = memory["phrases"].get(line, 0) + 1
added_phrases += 1
save_memory(memory)
return f"Imported {added_words} words and {added_phrases} phrases from file."
except Exception as e:
traceback.print_exc()
return f"Import failed: {e}"
def add_memory_entry(entry):
global memory
if not entry or not entry.strip():
return "No entry provided."
e = entry.strip()
with MEMORY_LOCK:
if len(e.split()) <= 3:
key = e.lower()
memory["words"][key] = memory["words"].get(key, 0) + 1
save_memory(memory)
return f"Added/updated word: '{key}'."
else:
memory["phrases"][e] = memory["phrases"].get(e, 0) + 1
save_memory(memory)
return f"Added/updated phrase: '{e}'."
def clear_memory():
global memory
with MEMORY_LOCK:
memory = {"words": {}, "phrases": {}}
save_memory(memory)
return "Memory cleared."
def view_memory(limit=2000):
w = memory.get("words", {})
p = memory.get("phrases", {})
out_lines = []
out_lines.append("WORDS (top 50):")
for k, v in sorted(w.items(), key=lambda kv: -kv[1])[:50]:
out_lines.append(f"{k}: {v}")
out_lines.append("")
out_lines.append("PHRASES (top 50):")
for k, v in sorted(p.items(), key=lambda kv: -kv[1])[:50]:
out_lines.append(f"{k}: {v}")
out = "\n".join(out_lines)
if len(out) > limit:
out = out[:limit] + "\n...truncated..."
return out
# ---------- File utilities ----------
def save_as_word(text, filename=None):
if filename is None:
filename = os.path.join(tempfile.gettempdir(), "merged_transcripts.docx")
doc = Document()
doc.add_paragraph(text)
doc.save(filename)
return filename
# ---------- improved ffmpeg convert ----------
def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
try:
cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"]
if fmt in ("s16le", "pcm_s16le", "mulaw"):
cmd += ["-f", fmt, "-ar", str(sr), "-ac", str(ch), "-i", input_path, out_path]
else:
cmd += ["-i", input_path, "-ar", str(sr), "-ac", str(ch), out_path]
proc = subprocess.run(cmd, capture_output=True, timeout=60, text=True)
stdout_stderr = (proc.stdout or "") + (proc.stderr or "")
if proc.returncode == 0 and os.path.exists(out_path) and os.path.getsize(out_path) > MIN_WAV_SIZE:
return True, stdout_stderr
else:
try:
if os.path.exists(out_path):
os.unlink(out_path)
except Exception:
pass
return False, stdout_stderr
except Exception as e:
try:
if os.path.exists(out_path):
os.unlink(out_path)
except Exception:
pass
return False, str(e)
def convert_to_wav_if_needed(input_path):
input_path = str(input_path)
lower = input_path.lower()
if lower.endswith(".wav"):
return input_path
auto_err = ""
tmp = None
try:
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.close()
AudioSegment.from_file(input_path).export(tmp.name, format="wav")
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > MIN_WAV_SIZE:
return tmp.name
else:
try:
os.unlink(tmp.name)
except Exception:
pass
except Exception:
auto_err = traceback.format_exc()
try:
if tmp and os.path.exists(tmp.name):
os.unlink(tmp.name)
except Exception:
pass
diag_dir = tempfile.mkdtemp(prefix="dct_diag_")
diag_log = os.path.join(diag_dir, "conversion_diagnostics.txt")
diagnostics = []
for fmt, sr, ch in FFMPEG_CANDIDATES:
out_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
out_wav.close()
success, debug = _ffmpeg_convert(input_path, out_wav.name, fmt, sr, ch)
diagnostics.append(f"TRY fmt={fmt} sr={sr} ch={ch} success={success}\n{debug}\n")
if success:
try:
with open(diag_log, "w", encoding="utf-8") as fh:
fh.write("pydub auto error:\n")
fh.write(auto_err + "\n\n")
fh.write("Successful ffmpeg candidate:\n")
fh.write(f"fmt={fmt} sr={sr} ch={ch}\n\n")
fh.write("Diagnostics:\n")
fh.write("\n".join(diagnostics))
except Exception:
pass
return out_wav.name
else:
try:
if os.path.exists(out_wav.name):
os.unlink(out_wav.name)
except Exception:
pass
try:
fp = subprocess.run(
["ffprobe", "-v", "error", "-show_format", "-show_streams", input_path],
capture_output=True,
text=True,
timeout=10,
)
diagnostics.append("FFPROBE:\n" + (fp.stdout.strip() or fp.stderr.strip()))
except Exception as e:
diagnostics.append("ffprobe failed: " + str(e))
try:
with open(input_path, "rb") as fh:
head = fh.read(512)
diagnostics.append("HEX PREVIEW:\n" + head.hex())
except Exception as e:
diagnostics.append("could not read head: " + str(e))
try:
with open(diag_log, "w", encoding="utf-8") as fh:
fh.write("pydub auto error:\n")
fh.write(auto_err + "\n\n")
fh.write("Full diagnostics:\n\n")
fh.write("\n\n".join(diagnostics))
except Exception as e:
raise Exception(f"Conversion failed; diagnostics write error: {e}")
raise Exception(f"Could not convert file to WAV. Diagnostics saved to: {diag_log}")
# ---------- Whisper model cache ----------
MODEL_CACHE = {}
def get_whisper_model(name):
if name not in MODEL_CACHE:
print(f"DEBUG: loading whisper model '{name}'", flush=True)
MODEL_CACHE[name] = whisper.load_model(name)
return MODEL_CACHE[name]
# ---------- Main transcription generator ----------
def transcribe_multiple(
audio_files,
model_name,
advanced_options,
merge_checkbox,
zip_file=None,
zip_password=None,
enable_memory=False,
):
log = []
transcripts = []
word_file_path = None
temp_extract_dir = os.path.join(tempfile.gettempdir(), "extracted_audio")
extracted_audio_paths = []
yield "", "", None, 0
if os.path.exists(temp_extract_dir):
try:
shutil.rmtree(temp_extract_dir)
log.append(f"Cleaned previous temp dir: {temp_extract_dir}")
except Exception:
pass
if zip_file:
log.append(f"Processing zip: {zip_file}")
yield "\n\n".join(log), "\n\n".join(transcripts), None, 2
try:
os.makedirs(temp_extract_dir, exist_ok=True)
with pyzipper.ZipFile(zip_file, "r") as zf:
if zip_password:
try:
zf.setpassword(zip_password.encode())
except Exception:
log.append("Failed to set zip password (unexpected).")
exts = [
".mp3",
".wav",
".aac",
".flac",
".ogg",
".m4a",
".dat",
".dct",
]
count = 0
for info in zf.infolist():
if info.is_dir():
continue
_, ext = os.path.splitext(info.filename)
if ext.lower() in exts:
try:
zf.extract(info, path=temp_extract_dir)
except RuntimeError as e:
log.append(f"Password required or incorrect for {info.filename}: {e}")
continue
except pyzipper.BadZipFile:
log.append(f"Bad zip entry: {info.filename}")
continue
except Exception as e:
log.append(f"Error extracting {info.filename}: {e}")
continue
p = os.path.normpath(os.path.join(temp_extract_dir, info.filename))
if os.path.exists(p):
extracted_audio_paths.append(p)
count += 1
log.append(f"Extracted: {info.filename}")
if count == 0:
log.append("No supported audio in zip.")
try:
shutil.rmtree(temp_extract_dir)
except Exception:
pass
yield "\n\n".join(log), "\n\n".join(transcripts), None, 100
return
except pyzipper.BadZipFile:
log.append("Invalid zip file.")
try:
shutil.rmtree(temp_extract_dir)
except Exception:
pass
yield "\n\n".join(log), "\n\n".join(transcripts), None, 100
return
except Exception as e:
log.append(f"Zip processing error: {e}")
try:
shutil.rmtree(temp_extract_dir)
except Exception:
pass
yield "\n\n".join(log), "\n\n".join(transcripts), None, 100
return
paths = []
if extracted_audio_paths:
paths.extend(extracted_audio_paths)
if audio_files:
if isinstance(audio_files, (list, tuple)):
for a in audio_files:
if a:
paths.append(a)
elif isinstance(audio_files, str):
paths.append(audio_files)
if not paths:
log.append("No audio files provided.")
yield "\n\n".join(log), "\n\n".join(transcripts), None, 100
return
yield "\n\n".join(log), "\n\n".join(transcripts), None, 5
try:
model = get_whisper_model(model_name)
log.append(f"Loaded Whisper model: {model_name}")
except Exception as e:
log.append(f"Failed to load model {model_name}: {e}")
yield "\n\n".join(log), "\n\n".join(transcripts), None, 100
return
total = len(paths)
idx = 0
for p in paths:
idx += 1
log.append(f"Processing file ({idx}/{total}): {p}")
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(5 + (idx - 1) * 80 / max(1, total))
wav = None
try:
wav = convert_to_wav_if_needed(p)
log.append(f"Converted to WAV: {wav}")
except Exception as e:
log.append(f"Conversion failed for {p}: {e}")
transcripts.append(f"FILE: {os.path.basename(p)}\nERROR: Conversion failed: {e}")
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(5 + idx * 80 / max(1, total))
continue
try:
whisper_opts = {}
if isinstance(advanced_options, dict):
whisper_opts.update(advanced_options)
result = model.transcribe(wav, **whisper_opts)
text = result.get("text", "").strip()
log.append(f"Transcribed: {len(text)} chars")
if enable_memory:
text = memory_correct_text(text)
text = postprocess_transcript(text)
transcripts.append(f"FILE: {os.path.basename(p)}\n{text}\n")
if enable_memory:
try:
update_memory_with_transcript(text)
log.append("Memory updated.")
except Exception:
pass
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(10 + idx * 85 / max(1, total))
except Exception as e:
log.append(f"Transcription failed for {p}: {e}")
transcripts.append(f"FILE: {os.path.basename(p)}\nERROR: Transcription failed: {e}")
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(10 + idx * 85 / max(1, total))
continue
finally:
try:
if wav and os.path.exists(wav):
tmpdir = tempfile.gettempdir()
try:
common = os.path.commonpath([os.path.abspath(tmpdir), os.path.abspath(wav)])
if common == os.path.abspath(tmpdir) and not p.lower().endswith(".wav"):
os.unlink(wav)
except Exception:
try:
if tmpdir in os.path.abspath(wav) and not p.lower().endswith(".wav"):
os.unlink(wav)
except Exception:
pass
except Exception:
pass
if merge_checkbox:
try:
merged_text = "\n\n".join(transcripts)
word_file_path = save_as_word(merged_text)
log.append(f"Merged transcript saved: {word_file_path}")
except Exception as e:
log.append(f"Failed to save merged file: {e}")
word_file_path = None
yield "\n\n".join(log), "\n\n".join(transcripts), word_file_path, 100
try:
if os.path.exists(temp_extract_dir):
shutil.rmtree(temp_extract_dir)
log.append("Cleaned temporary extraction dir.")
except Exception:
pass
# ----------------------- Gradio wrapper (streaming) -----------------------
def run_transcription_wrapper(
files,
model_name,
merge,
zip_file,
zip_password,
use_default_zip_pass,
default_zip_password,
enable_memory,
advanced_options_state,
):
try:
audio_input = files
zip_path = None
if zip_file:
if isinstance(zip_file, (str, os.PathLike)):
zip_path = str(zip_file)
elif hasattr(zip_file, "name"):
zip_path = zip_file.name
elif isinstance(zip_file, dict) and zip_file.get("name"):
zip_path = zip_file["name"]
if use_default_zip_pass and (not zip_password or zip_password.strip() == ""):
final_zip_password = default_zip_password
else:
final_zip_password = zip_password
adv = {}
for logs, transcripts, word_path, percent in transcribe_multiple(
audio_input,
model_name,
adv,
merge_checkbox=merge,
zip_file=zip_path,
zip_password=final_zip_password,
enable_memory=enable_memory,
):
yield logs, transcripts, word_path, percent
except Exception:
tb = traceback.format_exc()
logs = f"EXCEPTION in run_transcription_wrapper:\n{tb}"
transcripts = "ERROR: transcription did not start or failed unexpectedly."
yield logs, transcripts, None, 100
print("DEBUG: building Gradio Blocks", flush=True)
with gr.Blocks(title="Whisper Transcriber") as demo:
gr.Markdown(
"## Whisper Transcriber\n"
"Upload audio files or a ZIP on the left and click **Transcribe**.\n"
"Transcript, progress, download, and logs appear on the right."
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input")
file_input = gr.File(
label="Audio files",
file_count="multiple",
type="filepath",
height=60,
)
zip_input = gr.File(
label="ZIP with audio (optional)",
file_count="single",
type="filepath",
height=60,
)
use_default_zip_pass = gr.Checkbox(
label="Use default ZIP password",
value=False,
)
default_zip_password = gr.Textbox(
label="Default ZIP password",
value="dietcoke1",
interactive=True,
)
zip_password = gr.Textbox(
label="ZIP password (override)",
placeholder="If empty, default password will be used",
)
model_select = gr.Dropdown(
choices=["small", "medium", "large", "base"],
value="small",
label="Whisper model",
)
merge_checkbox = gr.Checkbox(
label="Merge all transcripts into one .docx",
value=True,
)
memory_checkbox = gr.Checkbox(
label="Enable correction memory (use during transcription)",
value=False,
)
submit = gr.Button("Transcribe", variant="primary")
gr.Markdown("### Memory management")
mem_upload = gr.File(label="Import memory file (JSON or text)", file_count="single", type="file")
mem_import_btn = gr.Button("Import Memory File")
mem_manual_entry = gr.Textbox(label="Add word/phrase to memory (manual)", placeholder="Type a word or phrase")
mem_add_btn = gr.Button("Add to Memory")
mem_clear_btn = gr.Button("Clear Memory")
mem_view_btn = gr.Button("View Memory")
mem_status = gr.Textbox(label="Memory status", interactive=False, lines=6)
with gr.Column(scale=1):
gr.Markdown("### Output")
transcripts_out = gr.Textbox(
label="Transcript",
lines=18,
interactive=False,
)
progress_num = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=1,
label="Progress (%)",
interactive=False,
)
download_file = gr.File(
label="Merged .docx (when available)"
)
logs = gr.Textbox(
label="Logs",
lines=10,
interactive=False,
)
submit.click(
fn=run_transcription_wrapper,
inputs=[
file_input,
model_select,
merge_checkbox,
zip_input,
zip_password,
use_default_zip_pass,
default_zip_password,
memory_checkbox,
gr.State({}),
],
outputs=[logs, transcripts_out, download_file, progress_num],
)
def _import_memory(uploaded):
return import_memory_file(uploaded)
mem_import_btn.click(fn=_import_memory, inputs=[mem_upload], outputs=[mem_status])
mem_add_btn.click(fn=add_memory_entry, inputs=[mem_manual_entry], outputs=[mem_status])
mem_clear_btn.click(fn=lambda: clear_memory(), inputs=[], outputs=[mem_status])
mem_view_btn.click(fn=lambda: view_memory(), inputs=[], outputs=[mem_status])
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print("DEBUG: launching Gradio on port", port, flush=True)
try:
demo.queue().launch(server_name="0.0.0.0", server_port=port)
except Exception as e:
print("FATAL: demo.launch failed:", e, flush=True)
traceback.print_exc()
raise