Jonascaps1 commited on
Commit
e6c1f00
·
verified ·
1 Parent(s): e7f8285

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -56
app.py CHANGED
@@ -10,26 +10,28 @@ from pathlib import Path
10
  from tempfile import NamedTemporaryFile
11
  from datetime import timedelta
12
 
13
- # ---------------- LOGGING ----------------
14
- logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
- # ---------------- CONFIG ----------------
18
  MODEL_ID = "KBLab/kb-whisper-large"
19
  CHUNK_DURATION_MS = 10000
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
22
  SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"}
23
 
24
- # ---------------- FFMPEG CHECK ----------------
25
  def check_ffmpeg():
26
  try:
27
  subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
 
28
  return True
29
- except Exception:
 
30
  return False
31
 
32
- # ---------------- LOAD MODEL ----------------
33
  def initialize_pipeline():
34
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
35
  MODEL_ID,
@@ -48,66 +50,77 @@ def initialize_pipeline():
48
  torch_dtype=TORCH_DTYPE
49
  )
50
 
51
- PIPELINE = initialize_pipeline()
52
-
53
- # ---------------- AUDIO UTILS ----------------
54
  def convert_to_wav(audio_path: str) -> str:
55
  if not check_ffmpeg():
56
- raise RuntimeError("ffmpeg not available")
57
 
58
- ext = Path(audio_path).suffix.lower()
59
  if ext not in SUPPORTED_FORMATS:
60
- raise ValueError("Unsupported audio format")
61
 
62
  if ext != ".wav":
63
  audio = AudioSegment.from_file(audio_path)
64
- wav_path = str(Path(audio_path).with_suffix(".wav"))
65
  audio.export(wav_path, format="wav")
66
  return wav_path
67
 
68
  return audio_path
69
 
70
- def split_audio(audio_path: str):
 
71
  audio = AudioSegment.from_file(audio_path)
72
  return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]
73
 
74
- def get_chunk_time(index: int) -> str:
75
- return str(timedelta(milliseconds=index * CHUNK_DURATION_MS))
 
 
76
 
77
- # ---------------- TRANSCRIBE ----------------
78
  def transcribe(audio_path: str, include_timestamps: bool, progress=gr.Progress()):
79
  if not audio_path or not os.path.exists(audio_path):
80
- yield "Please upload an audio file.", None
81
  return
82
 
83
  wav_path = convert_to_wav(audio_path)
84
  chunks = split_audio(wav_path)
85
 
86
  transcript = []
87
- timestamped = []
88
 
89
  for i, chunk in enumerate(chunks):
90
- with NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
91
- chunk.export(tmp.name, format="wav")
92
-
93
- result = PIPELINE(
94
- tmp.name,
95
- generate_kwargs={"task": "transcribe", "language": "sv"}
96
- )
97
-
98
- os.remove(tmp.name)
99
-
100
- text = result["text"].strip()
101
- if text:
102
- transcript.append(text)
103
- if include_timestamps:
104
- ts = get_chunk_time(i)
105
- timestamped.append(f"[{ts}] {text}")
 
 
 
 
 
106
 
107
  progress((i + 1) / len(chunks))
108
- yield " ".join(transcript), None
109
 
110
- content = "\n".join(timestamped) if include_timestamps else " ".join(transcript)
 
 
 
 
 
111
 
