Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# app.py
|
| 2 |
-
# Whisper transcription app -
|
| 3 |
-
#
|
|
|
|
| 4 |
|
| 5 |
import os
|
| 6 |
import sys
|
|
@@ -12,13 +13,14 @@ import traceback
|
|
| 12 |
import threading
|
| 13 |
import re
|
| 14 |
from difflib import get_close_matches
|
|
|
|
| 15 |
|
| 16 |
-
# Force unbuffered output
|
| 17 |
os.environ["PYTHONUNBUFFERED"] = "1"
|
| 18 |
|
| 19 |
print("DEBUG: app.py bootstrap starting", flush=True)
|
| 20 |
|
| 21 |
-
# Third-party imports
|
| 22 |
try:
|
| 23 |
from docx import Document
|
| 24 |
import whisper
|
|
@@ -43,9 +45,15 @@ FFMPEG_CANDIDATES = [
|
|
| 43 |
("pcm_s16le", 44100, 2),
|
| 44 |
("mulaw", 8000, 1),
|
| 45 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
# ----------------------------
|
| 47 |
|
| 48 |
-
# ---------- Memory
|
| 49 |
def load_memory():
|
| 50 |
try:
|
| 51 |
if os.path.exists(MEMORY_FILE):
|
|
@@ -77,9 +85,8 @@ def save_memory(mem):
|
|
| 77 |
|
| 78 |
|
| 79 |
memory = load_memory()
|
| 80 |
-
print("DEBUG: memory loaded (words=%d phrases=%d)" % (len(memory.get("words", {})), len(memory.get("phrases", {}))), flush=True)
|
| 81 |
|
| 82 |
-
|
| 83 |
MEDICAL_ABBREVIATIONS = {
|
| 84 |
"pt": "patient",
|
| 85 |
"dx": "diagnosis",
|
|
@@ -160,7 +167,6 @@ def postprocess_transcript(text, format_soap=False):
|
|
| 160 |
return t
|
| 161 |
|
| 162 |
|
| 163 |
-
# ---------- Memory utilities (same as before) ----------
|
| 164 |
def extract_words_and_phrases(text):
|
| 165 |
words = re.findall(r"[A-Za-z0-9\-']+", text)
|
| 166 |
sentences = [s.strip() for s in re.split(r"(?<=[.?!])\s+", text) if s.strip()]
|
|
@@ -228,7 +234,7 @@ def memory_correct_text(text, min_ratio=0.85):
|
|
| 228 |
return corrected
|
| 229 |
|
| 230 |
|
| 231 |
-
# ---------- Memory management
|
| 232 |
def import_memory_file(uploaded):
|
| 233 |
global memory
|
| 234 |
if not uploaded:
|
|
@@ -340,7 +346,7 @@ def save_as_word(text, filename=None):
|
|
| 340 |
return filename
|
| 341 |
|
| 342 |
|
| 343 |
-
# ----------
|
| 344 |
def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
|
| 345 |
try:
|
| 346 |
cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"]
|
|
@@ -465,15 +471,11 @@ def get_whisper_model(name, device=None):
|
|
| 465 |
return MODEL_CACHE[name]
|
| 466 |
|
| 467 |
|
| 468 |
-
# ---------- ZIP extraction
|
| 469 |
def extract_zip_list(zip_file, zip_password):
|
| 470 |
-
"""
|
| 471 |
-
Extract zip to a temp dir and return (list_of_paths, diagnostics_text)
|
| 472 |
-
"""
|
| 473 |
temp_extract_dir = os.path.join(tempfile.gettempdir(), "extracted_audio")
|
| 474 |
try:
|
| 475 |
if os.path.exists(temp_extract_dir):
|
| 476 |
-
# clear existing
|
| 477 |
try:
|
| 478 |
shutil.rmtree(temp_extract_dir)
|
| 479 |
except Exception:
|
|
@@ -511,14 +513,13 @@ def extract_zip_list(zip_file, zip_password):
|
|
| 511 |
if not extracted:
|
| 512 |
logs.append("No supported audio files found in zip.")
|
| 513 |
return [], "\n".join(logs)
|
| 514 |
-
# Return list and logs
|
| 515 |
return extracted, "\n".join(logs)
|
| 516 |
except Exception as e:
|
| 517 |
traceback.print_exc()
|
| 518 |
return [], f"Extraction failed: {e}"
|
| 519 |
|
| 520 |
|
| 521 |
-
# ----------
|
| 522 |
def transcribe_multiple(
|
| 523 |
selected_paths,
|
| 524 |
model_name,
|
|
@@ -527,10 +528,6 @@ def transcribe_multiple(
|
|
| 527 |
enable_memory=False,
|
| 528 |
device=None,
|
| 529 |
):
|
| 530 |
-
"""
|
| 531 |
-
Generator yields (log_text, transcripts_text, merged_file_path_or_None, percent_int)
|
| 532 |
-
selected_paths: list of absolute file paths to process
|
| 533 |
-
"""
|
| 534 |
log = []
|
| 535 |
transcripts = []
|
| 536 |
word_file_path = None
|
|
@@ -542,7 +539,6 @@ def transcribe_multiple(
|
|
| 542 |
|
| 543 |
yield "", "", None, 0
|
| 544 |
|
| 545 |
-
# load model
|
| 546 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, 5
|
| 547 |
try:
|
| 548 |
model = get_whisper_model(model_name, device=device)
|
|
@@ -554,16 +550,16 @@ def transcribe_multiple(
|
|
| 554 |
|
| 555 |
total = len(selected_paths)
|
| 556 |
for idx, p in enumerate(selected_paths, start=1):
|
| 557 |
-
log.append(f"Processing file ({idx}/{total}): {p}")
|
| 558 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(5 + (idx - 1) * 80 / max(1, total))
|
| 559 |
|
| 560 |
wav = None
|
| 561 |
try:
|
| 562 |
wav = convert_to_wav_if_needed(p)
|
| 563 |
-
log.append(f"Converted to WAV: {wav}")
|
| 564 |
except Exception as e:
|
| 565 |
-
log.append(f"Conversion failed for {p}: {e}")
|
| 566 |
-
transcripts.append(f"FILE: {os.path.basename(p)}\nERROR: Conversion failed: {e}")
|
| 567 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(5 + idx * 80 / max(1, total))
|
| 568 |
continue
|
| 569 |
|
|
@@ -579,7 +575,7 @@ def transcribe_multiple(
|
|
| 579 |
if enable_memory:
|
| 580 |
text = memory_correct_text(text)
|
| 581 |
text = postprocess_transcript(text)
|
| 582 |
-
transcripts.append(f"FILE: {os.path.basename(p)}\n{text}\n")
|
| 583 |
|
| 584 |
if enable_memory:
|
| 585 |
try:
|
|
@@ -590,8 +586,8 @@ def transcribe_multiple(
|
|
| 590 |
|
| 591 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(10 + idx * 85 / max(1, total))
|
| 592 |
except Exception as e:
|
| 593 |
-
log.append(f"Transcription failed for {p}: {e}")
|
| 594 |
-
transcripts.append(f"FILE: {os.path.basename(p)}\nERROR: Transcription failed: {e}")
|
| 595 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(10 + idx * 85 / max(1, total))
|
| 596 |
continue
|
| 597 |
finally:
|
|
@@ -600,11 +596,11 @@ def transcribe_multiple(
|
|
| 600 |
tmpdir = tempfile.gettempdir()
|
| 601 |
try:
|
| 602 |
common = os.path.commonpath([os.path.abspath(tmpdir), os.path.abspath(wav)])
|
| 603 |
-
if common == os.path.abspath(tmpdir) and not p.lower().endswith(".wav"):
|
| 604 |
os.unlink(wav)
|
| 605 |
except Exception:
|
| 606 |
try:
|
| 607 |
-
if tmpdir in os.path.abspath(wav) and not p.lower().endswith(".wav"):
|
| 608 |
os.unlink(wav)
|
| 609 |
except Exception:
|
| 610 |
pass
|
|
@@ -623,177 +619,330 @@ def transcribe_multiple(
|
|
| 623 |
yield "\n\n".join(log), "\n\n".join(transcripts), word_file_path, 100
|
| 624 |
|
| 625 |
|
| 626 |
-
#
|
| 627 |
-
def
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
else:
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
try:
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
yield logs_text, transcripts_text, word_path, percent
|
| 714 |
-
except Exception:
|
| 715 |
-
tb = traceback.format_exc()
|
| 716 |
-
logs_text = f"EXCEPTION in run_transcription_ui:\n{tb}"
|
| 717 |
-
transcripts_text = "ERROR: transcription did not start or failed unexpectedly."
|
| 718 |
-
yield logs_text, transcripts_text, None, 100
|
| 719 |
|
| 720 |
|
| 721 |
-
#
|
| 722 |
print("DEBUG: building Gradio Blocks", flush=True)
|
| 723 |
-
with gr.Blocks(title="Whisper Transcriber —
|
| 724 |
-
gr.Markdown(
|
| 725 |
-
|
| 726 |
-
"<p>Upload audio files or a ZIP, extract and choose files, then transcribe.</p>",
|
| 727 |
-
)
|
| 728 |
|
| 729 |
with gr.Tabs():
|
| 730 |
-
# ---------------- Transcribe
|
| 731 |
-
with gr.TabItem("Transcribe"):
|
| 732 |
with gr.Row():
|
| 733 |
with gr.Column(scale=1):
|
| 734 |
-
gr.Markdown("###
|
| 735 |
-
|
| 736 |
-
file_input = gr.File(label="Audio files (optional)", file_count="multiple", type="filepath", height=80)
|
| 737 |
-
zip_input = gr.File(label="ZIP with audio (optional)", file_count="single", type="filepath", height=80)
|
| 738 |
-
|
| 739 |
with gr.Row():
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
|
| 750 |
-
|
| 751 |
-
extract_btn = gr.Button("Extract ZIP & List Files")
|
| 752 |
-
extracted_files_check = gr.CheckboxGroup(choices=[], label="Select extracted files to transcribe (optional)", interactive=True)
|
| 753 |
-
extract_logs = gr.Textbox(label="Extraction logs", interactive=False, lines=6)
|
| 754 |
|
| 755 |
-
|
| 756 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
with gr.Column(scale=1):
|
| 758 |
gr.Markdown("### Output")
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 775 |
inputs=[
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
device_choice,
|
| 788 |
],
|
| 789 |
-
outputs=[
|
| 790 |
)
|
| 791 |
|
| 792 |
-
# ---------------- Memory
|
| 793 |
with gr.TabItem("Memory"):
|
| 794 |
with gr.Row():
|
| 795 |
with gr.Column(scale=1):
|
| 796 |
-
gr.Markdown("### Memory
|
| 797 |
mem_upload = gr.File(label="Import memory file (JSON or text)", file_count="single", type="filepath")
|
| 798 |
mem_import_btn = gr.Button("Import Memory File")
|
| 799 |
mem_manual_entry = gr.Textbox(label="Add word/phrase to memory (manual)", placeholder="Type a word or phrase")
|
|
@@ -802,7 +951,6 @@ with gr.Blocks(title="Whisper Transcriber — Multi-tab") as demo:
|
|
| 802 |
mem_view_btn = gr.Button("View Memory")
|
| 803 |
mem_status = gr.Textbox(label="Memory status", interactive=False, lines=12)
|
| 804 |
|
| 805 |
-
# memory bindings
|
| 806 |
def _import_mem(uploaded):
|
| 807 |
return import_memory_file(uploaded)
|
| 808 |
|
|
@@ -811,22 +959,66 @@ with gr.Blocks(title="Whisper Transcriber — Multi-tab") as demo:
|
|
| 811 |
mem_clear_btn.click(fn=lambda: clear_memory(), inputs=[], outputs=[mem_status])
|
| 812 |
mem_view_btn.click(fn=lambda: view_memory(), inputs=[], outputs=[mem_status])
|
| 813 |
|
| 814 |
-
# ----------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 815 |
with gr.TabItem("Settings"):
|
| 816 |
with gr.Row():
|
| 817 |
with gr.Column():
|
| 818 |
-
gr.Markdown("###
|
| 819 |
-
gr.Markdown("-
|
| 820 |
-
gr.Markdown("-
|
| 821 |
-
gr.Markdown("-
|
| 822 |
with gr.Column():
|
| 823 |
gr.Markdown("### Diagnostics")
|
| 824 |
diag_btn = gr.Button("Show memory summary")
|
| 825 |
diag_out = gr.Textbox(label="Diagnostics output", interactive=False, lines=12)
|
| 826 |
-
|
| 827 |
diag_btn.click(fn=lambda: view_memory(), inputs=[], outputs=[diag_out])
|
| 828 |
|
| 829 |
-
#
|
| 830 |
|
| 831 |
# ---------- Launch ----------
|
| 832 |
if __name__ == "__main__":
|
|
|
|
| 1 |
# app.py
|
| 2 |
+
# Whisper transcription app - Redesigned UI: Tabs for different works
|
| 3 |
+
# Features: Audio Transcribe, Batch Transcribe (ZIP extraction + selection), Memory, Fine-tune, Settings
|
| 4 |
+
# Drop-in replacement. Requires dependencies: gradio, whisper, pydub, pyzipper, python-docx, ffmpeg.
|
| 5 |
|
| 6 |
import os
|
| 7 |
import sys
|
|
|
|
| 13 |
import threading
|
| 14 |
import re
|
| 15 |
from difflib import get_close_matches
|
| 16 |
+
from pathlib import Path
|
| 17 |
|
| 18 |
+
# Force unbuffered output so container logs show prints immediately
|
| 19 |
os.environ["PYTHONUNBUFFERED"] = "1"
|
| 20 |
|
| 21 |
print("DEBUG: app.py bootstrap starting", flush=True)
|
| 22 |
|
| 23 |
+
# Third-party imports (must be installed in the environment)
|
| 24 |
try:
|
| 25 |
from docx import Document
|
| 26 |
import whisper
|
|
|
|
| 45 |
("pcm_s16le", 44100, 2),
|
| 46 |
("mulaw", 8000, 1),
|
| 47 |
]
|
| 48 |
+
# Fine-tune globals
|
| 49 |
+
FINETUNE_PROC = None
|
| 50 |
+
FINETUNE_LOCK = threading.Lock()
|
| 51 |
+
FINETUNE_LOG = os.path.join(tempfile.gettempdir(), "finetune_logs.txt")
|
| 52 |
+
FINETUNE_WORKDIR = os.path.join(tempfile.gettempdir(), "finetune_workdir")
|
| 53 |
+
os.makedirs(FINETUNE_WORKDIR, exist_ok=True)
|
| 54 |
# ----------------------------
|
| 55 |
|
| 56 |
+
# ---------- Utilities / Memory / Postprocessing ----------
|
| 57 |
def load_memory():
|
| 58 |
try:
|
| 59 |
if os.path.exists(MEMORY_FILE):
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
memory = load_memory()
|
|
|
|
| 88 |
|
| 89 |
+
|
| 90 |
MEDICAL_ABBREVIATIONS = {
|
| 91 |
"pt": "patient",
|
| 92 |
"dx": "diagnosis",
|
|
|
|
| 167 |
return t
|
| 168 |
|
| 169 |
|
|
|
|
| 170 |
def extract_words_and_phrases(text):
|
| 171 |
words = re.findall(r"[A-Za-z0-9\-']+", text)
|
| 172 |
sentences = [s.strip() for s in re.split(r"(?<=[.?!])\s+", text) if s.strip()]
|
|
|
|
| 234 |
return corrected
|
| 235 |
|
| 236 |
|
| 237 |
+
# ---------- Memory management helpers ----------
|
| 238 |
def import_memory_file(uploaded):
|
| 239 |
global memory
|
| 240 |
if not uploaded:
|
|
|
|
| 346 |
return filename
|
| 347 |
|
| 348 |
|
| 349 |
+
# ---------- Conversion helpers (pydub + ffmpeg fallback) ----------
|
| 350 |
def _ffmpeg_convert(input_path, out_path, fmt, sr, ch):
|
| 351 |
try:
|
| 352 |
cmd = ["ffmpeg", "-hide_banner", "-loglevel", "error", "-y"]
|
|
|
|
| 471 |
return MODEL_CACHE[name]
|
| 472 |
|
| 473 |
|
| 474 |
+
# ---------- ZIP extraction helpers ----------
|
| 475 |
def extract_zip_list(zip_file, zip_password):
|
|
|
|
|
|
|
|
|
|
| 476 |
temp_extract_dir = os.path.join(tempfile.gettempdir(), "extracted_audio")
|
| 477 |
try:
|
| 478 |
if os.path.exists(temp_extract_dir):
|
|
|
|
| 479 |
try:
|
| 480 |
shutil.rmtree(temp_extract_dir)
|
| 481 |
except Exception:
|
|
|
|
| 513 |
if not extracted:
|
| 514 |
logs.append("No supported audio files found in zip.")
|
| 515 |
return [], "\n".join(logs)
|
|
|
|
| 516 |
return extracted, "\n".join(logs)
|
| 517 |
except Exception as e:
|
| 518 |
traceback.print_exc()
|
| 519 |
return [], f"Extraction failed: {e}"
|
| 520 |
|
| 521 |
|
| 522 |
+
# ---------- Transcription generator used by both Audio and Batch workflows ----------
|
| 523 |
def transcribe_multiple(
|
| 524 |
selected_paths,
|
| 525 |
model_name,
|
|
|
|
| 528 |
enable_memory=False,
|
| 529 |
device=None,
|
| 530 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
log = []
|
| 532 |
transcripts = []
|
| 533 |
word_file_path = None
|
|
|
|
| 539 |
|
| 540 |
yield "", "", None, 0
|
| 541 |
|
|
|
|
| 542 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, 5
|
| 543 |
try:
|
| 544 |
model = get_whisper_model(model_name, device=device)
|
|
|
|
| 550 |
|
| 551 |
total = len(selected_paths)
|
| 552 |
for idx, p in enumerate(selected_paths, start=1):
|
| 553 |
+
log.append(f"Processing file ({idx}/{total}): {os.path.basename(str(p))}")
|
| 554 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(5 + (idx - 1) * 80 / max(1, total))
|
| 555 |
|
| 556 |
wav = None
|
| 557 |
try:
|
| 558 |
wav = convert_to_wav_if_needed(p)
|
| 559 |
+
log.append(f"Converted to WAV: {os.path.basename(str(wav))}")
|
| 560 |
except Exception as e:
|
| 561 |
+
log.append(f"Conversion failed for {os.path.basename(str(p))}: {e}")
|
| 562 |
+
transcripts.append(f"FILE: {os.path.basename(str(p))}\nERROR: Conversion failed: {e}")
|
| 563 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(5 + idx * 80 / max(1, total))
|
| 564 |
continue
|
| 565 |
|
|
|
|
| 575 |
if enable_memory:
|
| 576 |
text = memory_correct_text(text)
|
| 577 |
text = postprocess_transcript(text)
|
| 578 |
+
transcripts.append(f"FILE: {os.path.basename(str(p))}\n{text}\n")
|
| 579 |
|
| 580 |
if enable_memory:
|
| 581 |
try:
|
|
|
|
| 586 |
|
| 587 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(10 + idx * 85 / max(1, total))
|
| 588 |
except Exception as e:
|
| 589 |
+
log.append(f"Transcription failed for {os.path.basename(str(p))}: {e}")
|
| 590 |
+
transcripts.append(f"FILE: {os.path.basename(str(p))}\nERROR: Transcription failed: {e}")
|
| 591 |
yield "\n\n".join(log), "\n\n".join(transcripts), None, int(10 + idx * 85 / max(1, total))
|
| 592 |
continue
|
| 593 |
finally:
|
|
|
|
| 596 |
tmpdir = tempfile.gettempdir()
|
| 597 |
try:
|
| 598 |
common = os.path.commonpath([os.path.abspath(tmpdir), os.path.abspath(wav)])
|
| 599 |
+
if common == os.path.abspath(tmpdir) and not str(p).lower().endswith(".wav"):
|
| 600 |
os.unlink(wav)
|
| 601 |
except Exception:
|
| 602 |
try:
|
| 603 |
+
if tmpdir in os.path.abspath(wav) and not str(p).lower().endswith(".wav"):
|
| 604 |
os.unlink(wav)
|
| 605 |
except Exception:
|
| 606 |
pass
|
|
|
|
| 619 |
yield "\n\n".join(log), "\n\n".join(transcripts), word_file_path, 100
|
| 620 |
|
| 621 |
|
| 622 |
+
# ---------- Fine-tune helpers (same as earlier) ----------
|
| 623 |
+
def _safe_write_log(msg):
|
| 624 |
+
try:
|
| 625 |
+
with open(FINETUNE_LOG, "a", encoding="utf-8") as fh:
|
| 626 |
+
fh.write(msg + "\n")
|
| 627 |
+
except Exception:
|
| 628 |
+
pass
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def prepare_finetune_dataset(uploaded_zip_or_dir):
|
| 632 |
+
dst = os.path.join(FINETUNE_WORKDIR, "data")
|
| 633 |
+
try:
|
| 634 |
+
if os.path.exists(dst):
|
| 635 |
+
shutil.rmtree(dst)
|
| 636 |
+
os.makedirs(dst, exist_ok=True)
|
| 637 |
+
except Exception as e:
|
| 638 |
+
return f"Failed to prepare workdir: {e}", ""
|
| 639 |
+
|
| 640 |
+
path = None
|
| 641 |
+
try:
|
| 642 |
+
if not uploaded_zip_or_dir:
|
| 643 |
+
return "No dataset file or dir provided.", ""
|
| 644 |
+
if isinstance(uploaded_zip_or_dir, (str, os.PathLike)):
|
| 645 |
+
path = str(uploaded_zip_or_dir)
|
| 646 |
+
elif hasattr(uploaded_zip_or_dir, "name"):
|
| 647 |
+
path = uploaded_zip_or_dir.name
|
| 648 |
+
elif isinstance(uploaded_zip_or_dir, dict) and uploaded_zip_or_dir.get("name"):
|
| 649 |
+
path = uploaded_zip_or_dir["name"]
|
| 650 |
+
except Exception as e:
|
| 651 |
+
return f"Unable to determine uploaded path: {e}", ""
|
| 652 |
+
|
| 653 |
+
if os.path.isfile(path) and path.lower().endswith(".zip"):
|
| 654 |
+
try:
|
| 655 |
+
with pyzipper.ZipFile(path, "r") as zf:
|
| 656 |
+
zf.extractall(dst)
|
| 657 |
+
except Exception as e:
|
| 658 |
+
return f"Failed to extract ZIP: {e}", ""
|
| 659 |
+
elif os.path.isdir(path):
|
| 660 |
+
try:
|
| 661 |
+
for item in os.listdir(path):
|
| 662 |
+
s = os.path.join(path, item)
|
| 663 |
+
d = os.path.join(dst, item)
|
| 664 |
+
if os.path.isdir(s):
|
| 665 |
+
shutil.copytree(s, d)
|
| 666 |
+
else:
|
| 667 |
+
shutil.copy2(s, d)
|
| 668 |
+
except Exception as e:
|
| 669 |
+
return f"Failed to copy dataset dir: {e}", ""
|
| 670 |
else:
|
| 671 |
+
return "Uploaded file is not zip or directory.", ""
|
| 672 |
+
|
| 673 |
+
transcripts_candidates = [
|
| 674 |
+
os.path.join(dst, "transcripts.tsv"),
|
| 675 |
+
os.path.join(dst, "metadata.tsv"),
|
| 676 |
+
os.path.join(dst, "manifest.tsv"),
|
| 677 |
+
os.path.join(dst, "transcripts.txt"),
|
| 678 |
+
os.path.join(dst, "metadata.txt"),
|
| 679 |
+
]
|
| 680 |
+
manifest_path = os.path.join(FINETUNE_WORKDIR, "manifest.tsv")
|
| 681 |
+
found = False
|
| 682 |
+
|
| 683 |
+
for tpath in transcripts_candidates:
|
| 684 |
+
if os.path.exists(tpath):
|
| 685 |
+
try:
|
| 686 |
+
shutil.copy2(tpath, manifest_path)
|
| 687 |
+
found = True
|
| 688 |
+
break
|
| 689 |
+
except Exception:
|
| 690 |
+
pass
|
| 691 |
+
|
| 692 |
+
if not found:
|
| 693 |
+
audio_files = []
|
| 694 |
+
for root, _, files in os.walk(dst):
|
| 695 |
+
for f in files:
|
| 696 |
+
if f.lower().endswith((".wav", ".mp3", ".flac", ".m4a", ".ogg")):
|
| 697 |
+
audio_files.append(os.path.join(root, f))
|
| 698 |
+
if not audio_files:
|
| 699 |
+
return "No audio files found in dataset.", ""
|
| 700 |
+
entries = []
|
| 701 |
+
for a in audio_files:
|
| 702 |
+
base = os.path.splitext(a)[0]
|
| 703 |
+
t_candidate = base + ".txt"
|
| 704 |
+
transcript = ""
|
| 705 |
+
if os.path.exists(t_candidate):
|
| 706 |
+
try:
|
| 707 |
+
with open(t_candidate, "r", encoding="utf-8") as fh:
|
| 708 |
+
transcript = fh.read().strip().replace("\n", " ")
|
| 709 |
+
except Exception:
|
| 710 |
+
transcript = ""
|
| 711 |
+
entries.append(f"{a}\t{transcript}")
|
| 712 |
+
try:
|
| 713 |
+
with open(manifest_path, "w", encoding="utf-8") as fh:
|
| 714 |
+
fh.write("\n".join(entries))
|
| 715 |
+
found = True
|
| 716 |
+
except Exception as e:
|
| 717 |
+
return f"Failed to write manifest: {e}", ""
|
| 718 |
+
|
| 719 |
+
if not found:
|
| 720 |
+
return "Failed to locate or build manifest.", ""
|
| 721 |
|
| 722 |
+
return f"Dataset prepared. Manifest: {manifest_path}", manifest_path
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def start_finetune(manifest_path, base_model, epochs, batch_size, lr, output_dir):
|
| 726 |
+
global FINETUNE_PROC
|
| 727 |
+
with FINETUNE_LOCK:
|
| 728 |
+
if FINETUNE_PROC and FINETUNE_PROC.poll() is None:
|
| 729 |
+
return "Fine-tune already running."
|
| 730 |
+
|
| 731 |
+
outdir = output_dir or os.path.join(FINETUNE_WORKDIR, "output")
|
| 732 |
+
os.makedirs(outdir, exist_ok=True)
|
| 733 |
+
|
| 734 |
+
try:
|
| 735 |
+
if os.path.exists(FINETUNE_LOG):
|
| 736 |
+
os.remove(FINETUNE_LOG)
|
| 737 |
+
except Exception:
|
| 738 |
+
pass
|
| 739 |
+
|
| 740 |
+
START_CMD = [
|
| 741 |
+
sys.executable,
|
| 742 |
+
"fine_tune.py",
|
| 743 |
+
"--manifest",
|
| 744 |
+
manifest_path,
|
| 745 |
+
"--base_model",
|
| 746 |
+
base_model,
|
| 747 |
+
"--epochs",
|
| 748 |
+
str(epochs),
|
| 749 |
+
"--batch_size",
|
| 750 |
+
str(batch_size),
|
| 751 |
+
"--lr",
|
| 752 |
+
str(lr),
|
| 753 |
+
"--output_dir",
|
| 754 |
+
outdir,
|
| 755 |
+
]
|
| 756 |
+
try:
|
| 757 |
+
logfile = open(FINETUNE_LOG, "a", encoding="utf-8")
|
| 758 |
+
proc = subprocess.Popen(START_CMD, stdout=logfile, stderr=logfile, cwd=os.getcwd())
|
| 759 |
+
FINETUNE_PROC = proc
|
| 760 |
+
_safe_write_log(f"Started fine-tune: PID={proc.pid}, cmd={' '.join(START_CMD)}")
|
| 761 |
+
return f"Fine-tune started (PID={proc.pid}). Logs: {FINETUNE_LOG}"
|
| 762 |
+
except FileNotFoundError as e:
|
| 763 |
+
return f"Training script not found: {e}. Put your training script 'fine_tune.py' in project root or change START_CMD."
|
| 764 |
+
except Exception as e:
|
| 765 |
+
return f"Failed to start fine-tune: {e}"
|
| 766 |
|
| 767 |
+
|
| 768 |
+
def stop_finetune():
|
| 769 |
+
global FINETUNE_PROC
|
| 770 |
+
with FINETUNE_LOCK:
|
| 771 |
+
if not FINETUNE_PROC:
|
| 772 |
+
return "No running fine-tune process."
|
| 773 |
+
try:
|
| 774 |
+
FINETUNE_PROC.terminate()
|
| 775 |
+
FINETUNE_PROC.wait(timeout=10)
|
| 776 |
+
pid = FINETUNE_PROC.pid
|
| 777 |
+
FINETUNE_PROC = None
|
| 778 |
+
_safe_write_log(f"Terminated fine-tune PID={pid}")
|
| 779 |
+
return f"Terminated fine-tune PID={pid}"
|
| 780 |
+
except Exception as e:
|
| 781 |
+
try:
|
| 782 |
+
FINETUNE_PROC.kill()
|
| 783 |
+
except Exception:
|
| 784 |
+
pass
|
| 785 |
+
FINETUNE_PROC = None
|
| 786 |
+
return f"Force killed fine-tune process: {e}"
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
def tail_finetune_logs(lines=50):
|
| 790 |
try:
|
| 791 |
+
if not os.path.exists(FINETUNE_LOG):
|
| 792 |
+
return "No logs yet."
|
| 793 |
+
with open(FINETUNE_LOG, "r", encoding="utf-8", errors="ignore") as fh:
|
| 794 |
+
all_lines = fh.read().splitlines()
|
| 795 |
+
last = all_lines[-lines:]
|
| 796 |
+
return "\n".join(last)
|
| 797 |
+
except Exception as e:
|
| 798 |
+
return f"Failed to read logs: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 799 |
|
| 800 |
|
| 801 |
+
# ----------------------- Gradio UI -----------------------
|
| 802 |
print("DEBUG: building Gradio Blocks", flush=True)
|
| 803 |
+
with gr.Blocks(title="Whisper Transcriber — Redesigned UI") as demo:
|
| 804 |
+
gr.Markdown("<h1 style='margin-bottom:0.25rem;'>Whisper Transcriber</h1>")
|
| 805 |
+
gr.Markdown("<p style='margin-top:0.1rem;color:#666;'>Organize work by tabs. Quick single-file transcription, batch workflows, memory and fine-tune tools.</p>")
|
|
|
|
|
|
|
| 806 |
|
| 807 |
with gr.Tabs():
|
| 808 |
+
# ---------------- Audio Transcribe (single-file focused) ----------------
|
| 809 |
+
with gr.TabItem("Audio Transcribe"):
|
| 810 |
with gr.Row():
|
| 811 |
with gr.Column(scale=1):
|
| 812 |
+
gr.Markdown("### Quick audio transcription")
|
| 813 |
+
single_audio = gr.Audio(label="Upload or record an audio file", type="filepath", interactive=True)
|
|
|
|
|
|
|
|
|
|
| 814 |
with gr.Row():
|
| 815 |
+
audio_model = gr.Dropdown(choices=["small", "medium", "large", "base"], value="small", label="Model", interactive=True)
|
| 816 |
+
audio_device = gr.Dropdown(choices=["auto", "cpu", "cuda"], value="auto", label="Device", interactive=True)
|
| 817 |
+
audio_enable_memory = gr.Checkbox(label="Enable correction memory", value=False)
|
| 818 |
+
audio_transcribe_btn = gr.Button("Transcribe Audio", variant="primary")
|
| 819 |
+
audio_help = gr.Markdown("<small>Upload a single audio file (wav/mp3/m4a/ogg). Use memory to apply learned corrections.</small>")
|
| 820 |
+
with gr.Column(scale=1):
|
| 821 |
+
gr.Markdown("### Player & Transcript")
|
| 822 |
+
audio_player_out = gr.Audio(label="Player", interactive=False)
|
| 823 |
+
audio_transcript_out = gr.Textbox(label="Transcript", lines=16, interactive=False)
|
| 824 |
+
audio_logs = gr.Textbox(label="Logs", lines=10, interactive=False)
|
| 825 |
+
|
| 826 |
+
def _single_transcribe(audio_path, model_name, enable_memory, device_choice):
|
| 827 |
+
logs = []
|
| 828 |
+
transcripts = []
|
| 829 |
+
if not audio_path:
|
| 830 |
+
return None, "No audio uploaded.", "No file provided."
|
| 831 |
+
# Normalize single path (gr.Audio returns path)
|
| 832 |
+
path = str(audio_path)
|
| 833 |
+
try:
|
| 834 |
+
model = get_whisper_model(model_name, device=(None if device_choice == "auto" else device_choice))
|
| 835 |
+
logs.append(f"Loaded model: {model_name}")
|
| 836 |
+
except Exception as e:
|
| 837 |
+
tb = traceback.format_exc()
|
| 838 |
+
return None, "", f"Failed to load model: {e}\n{tb}"
|
| 839 |
+
try:
|
| 840 |
+
wav = convert_to_wav_if_needed(path)
|
| 841 |
+
logs.append(f"Converted to WAV: {os.path.basename(wav)}")
|
| 842 |
+
except Exception as e:
|
| 843 |
+
return None, "", f"Conversion failed: {e}"
|
| 844 |
+
try:
|
| 845 |
+
result = model.transcribe(wav)
|
| 846 |
+
text = result.get("text", "").strip()
|
| 847 |
+
if enable_memory:
|
| 848 |
+
text = memory_correct_text(text)
|
| 849 |
+
text = postprocess_transcript(text)
|
| 850 |
+
transcripts = text
|
| 851 |
+
# update memory optionally
|
| 852 |
+
if enable_memory:
|
| 853 |
+
try:
|
| 854 |
+
update_memory_with_transcript(text)
|
| 855 |
+
logs.append("Memory updated.")
|
| 856 |
+
except Exception:
|
| 857 |
+
pass
|
| 858 |
+
except Exception as e:
|
| 859 |
+
return None, "", f"Transcription failed: {e}"
|
| 860 |
+
finally:
|
| 861 |
+
try:
|
| 862 |
+
if wav and os.path.exists(wav) and wav != path:
|
| 863 |
+
# remove tmp wav produced by conversion
|
| 864 |
+
try:
|
| 865 |
+
os.unlink(wav)
|
| 866 |
+
except Exception:
|
| 867 |
+
pass
|
| 868 |
+
except Exception:
|
| 869 |
+
pass
|
| 870 |
+
# audio_player_out accepts filepath
|
| 871 |
+
return path, transcripts, "\n".join(logs)
|
| 872 |
|
| 873 |
+
audio_transcribe_btn.click(fn=_single_transcribe, inputs=[single_audio, audio_model, audio_enable_memory, audio_device], outputs=[audio_player_out, audio_transcript_out, audio_logs])
|
|
|
|
|
|
|
|
|
|
| 874 |
|
| 875 |
+
# ---------------- Batch Transcribe ----------------
|
| 876 |
+
with gr.TabItem("Batch Transcribe"):
|
| 877 |
+
with gr.Row():
|
| 878 |
+
with gr.Column(scale=1):
|
| 879 |
+
gr.Markdown("### Batch / ZIP workflow")
|
| 880 |
+
batch_files = gr.File(label="Upload multiple audio files (optional)", file_count="multiple", type="filepath")
|
| 881 |
+
batch_zip = gr.File(label="Or upload ZIP with audio", file_count="single", type="filepath")
|
| 882 |
+
with gr.Row():
|
| 883 |
+
batch_zip_password = gr.Textbox(label="ZIP password (override)", placeholder="Optional")
|
| 884 |
+
batch_use_default_zip_pass = gr.Checkbox(label="Use default ZIP password", value=False)
|
| 885 |
+
batch_default_zip_password = gr.Textbox(label="Default ZIP password", value="", interactive=True)
|
| 886 |
+
batch_model = gr.Dropdown(choices=["small", "medium", "large", "base"], value="small", label="Model")
|
| 887 |
+
batch_device = gr.Dropdown(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
|
| 888 |
+
batch_merge = gr.Checkbox(label="Merge all transcripts into one .docx", value=True)
|
| 889 |
+
batch_enable_memory = gr.Checkbox(label="Enable correction memory", value=False)
|
| 890 |
+
gr.Markdown("### Extraction")
|
| 891 |
+
batch_extract_btn = gr.Button("Extract ZIP & List Files")
|
| 892 |
+
batch_extracted_check = gr.CheckboxGroup(choices=[], label="Select extracted files to transcribe (optional)", interactive=True)
|
| 893 |
+
batch_extract_logs = gr.Textbox(label="Extraction logs", interactive=False, lines=6)
|
| 894 |
+
batch_transcribe_btn = gr.Button("Transcribe Selected / Uploaded", variant="primary")
|
| 895 |
with gr.Column(scale=1):
|
| 896 |
gr.Markdown("### Output")
|
| 897 |
+
batch_transcripts_out = gr.Textbox(label="Transcript (cumulative)", lines=20, interactive=False)
|
| 898 |
+
batch_progress = gr.Slider(minimum=0, maximum=100, value=0, step=1, label="Progress (%)", interactive=False)
|
| 899 |
+
batch_download_file = gr.File(label="Merged .docx (when available)")
|
| 900 |
+
batch_logs = gr.Textbox(label="Logs", lines=12, interactive=False)
|
| 901 |
+
|
| 902 |
+
def _batch_extract(zip_file, zip_password, use_default_zip_pass, default_zip_password):
|
| 903 |
+
if use_default_zip_pass and (not zip_password or zip_password.strip() == ""):
|
| 904 |
+
final_zip_password = default_zip_password
|
| 905 |
+
else:
|
| 906 |
+
final_zip_password = zip_password
|
| 907 |
+
if not zip_file:
|
| 908 |
+
return [], "No ZIP file provided."
|
| 909 |
+
zip_path = None
|
| 910 |
+
if isinstance(zip_file, (str, os.PathLike)):
|
| 911 |
+
zip_path = str(zip_file)
|
| 912 |
+
elif hasattr(zip_file, "name"):
|
| 913 |
+
zip_path = zip_file.name
|
| 914 |
+
else:
|
| 915 |
+
return [], "Unable to determine uploaded zip path."
|
| 916 |
+
extracted, logs = extract_zip_list(zip_path, final_zip_password)
|
| 917 |
+
# For nicer UI, present basenames in the extract logs.
|
| 918 |
+
short_logs = logs + "\n\nFiles:\n" + "\n".join([os.path.basename(p) for p in extracted])
|
| 919 |
+
return extracted, short_logs
|
| 920 |
+
|
| 921 |
+
batch_extract_btn.click(fn=_batch_extract, inputs=[batch_zip, batch_zip_password, batch_use_default_zip_pass, batch_default_zip_password], outputs=[batch_extracted_check, batch_extract_logs])
|
| 922 |
+
|
| 923 |
+
batch_transcribe_btn.click(
|
| 924 |
+
fn=run_transcription_ui if 'run_transcription_ui' in globals() else None,
|
| 925 |
inputs=[
|
| 926 |
+
batch_extracted_check,
|
| 927 |
+
batch_files,
|
| 928 |
+
batch_model,
|
| 929 |
+
batch_merge,
|
| 930 |
+
batch_extracted_check,
|
| 931 |
+
batch_zip,
|
| 932 |
+
batch_zip_password,
|
| 933 |
+
batch_use_default_zip_pass,
|
| 934 |
+
batch_default_zip_password,
|
| 935 |
+
batch_enable_memory,
|
| 936 |
+
batch_device,
|
|
|
|
| 937 |
],
|
| 938 |
+
outputs=[batch_logs, batch_transcripts_out, batch_download_file, batch_progress],
|
| 939 |
)
|
| 940 |
|
| 941 |
+
# ---------------- Memory ----------------
|
| 942 |
with gr.TabItem("Memory"):
|
| 943 |
with gr.Row():
|
| 944 |
with gr.Column(scale=1):
|
| 945 |
+
gr.Markdown("### Memory management")
|
| 946 |
mem_upload = gr.File(label="Import memory file (JSON or text)", file_count="single", type="filepath")
|
| 947 |
mem_import_btn = gr.Button("Import Memory File")
|
| 948 |
mem_manual_entry = gr.Textbox(label="Add word/phrase to memory (manual)", placeholder="Type a word or phrase")
|
|
|
|
| 951 |
mem_view_btn = gr.Button("View Memory")
|
| 952 |
mem_status = gr.Textbox(label="Memory status", interactive=False, lines=12)
|
| 953 |
|
|
|
|
| 954 |
def _import_mem(uploaded):
|
| 955 |
return import_memory_file(uploaded)
|
| 956 |
|
|
|
|
| 959 |
mem_clear_btn.click(fn=lambda: clear_memory(), inputs=[], outputs=[mem_status])
|
| 960 |
mem_view_btn.click(fn=lambda: view_memory(), inputs=[], outputs=[mem_status])
|
| 961 |
|
| 962 |
+
# ---------------- Fine-tune ----------------
|
| 963 |
+
with gr.TabItem("Fine-tune"):
|
| 964 |
+
with gr.Row():
|
| 965 |
+
with gr.Column(scale=1):
|
| 966 |
+
gr.Markdown("### Prepare dataset & start fine-tuning")
|
| 967 |
+
ft_upload = gr.File(label="Upload training ZIP or folder (zip)", file_count="single", type="filepath")
|
| 968 |
+
ft_prepare_btn = gr.Button("Prepare dataset")
|
| 969 |
+
ft_prepare_status = gr.Textbox(label="Prepare status / manifest", interactive=False, lines=4)
|
| 970 |
+
|
| 971 |
+
gr.Markdown("### Training parameters")
|
| 972 |
+
ft_base_model = gr.Dropdown(choices=["small", "base", "medium", "large"], value="small", label="Base model")
|
| 973 |
+
ft_epochs = gr.Slider(minimum=1, maximum=100, value=3, step=1, label="Epochs")
|
| 974 |
+
ft_batch = gr.Number(label="Batch size", value=8)
|
| 975 |
+
ft_lr = gr.Number(label="Learning rate", value=1e-5, precision=8)
|
| 976 |
+
ft_output_dir = gr.Textbox(label="Output dir (optional)", value="", placeholder="Leave blank to use temp output")
|
| 977 |
+
|
| 978 |
+
ft_start_btn = gr.Button("Start Fine-tune")
|
| 979 |
+
ft_stop_btn = gr.Button("Stop Fine-tune")
|
| 980 |
+
ft_start_status = gr.Textbox(label="Start/Stop status", interactive=False, lines=4)
|
| 981 |
+
|
| 982 |
+
ft_tail_btn = gr.Button("Tail training logs")
|
| 983 |
+
ft_logs = gr.Textbox(label="Training logs (tail)", interactive=False, lines=12)
|
| 984 |
+
with gr.Column(scale=1):
|
| 985 |
+
gr.Markdown("### Fine-tune notes")
|
| 986 |
+
gr.Markdown(
|
| 987 |
+
"- The app calls `python fine_tune.py --manifest <manifest> ...` by default; provide your training script or change START_CMD."
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
def _prepare_action(ft_upload):
|
| 991 |
+
status, manifest = prepare_finetune_dataset(ft_upload)
|
| 992 |
+
return status
|
| 993 |
+
|
| 994 |
+
ft_prepare_btn.click(fn=_prepare_action, inputs=[ft_upload], outputs=[ft_prepare_status])
|
| 995 |
+
|
| 996 |
+
def _start_action(ft_prepare_status_txt, ft_base_model, ft_epochs, ft_batch, ft_lr, ft_output_dir):
|
| 997 |
+
manifest_guess = os.path.join(FINETUNE_WORKDIR, "manifest.tsv")
|
| 998 |
+
if not os.path.exists(manifest_guess):
|
| 999 |
+
return "Manifest not found. Prepare dataset first or manually provide manifest."
|
| 1000 |
+
status = start_finetune(manifest_guess, ft_base_model, int(ft_epochs), int(ft_batch), float(ft_lr), ft_output_dir)
|
| 1001 |
+
return status
|
| 1002 |
+
|
| 1003 |
+
ft_start_btn.click(fn=_start_action, inputs=[ft_prepare_status, ft_base_model, ft_epochs, ft_batch, ft_lr, ft_output_dir], outputs=[ft_start_status])
|
| 1004 |
+
ft_stop_btn.click(fn=lambda: stop_finetune(), inputs=[], outputs=[ft_start_status])
|
| 1005 |
+
ft_tail_btn.click(fn=lambda: tail_finetune_logs(), inputs=[], outputs=[ft_logs])
|
| 1006 |
+
|
| 1007 |
+
# ---------------- Settings ----------------
|
| 1008 |
with gr.TabItem("Settings"):
|
| 1009 |
with gr.Row():
|
| 1010 |
with gr.Column():
|
| 1011 |
+
gr.Markdown("### Runtime & Tips")
|
| 1012 |
+
gr.Markdown("- Device: choose CPU or CUDA in workflows. If CUDA isn't available, leave `auto` or `cpu`.")
|
| 1013 |
+
gr.Markdown("- Keep default ZIP password empty for safety.")
|
| 1014 |
+
gr.Markdown("- Extraction writes to system temp dir (extracted_audio). Re-extracting overwrites it.")
|
| 1015 |
with gr.Column():
|
| 1016 |
gr.Markdown("### Diagnostics")
|
| 1017 |
diag_btn = gr.Button("Show memory summary")
|
| 1018 |
diag_out = gr.Textbox(label="Diagnostics output", interactive=False, lines=12)
|
|
|
|
| 1019 |
diag_btn.click(fn=lambda: view_memory(), inputs=[], outputs=[diag_out])
|
| 1020 |
|
| 1021 |
+
# End tabs
|
| 1022 |
|
| 1023 |
# ---------- Launch ----------
|
| 1024 |
if __name__ == "__main__":
|