staraks commited on
Commit
97cd142
·
verified ·
1 Parent(s): 4987752

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -51
app.py CHANGED
@@ -35,9 +35,7 @@ print("DEBUG: imports OK", flush=True)
35
  # ---------- Config ----------
36
  MEMORY_FILE = "memory.json"
37
  MEMORY_LOCK = threading.Lock()
38
- MIN_WAV_SIZE = 200 # bytes
39
-
40
- # Small ffmpeg fallback grid (hybrid conversion)
41
  FFMPEG_CANDIDATES = [
42
  ("s16le", 16000, 1),
43
  ("s16le", 44100, 2),
@@ -52,7 +50,13 @@ def load_memory():
52
  try:
53
  if os.path.exists(MEMORY_FILE):
54
  with open(MEMORY_FILE, "r", encoding="utf-8") as fh:
55
- return json.load(fh)
 
 
 
 
 
 
56
  except Exception:
57
  pass
58
  mem = {"words": {}, "phrases": {}}
@@ -66,8 +70,11 @@ def load_memory():
66
 
67
  def save_memory(mem):
68
  with MEMORY_LOCK:
69
- with open(MEMORY_FILE, "w", encoding="utf-8") as fh:
70
- json.dump(mem, fh, ensure_ascii=False, indent=2)
 
 
 
71
 
72
 
73
  memory = load_memory()
@@ -77,6 +84,7 @@ print(
77
  flush=True,
78
  )
79
 
 
80
  # ---------- Postprocessing ----------
81
  MEDICAL_ABBREVIATIONS = {
82
  "pt": "patient",
@@ -235,6 +243,129 @@ def memory_correct_text(text, min_ratio=0.85):
235
  return corrected
236
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  # ---------- File utilities ----------
239
  def save_as_word(text, filename=None):
240
  if filename is None:
@@ -247,39 +378,33 @@ def save_as_word(text, filename=None):
247
  return filename
248
 
249
 
250
- # ---------- Hybrid conversion: pydub + small ffmpeg fallback ----------
251
  def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
252
- cmd = [
253
- "ffmpeg",
254
- "-hide_banner",
255
- "-loglevel",
256
- "error",
257
- "-y",
258
- "-f",
259
- fmt,
260
- "-ar",
261
- str(sr),
262
- "-ac",
263
- str(ch),
264
- "-i",
265
- input_path,
266
- out_path,
267
- ]
268
  try:
269
- proc = subprocess.run(cmd, capture_output=True, timeout=30, text=True)
270
- if (
271
- proc.returncode == 0
272
- and os.path.exists(out_path)
273
- and os.path.getsize(out_path) > MIN_WAV_SIZE
274
- ):
275
- return True, proc.stderr + proc.stdout
 
 
 
 
 
 
276
  else:
277
  try:
278
  if os.path.exists(out_path):
279
  os.unlink(out_path)
280
  except Exception:
281
  pass
282
- return False, proc.stderr + proc.stdout
283
  except Exception as e:
284
  try:
285
  if os.path.exists(out_path):
@@ -324,9 +449,7 @@ def convert_to_wav_if_needed(input_path):
324
  out_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
325
  out_wav.close()
326
  success, debug = _ffmpeg_convert(input_path, out_wav.name, fmt, sr, ch)
327
- diagnostics.append(
328
- f"TRY fmt={fmt} sr={sr} ch={ch} success={success}\n{debug}\n"
329
- )
330
  if success:
331
  try:
332
  with open(diag_log, "w", encoding="utf-8") as fh:
@@ -373,9 +496,7 @@ def convert_to_wav_if_needed(input_path):
373
  except Exception as e:
374
  raise Exception(f"Conversion failed; diagnostics write error: {e}")
375
 
376
- raise Exception(
377
- f"Could not convert file to WAV. Diagnostics saved to: {diag_log}"
378
- )
379
 
380
 
381
  # ---------- Whisper model cache ----------
@@ -385,6 +506,7 @@ MODEL_CACHE = {}
385
  def get_whisper_model(name):
386
  if name not in MODEL_CACHE:
387
  print(f"DEBUG: loading whisper model '{name}'", flush=True)
 
388
  MODEL_CACHE[name] = whisper.load_model(name)
389
  return MODEL_CACHE[name]
390
 
@@ -430,9 +552,8 @@ def transcribe_multiple(
430
  try:
431
  zf.setpassword(zip_password.encode())
432
  except Exception:
433
- log.append("Incorrect zip password")
434
- yield "\n\n".join(log), "\n\n".join(transcripts), None, 100
435
- return
436
  exts = [
437
  ".mp3",
438
  ".wav",
@@ -451,12 +572,16 @@ def transcribe_multiple(
451
  if ext.lower() in exts:
452
  try:
453
  zf.extract(info, path=temp_extract_dir)
 
 
 
 
 
 
454
  except Exception as e:
455
  log.append(f"Error extracting {info.filename}: {e}")
456
  continue
457
- p = os.path.normpath(
458
- os.path.join(temp_extract_dir, info.filename)
459
- )
460
  if os.path.exists(p):
461
  extracted_audio_paths.append(p)
462
  count += 1
@@ -575,12 +700,16 @@ def transcribe_multiple(
575
  try:
576
  if wav and os.path.exists(wav):
577
  tmpdir = tempfile.gettempdir()
578
- if (
579
- os.path.commonpath([tmpdir, os.path.abspath(wav)])
580
- == tmpdir
581
- and not p.lower().endswith(".wav")
582
- ):
583
- os.unlink(wav)
 
 
 
 
584
  except Exception:
585
  pass
586
 
@@ -699,7 +828,7 @@ with gr.Blocks(title="Whisper Transcriber") as demo:
699
 
700
  default_zip_password = gr.Textbox(
701
  label="Default ZIP password",
702
- value="dietcoke1", # you can change this
703
  interactive=True,
704
  )
705
 
@@ -720,12 +849,22 @@ with gr.Blocks(title="Whisper Transcriber") as demo:
720
  )
721
 
722
  memory_checkbox = gr.Checkbox(
723
- label="Enable correction memory",
724
  value=False,
725
  )
726
 
727
  submit = gr.Button("Transcribe", variant="primary")
728
 
 
 
 
 
 
 
 
 
 
 
729
  # RIGHT: Outputs (Transcript → Progress → Download → Logs)
730
  with gr.Column(scale=1):
731
  gr.Markdown("### Output")
@@ -755,6 +894,7 @@ with gr.Blocks(title="Whisper Transcriber") as demo:
755
  interactive=False,
756
  )
757
 
 
758
  submit.click(
759
  fn=run_transcription_wrapper,
760
  inputs=[
@@ -772,6 +912,18 @@ with gr.Blocks(title="Whisper Transcriber") as demo:
772
  outputs=[logs, transcripts_out, download_file, progress_num],
773
  )
774
 
 
 
 
 
 
 
 
 
 
 
 
 
775
  # ---------- Launch ----------
776
  if __name__ == "__main__":
777
  port = int(os.environ.get("PORT", 7860))
 
35
  # ---------- Config ----------
36
  MEMORY_FILE = "memory.json"
37
  MEMORY_LOCK = threading.Lock()
38
+ MIN_WAV_SIZE = 1024 # raised slightly from 200 for safety
 
 
39
  FFMPEG_CANDIDATES = [
40
  ("s16le", 16000, 1),
41
  ("s16le", 44100, 2),
 
50
  try:
51
  if os.path.exists(MEMORY_FILE):
52
  with open(MEMORY_FILE, "r", encoding="utf-8") as fh:
53
+ data = json.load(fh)
54
+ # validate minimal structure
55
+ if not isinstance(data, dict):
56
+ raise ValueError("memory.json root not dict")
57
+ data.setdefault("words", {})
58
+ data.setdefault("phrases", {})
59
+ return data
60
  except Exception:
61
  pass
62
  mem = {"words": {}, "phrases": {}}
 
70
 
71
  def save_memory(mem):
72
  with MEMORY_LOCK:
73
+ try:
74
+ with open(MEMORY_FILE, "w", encoding="utf-8") as fh:
75
+ json.dump(mem, fh, ensure_ascii=False, indent=2)
76
+ except Exception:
77
+ traceback.print_exc()
78
 
79
 
80
  memory = load_memory()
 
84
  flush=True,
85
  )
86
 
87
+
88
  # ---------- Postprocessing ----------
89
  MEDICAL_ABBREVIATIONS = {
90
  "pt": "patient",
 
243
  return corrected
244
 
245
 
246
+ # ---------- Memory management UI helpers ----------
247
+ def import_memory_file(uploaded):
248
+ """
249
+ Accepts an uploaded file object or filepath. Accepts:
250
+ - JSON of shape {"words": {...}, "phrases": {...}}
251
+ - Plain newline-separated words or CSV-like lines "word,count"
252
+ Returns status message.
253
+ """
254
+ global memory
255
+ if not uploaded:
256
+ return "No file provided."
257
+
258
+ path = None
259
+ try:
260
+ if isinstance(uploaded, (str, os.PathLike)):
261
+ path = str(uploaded)
262
+ elif hasattr(uploaded, "name"):
263
+ path = uploaded.name
264
+ elif isinstance(uploaded, dict) and uploaded.get("name"):
265
+ path = uploaded["name"]
266
+ else:
267
+ return "Unable to determine uploaded file path."
268
+
269
+ # read file
270
+ with open(path, "r", encoding="utf-8") as fh:
271
+ raw = fh.read()
272
+
273
+ # try JSON first
274
+ try:
275
+ parsed = json.loads(raw)
276
+ if isinstance(parsed, dict):
277
+ with MEMORY_LOCK:
278
+ # merge words/phrases
279
+ parsed_words = parsed.get("words", {})
280
+ parsed_phrases = parsed.get("phrases", {})
281
+ for k, v in parsed_words.items():
282
+ memory["words"][k.lower()] = memory["words"].get(k.lower(), 0) + int(v)
283
+ for k, v in parsed_phrases.items():
284
+ memory["phrases"][k] = memory["phrases"].get(k, 0) + int(v)
285
+ save_memory(memory)
286
+ return f"Imported JSON memory (words={len(parsed_words)}, phrases={len(parsed_phrases)})."
287
+ except Exception:
288
+ # not JSON, fallback to newline parse
289
+ pass
290
+
291
+ # fallback: split lines, attempt "word,count" or just "word"
292
+ lines = [l.strip() for l in raw.splitlines() if l.strip()]
293
+ added_words = 0
294
+ added_phrases = 0
295
+ with MEMORY_LOCK:
296
+ for line in lines:
297
+ if "," in line:
298
+ parts = [p.strip() for p in line.split(",", 1)]
299
+ key = parts[0].lower()
300
+ try:
301
+ cnt = int(parts[1])
302
+ except Exception:
303
+ cnt = 1
304
+ memory["words"][key] = memory["words"].get(key, 0) + cnt
305
+ added_words += 1
306
+ else:
307
+ # treat as word (if short) else as phrase
308
+ if len(line.split()) <= 3:
309
+ key = line.lower()
310
+ memory["words"][key] = memory["words"].get(key, 0) + 1
311
+ added_words += 1
312
+ else:
313
+ memory["phrases"][line] = memory["phrases"].get(line, 0) + 1
314
+ added_phrases += 1
315
+ save_memory(memory)
316
+ return f"Imported {added_words} words and {added_phrases} phrases from file."
317
+ except Exception as e:
318
+ traceback.print_exc()
319
+ return f"Import failed: {e}"
320
+
321
+
322
+ def add_memory_entry(entry):
323
+ """
324
+ Add a single 'word' or phrase. If entry is short (<=3 words) treat as word, else phrase.
325
+ """
326
+ global memory
327
+ if not entry or not entry.strip():
328
+ return "No entry provided."
329
+ e = entry.strip()
330
+ with MEMORY_LOCK:
331
+ if len(e.split()) <= 3:
332
+ key = e.lower()
333
+ memory["words"][key] = memory["words"].get(key, 0) + 1
334
+ save_memory(memory)
335
+ return f"Added/updated word: '{key}'."
336
+ else:
337
+ memory["phrases"][e] = memory["phrases"].get(e, 0) + 1
338
+ save_memory(memory)
339
+ return f"Added/updated phrase: '{e}'."
340
+
341
+ def clear_memory():
342
+ global memory
343
+ with MEMORY_LOCK:
344
+ memory = {"words": {}, "phrases": {}}
345
+ save_memory(memory)
346
+ return "Memory cleared."
347
+
348
+ def view_memory(limit=2000):
349
+ """
350
+ Returns a text summary of memory (words sorted by count then phrases).
351
+ limit parameter caps returned characters for UI.
352
+ """
353
+ w = memory.get("words", {})
354
+ p = memory.get("phrases", {})
355
+ out_lines = []
356
+ out_lines.append("WORDS (top 50):")
357
+ for k, v in sorted(w.items(), key=lambda kv: -kv[1])[:50]:
358
+ out_lines.append(f"{k}: {v}")
359
+ out_lines.append("")
360
+ out_lines.append("PHRASES (top 50):")
361
+ for k, v in sorted(p.items(), key=lambda kv: -kv[1])[:50]:
362
+ out_lines.append(f"{k}: {v}")
363
+ out = "\n".join(out_lines)
364
+ if len(out) > limit:
365
+ out = out[:limit] + "\n...truncated..."
366
+ return out
367
+
368
+
369
  # ---------- File utilities ----------
370
  def save_as_word(text, filename=None):
371
  if filename is None:
 
378
  return filename
379
 
380
 
381
+ # ---------- improved ffmpeg convert ----------
382
  def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
383
+ """
384
+ Use ffmpeg to convert input_path -> out_path.
385
+ Let ffmpeg autodetect input unless fmt signals raw PCM.
386
+ """
 
 
 
 
 
 
 
 
 
 
 
 
387
  try:
388
+ cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"]
389
+
390
+ if fmt in ("s16le", "pcm_s16le", "mulaw"):
391
+ # raw input: specify input format and sample params before -i
392
+ cmd += ["-f", fmt, "-ar", str(sr), "-ac", str(ch), "-i", input_path, out_path]
393
+ else:
394
+ # autodetect input, request output sample rate/channels
395
+ cmd += ["-i", input_path, "-ar", str(sr), "-ac", str(ch), out_path]
396
+
397
+ proc = subprocess.run(cmd, capture_output=True, timeout=60, text=True)
398
+ stdout_stderr = (proc.stdout or "") + (proc.stderr or "")
399
+ if proc.returncode == 0 and os.path.exists(out_path) and os.path.getsize(out_path) > MIN_WAV_SIZE:
400
+ return True, stdout_stderr
401
  else:
402
  try:
403
  if os.path.exists(out_path):
404
  os.unlink(out_path)
405
  except Exception:
406
  pass
407
+ return False, stdout_stderr
408
  except Exception as e:
409
  try:
410
  if os.path.exists(out_path):
 
449
  out_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
450
  out_wav.close()
451
  success, debug = _ffmpeg_convert(input_path, out_wav.name, fmt, sr, ch)
452
+ diagnostics.append(f"TRY fmt={fmt} sr={sr} ch={ch} success={success}\n{debug}\n")
 
 
453
  if success:
454
  try:
455
  with open(diag_log, "w", encoding="utf-8") as fh:
 
496
  except Exception as e:
497
  raise Exception(f"Conversion failed; diagnostics write error: {e}")
498
 
499
+ raise Exception(f"Could not convert file to WAV. Diagnostics saved to: {diag_log}")
 
 
500
 
501
 
502
  # ---------- Whisper model cache ----------
 
506
  def get_whisper_model(name):
507
  if name not in MODEL_CACHE:
508
  print(f"DEBUG: loading whisper model '{name}'", flush=True)
509
+ # You can set device by changing whisper.load_model(name, device="cpu") if needed.
510
  MODEL_CACHE[name] = whisper.load_model(name)
511
  return MODEL_CACHE[name]
512
 
 
552
  try:
553
  zf.setpassword(zip_password.encode())
554
  except Exception:
555
+ log.append("Failed to set zip password (unexpected).")
556
+
 
557
  exts = [
558
  ".mp3",
559
  ".wav",
 
572
  if ext.lower() in exts:
573
  try:
574
  zf.extract(info, path=temp_extract_dir)
575
+ except RuntimeError as e:
576
+ log.append(f"Password required or incorrect for {info.filename}: {e}")
577
+ continue
578
+ except pyzipper.BadZipFile:
579
+ log.append(f"Bad zip entry: {info.filename}")
580
+ continue
581
  except Exception as e:
582
  log.append(f"Error extracting {info.filename}: {e}")
583
  continue
584
+ p = os.path.normpath(os.path.join(temp_extract_dir, info.filename))
 
 
585
  if os.path.exists(p):
586
  extracted_audio_paths.append(p)
587
  count += 1
 
700
  try:
701
  if wav and os.path.exists(wav):
702
  tmpdir = tempfile.gettempdir()
703
+ try:
704
+ common = os.path.commonpath([os.path.abspath(tmpdir), os.path.abspath(wav)])
705
+ if common == os.path.abspath(tmpdir) and not p.lower().endswith(".wav"):
706
+ os.unlink(wav)
707
+ except Exception:
708
+ try:
709
+ if tmpdir in os.path.abspath(wav) and not p.lower().endswith(".wav"):
710
+ os.unlink(wav)
711
+ except Exception:
712
+ pass
713
  except Exception:
714
  pass
715
 
 
828
 
829
  default_zip_password = gr.Textbox(
830
  label="Default ZIP password",
831
+ value="dietcoke1",
832
  interactive=True,
833
  )
834
 
 
849
  )
850
 
851
  memory_checkbox = gr.Checkbox(
852
+ label="Enable correction memory (use during transcription)",
853
  value=False,
854
  )
855
 
856
  submit = gr.Button("Transcribe", variant="primary")
857
 
858
+ # Memory management UI
859
+ gr.Markdown("### Memory management")
860
+ mem_upload = gr.File(label="Import memory file (JSON or text)", file_count="single", type="file")
861
+ mem_import_btn = gr.Button("Import Memory File")
862
+ mem_manual_entry = gr.Textbox(label="Add word/phrase to memory (manual)", placeholder="Type a word or phrase")
863
+ mem_add_btn = gr.Button("Add to Memory")
864
+ mem_clear_btn = gr.Button("Clear Memory")
865
+ mem_view_btn = gr.Button("View Memory")
866
+ mem_status = gr.Textbox(label="Memory status", interactive=False, lines=4)
867
+
868
  # RIGHT: Outputs (Transcript → Progress → Download → Logs)
869
  with gr.Column(scale=1):
870
  gr.Markdown("### Output")
 
894
  interactive=False,
895
  )
896
 
897
+ # Transcription click binding
898
  submit.click(
899
  fn=run_transcription_wrapper,
900
  inputs=[
 
912
  outputs=[logs, transcripts_out, download_file, progress_num],
913
  )
914
 
915
+ # Memory button bindings
916
+ def _import_memory(uploaded):
917
+ return import_memory_file(uploaded)
918
+
919
+ mem_import_btn.click(fn=_import_memory, inputs=[mem_upload], outputs=[mem_status])
920
+
921
+ mem_add_btn.click(fn=add_memory_entry, inputs=[mem_manual_entry], outputs=[mem_status])
922
+
923
+ mem_clear_btn.click(fn=lambda: clear_memory(), inputs=[], outputs=[mem_status])
924
+
925
+ mem_view_btn.click(fn=lambda: view_memory(), inputs=[], outputs=[mem_status])
926
+
927
  # ---------- Launch ----------
928
  if __name__ == "__main__":
929
  port = int(os.environ.get("PORT", 7860))