112
  with NamedTemporaryFile(
113
  suffix=".txt",
@@ -118,27 +131,33 @@ def transcribe(audio_path: str, include_timestamps: bool, progress=gr.Progress()
118
  f.write(content)
119
  download_path = f.name
120
 
121
- yield " ".join(transcript), download_path
122
 
123
- # ---------------- UI ----------------
124
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
125
- gr.Markdown("# Swedish Whisper Transcriber")
126
- gr.Markdown("Upload an .m4a file and download the transcript with timestamps.")
127
 
128
- with gr.Row():
129
- with gr.Column():
130
- audio_input = gr.Audio(type="filepath", label="Upload Audio (.m4a)")
131
- timestamp_toggle = gr.Checkbox(label="Include timestamps in download")
132
- transcribe_btn = gr.Button("Transcribe")
133
 
134
- with gr.Column():
135
- transcript_output = gr.Textbox(label="Live Transcription", lines=12)
136
- download_output = gr.File(label="Download Transcript")
 
 
137
 
138
- transcribe_btn.click(
139
- fn=transcribe,
140
- inputs=[audio_input, timestamp_toggle],
141
- outputs=[transcript_output, download_output]
142
- )
 
 
 
 
 
 
143
 
144
- demo.launch()
 
 
10
  from tempfile import NamedTemporaryFile
11
  from datetime import timedelta
12
 
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
15
  logger = logging.getLogger(__name__)
16
 
17
+ # Configuration
18
  MODEL_ID = "KBLab/kb-whisper-large"
19
  CHUNK_DURATION_MS = 10000
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
22
  SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"}
23
 
24
+ # Check for ffmpeg availability
25
  def check_ffmpeg():
26
  try:
27
  subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
28
+ logger.info("ffmpeg is installed and accessible.")
29
  return True
30
+ except (subprocess.CalledProcessError, FileNotFoundError):
31
+ logger.error("ffmpeg is not installed or not found in PATH.")
32
  return False
33
 
34
+ # Initialize model and pipeline
35
  def initialize_pipeline():
36
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
37
  MODEL_ID,
 
50
  torch_dtype=TORCH_DTYPE
51
  )
52
 
53
+ # Convert audio if needed
 
 
54
  def convert_to_wav(audio_path: str) -> str:
55
  if not check_ffmpeg():
56
+ raise RuntimeError("ffmpeg is required")
57
 
58
+ ext = str(Path(audio_path).suffix).lower()
59
  if ext not in SUPPORTED_FORMATS:
60
+ raise ValueError(f"Unsupported format: {ext}")
61
 
62
  if ext != ".wav":
63
  audio = AudioSegment.from_file(audio_path)
64
+ wav_path = str(Path(audio_path).with_suffix(".converted.wav"))
65
  audio.export(wav_path, format="wav")
66
  return wav_path
67
 
68
  return audio_path
69
 
70
+ # Split audio into chunks
71
+ def split_audio(audio_path: str) -> list:
72
  audio = AudioSegment.from_file(audio_path)
73
  return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]
74
 
75
+ # Helper to compute chunk start time
76
+ def get_chunk_time(index: int, chunk_duration_ms: int) -> str:
77
+ start_ms = index * chunk_duration_ms
78
+ return str(timedelta(milliseconds=start_ms))
79
 
80
+ # Transcribe audio with streaming + working download
81
  def transcribe(audio_path: str, include_timestamps: bool, progress=gr.Progress()):
82
  if not audio_path or not os.path.exists(audio_path):
83
+ yield "Please upload a valid audio file.", None
84
  return
85
 
86
  wav_path = convert_to_wav(audio_path)
87
  chunks = split_audio(wav_path)
88
 
89
  transcript = []
90
+ timestamped_transcript = []
91
 
92
  for i, chunk in enumerate(chunks):
93
+ temp_file_path = None
94
+ try:
95
+ with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
96
+ temp_file_path = temp_file.name
97
+ chunk.export(temp_file.name, format="wav")
98
+
99
+ result = PIPELINE(
100
+ temp_file.name,
101
+ generate_kwargs={"task": "transcribe", "language": "sv"}
102
+ )
103
+
104
+ text = result["text"].strip()
105
+ if text:
106
+ transcript.append(text)
107
+ if include_timestamps:
108
+ timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
109
+ timestamped_transcript.append(f"[{timestamp}] {text}")
110
+
111
+ finally:
112
+ if temp_file_path and os.path.exists(temp_file_path):
113
+ os.remove(temp_file_path)
114
 
115
  progress((i + 1) / len(chunks))
116
+ yield " ".join(transcript), None # STREAM TEXT ONLY
117
 
118
+ # Create downloadable file ONLY ONCE (fix)
119
+ content = (
120
+ "\n".join(timestamped_transcript)
121
+ if include_timestamps
122
+ else " ".join(transcript)
123
+ )
124
 
125
  with NamedTemporaryFile(
126
  suffix=".txt",
 
131
  f.write(content)
132
  download_path = f.name
133
 
134
+ yield " ".join(transcript), download_path # FINAL OUTPUT
135
 
136
+ # Initialize pipeline globally
137
+ PIPELINE = initialize_pipeline()
 
 
138
 
139
+ # Gradio Interface
140
+ def create_interface():
141
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
142
+ gr.Markdown("# Swedish Whisper Transcriber")
 
143
 
144
+ with gr.Row():
145
+ with gr.Column():
146
+ audio_input = gr.Audio(type="filepath", label="Upload .m4a Audio")
147
+ timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download")
148
+ transcribe_btn = gr.Button("Transcribe")
149
 
150
+ with gr.Column():
151
+ transcript_output = gr.Textbox(label="Live Transcription", lines=10)
152
+ download_output = gr.File(label="Download Transcript")
153
+
154
+ transcribe_btn.click(
155
+ fn=transcribe,
156
+ inputs=[audio_input, timestamp_toggle],
157
+ outputs=[transcript_output, download_output]
158
+ )
159
+
160
+ return demo
161
 
162
+ if __name__ == "__main__":
163
+ create_interface().launch()