Jonascaps1 commited on
Commit
e7f8285
·
verified ·
1 Parent(s): f12cc38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -198
app.py CHANGED
@@ -10,218 +10,135 @@ from pathlib import Path
10
  from tempfile import NamedTemporaryFile
11
  from datetime import timedelta
12
 
13
- # Setup logging
14
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
15
  logger = logging.getLogger(__name__)
16
 
17
- # Configuration
18
  MODEL_ID = "KBLab/kb-whisper-large"
19
  CHUNK_DURATION_MS = 10000
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
22
  SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"}
23
 
24
- # Check for ffmpeg availability
25
  def check_ffmpeg():
26
  try:
27
  subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
28
- logger.info("ffmpeg is installed and accessible.")
29
  return True
30
- except (subprocess.CalledProcessError, FileNotFoundError):
31
- logger.error("ffmpeg is not installed or not found in PATH.")
32
  return False
33
 
34
- # Initialize model and pipeline
35
  def initialize_pipeline():
36
- try:
37
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
38
- MODEL_ID,
39
- torch_dtype=TORCH_DTYPE,
40
- low_cpu_mem_usage=True
41
- ).to(DEVICE)
42
- processor = AutoProcessor.from_pretrained(MODEL_ID)
43
- return pipeline(
44
- "automatic-speech-recognition",
45
- model=model,
46
- tokenizer=processor.tokenizer,
47
- feature_extractor=processor.feature_extractor,
48
- device=DEVICE,
49
- torch_dtype=TORCH_DTYPE,
50
- model_kwargs={"use_flash_attention_2": torch.cuda.is_available()}
51
- )
52
- except Exception as e:
53
- logger.error(f"Failed to initialize pipeline: {str(e)}")
54
- raise RuntimeError("Unable to load transcription model. Please check your network connection or model ID.")
55
-
56
- # Convert audio if needed
57
  def convert_to_wav(audio_path: str) -> str:
58
- try:
59
- if not check_ffmpeg():
60
- raise RuntimeError("ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.")
61
- ext = str(Path(audio_path).suffix).lower()
62
- if ext not in SUPPORTED_FORMATS:
63
- raise ValueError(f"Unsupported audio format: {ext}. Supported formats: {', '.join(SUPPORTED_FORMATS)}")
64
- if ext != ".wav":
65
- logger.info(f"Converting {ext} file to WAV: {audio_path}")
66
- audio = AudioSegment.from_file(audio_path)
67
- wav_path = str(Path(audio_path).with_suffix(".converted.wav"))
68
- audio.export(wav_path, format="wav")
69
- logger.info(f"Conversion successful: {wav_path}")
70
- return wav_path
71
- return audio_path
72
- except CouldntDecodeError:
73
- logger.error(f"Failed to decode .m4a file: {audio_path}")
74
- raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording and ffmpeg is installed.")
75
- except OSError as e:
76
- logger.error(f"OS error during audio conversion: {str(e)}")
77
- raise ValueError("Failed to process the .m4a file due to a system error. Check file permissions or disk space.")
78
- except Exception as e:
79
- logger.error(f"Unexpected error during .m4a conversion: {str(e)}")
80
- raise ValueError(f"An unexpected error occurred while converting the .m4a file: {str(e)}")
81
-
82
- # Split audio into chunks
83
- def split_audio(audio_path: str) -> list:
84
- try:
85
  audio = AudioSegment.from_file(audio_path)
86
- if len(audio) == 0:
87
- raise ValueError("The .m4a file is empty or invalid.")
88
- logger.info(f"Splitting audio into {CHUNK_DURATION_MS/1000}-second chunks: {audio_path}")
89
- return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]
90
- except CouldntDecodeError:
91
- logger.error(f"Failed to decode audio for splitting: {audio_path}")
92
- raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording.")
93
- except Exception as e:
94
- logger.error(f"Failed to split audio: {str(e)}")
95
- raise ValueError(f"Failed to process the .m4a file: {str(e)}")
96
-
97
- # Helper to compute chunk start time
98
- def get_chunk_time(index: int, chunk_duration_ms: int) -> str:
99
- start_ms = index * chunk_duration_ms
100
- return str(timedelta(milliseconds=start_ms))
101
-
102
- # Transcribe audio with progress and timestamps
103
- def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Progress()):
104
- try:
105
- if not audio_path or not os.path.exists(audio_path):
106
- logger.warning("Invalid or missing audio file path.")
107
- return "Please upload a valid .m4a file.", None
108
-
109
- # Convert to WAV if needed
110
- wav_path = convert_to_wav(audio_path)
111
-
112
- # Split and process
113
- chunks = split_audio(wav_path)
114
- total_chunks = len(chunks)
115
- transcript = []
116
- timestamped_transcript = []
117
- failed_chunks = 0
118
-
119
- for i, chunk in enumerate(chunks):
120
- temp_file_path = None
121
- try:
122
- with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
123
- temp_file_path = temp_file.name
124
- chunk.export(temp_file.name, format="wav")
125
- result = PIPELINE(temp_file.name,
126
- generate_kwargs={"task": "transcribe", "language": "sv"})
127
- text = result["text"].strip()
128
- if text:
129
- transcript.append(text)
130
- if include_timestamps:
131
- timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
132
- timestamped_transcript.append(f"[{timestamp}] {text}")
133
- except RuntimeError as e:
134
- logger.warning(f"Failed to transcribe chunk {i+1}/{total_chunks}: {str(e)}")
135
- failed_chunks += 1
136
- transcript.append("[Transcription failed for this segment]")
137
- if include_timestamps:
138
- timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
139
- timestamped_transcript.append(f"[{timestamp}] [Transcription failed]")
140
- except Exception as e:
141
- logger.error(f"Unexpected error in chunk {i+1}/{total_chunks}: {str(e)}")
142
- failed_chunks += 1
143
- transcript.append("[Transcription failed for this segment]")
144
- if include_timestamps:
145
- timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
146
- timestamped_transcript.append(f"[{timestamp}] [Transcription failed]")
147
- finally:
148
- if temp_file_path and os.path.exists(temp_file_path):
149
- try:
150
- os.remove(temp_file_path)
151
- except OSError as e:
152
- logger.warning(f"Failed to delete temporary file {temp_file_path}: {str(e)}")
153
-
154
- progress((i + 1) / total_chunks)
155
- yield " ".join(transcript), None
156
-
157
- # Clean up converted file if created
158
- if wav_path != audio_path and os.path.exists(wav_path):
159
- try:
160
- os.remove(wav_path)
161
- except OSError as e:
162
- logger.warning(f"Failed to delete converted WAV file {wav_path}: {str(e)}")
163
-
164
- # Prepare final transcript and downloadable file
165
- final_transcript = " ".join(transcript)
166
- if failed_chunks > 0:
167
- final_transcript = f"Warning: {failed_chunks}/{total_chunks} chunks failed to transcribe.\n{final_transcript}"
168
-
169
- download_content = "\n".join(timestamped_transcript) if include_timestamps else final_transcript
170
- download_path = None
171
- try:
172
- with NamedTemporaryFile(suffix=".txt", delete=False, mode='w', encoding='utf-8') as temp_file:
173
- temp_file.write(download_content)
174
- download_path = temp_file.name
175
- except OSError as e:
176
- logger.error(f"Failed to create downloadable transcript: {str(e)}")
177
- final_transcript = f"{final_transcript}\nNote: Could not generate downloadable transcript due to a file error."
178
-
179
- return final_transcript, download_path
180
-
181
- except ValueError as e:
182
- logger.error(f"Value error during transcription: {str(e)}")
183
- return str(e), None
184
- except Exception as e:
185
- logger.error(f"Unexpected error during transcription: {str(e)}")
186
- return f"An unexpected error occurred while processing the .m4a file: {str(e)}. Please ensure the file is a valid iPhone recording and try again.", None
187
-
188
- # Initialize pipeline globally
189
- try:
190
- PIPELINE = initialize_pipeline()
191
- except RuntimeError as e:
192
- logger.critical(f"Pipeline initialization failed: {str(e)}")
193
- raise
194
-
195
- # Gradio Interface with Blocks
196
- def create_interface():
197
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
198
- gr.Markdown("# Swedish Whisper Transcriber")
199
- gr.Markdown("Upload an .m4a file from your iPhone for real-time Swedish speech transcription.")
200
-
201
- with gr.Row():
202
- with gr.Column():
203
- audio_input = gr.Audio(type="filepath", label="Upload .m4a Audio")
204
- timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download", value=False)
205
- transcribe_btn = gr.Button("Transcribe")
206
-
207
- with gr.Column():
208
- transcript_output = gr.Textbox(label="Live Transcription", lines=10)
209
- download_output = gr.File(label="Download Transcript")
210
-
211
- transcribe_btn.click(
212
- fn=transcribe,
213
- inputs=[audio_input, timestamp_toggle],
214
- outputs=[transcript_output, download_output]
215
- )
216
-
217
- return demo
218
-
219
- if __name__ == "__main__":
220
- try:
221
- if not check_ffmpeg():
222
- print("Error: ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.")
223
- exit(1)
224
- create_interface().launch()
225
- except Exception as e:
226
- logger.critical(f"Failed to launch Gradio interface: {str(e)}")
227
- print(f"Error: Could not start the application. Please check the logs for details.")
 
10
  from tempfile import NamedTemporaryFile
11
  from datetime import timedelta
12
 
13
+ # ---------------- LOGGING ----------------
14
+ logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ # ---------------- CONFIG ----------------
18
  MODEL_ID = "KBLab/kb-whisper-large"
19
  CHUNK_DURATION_MS = 10000
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
22
  SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"}
23
 
24
+ # ---------------- FFMPEG CHECK ----------------
25
  def check_ffmpeg():
26
  try:
27
  subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
 
28
  return True
29
+ except Exception:
 
30
  return False
31
 
32
+ # ---------------- LOAD MODEL ----------------
33
  def initialize_pipeline():
34
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
35
+ MODEL_ID,
36
+ torch_dtype=TORCH_DTYPE,
37
+ low_cpu_mem_usage=True
38
+ ).to(DEVICE)
39
+
40
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
41
+
42
+ return pipeline(
43
+ "automatic-speech-recognition",
44
+ model=model,
45
+ tokenizer=processor.tokenizer,
46
+ feature_extractor=processor.feature_extractor,
47
+ device=DEVICE,
48
+ torch_dtype=TORCH_DTYPE
49
+ )
50
+
51
+ PIPELINE = initialize_pipeline()
52
+
53
+ # ---------------- AUDIO UTILS ----------------
 
54
  def convert_to_wav(audio_path: str) -> str:
55
+ if not check_ffmpeg():
56
+ raise RuntimeError("ffmpeg not available")
57
+
58
+ ext = Path(audio_path).suffix.lower()
59
+ if ext not in SUPPORTED_FORMATS:
60
+ raise ValueError("Unsupported audio format")
61
+
62
+ if ext != ".wav":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  audio = AudioSegment.from_file(audio_path)
64
+ wav_path = str(Path(audio_path).with_suffix(".wav"))
65
+ audio.export(wav_path, format="wav")
66
+ return wav_path
67
+
68
+ return audio_path
69
+
70
+ def split_audio(audio_path: str):
71
+ audio = AudioSegment.from_file(audio_path)
72
+ return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]
73
+
74
+ def get_chunk_time(index: int) -> str:
75
+ return str(timedelta(milliseconds=index * CHUNK_DURATION_MS))
76
+
77
+ # ---------------- TRANSCRIBE ----------------
78
+ def transcribe(audio_path: str, include_timestamps: bool, progress=gr.Progress()):
79
+ if not audio_path or not os.path.exists(audio_path):
80
+ yield "Please upload an audio file.", None
81
+ return
82
+
83
+ wav_path = convert_to_wav(audio_path)
84
+ chunks = split_audio(wav_path)
85
+
86
+ transcript = []
87
+ timestamped = []
88
+
89
+ for i, chunk in enumerate(chunks):
90
+ with NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
91
+ chunk.export(tmp.name, format="wav")
92
+
93
+ result = PIPELINE(
94
+ tmp.name,
95
+ generate_kwargs={"task": "transcribe", "language": "sv"}
96
+ )
97
+
98
+ os.remove(tmp.name)
99
+
100
+ text = result["text"].strip()
101
+ if text:
102
+ transcript.append(text)
103
+ if include_timestamps:
104
+ ts = get_chunk_time(i)
105
+ timestamped.append(f"[{ts}] {text}")
106
+
107
+ progress((i + 1) / len(chunks))
108
+ yield " ".join(transcript), None
109
+
110
+ content = "\n".join(timestamped) if include_timestamps else " ".join(transcript)
111
+
112
+ with NamedTemporaryFile(
113
+ suffix=".txt",
114
+ delete=False,
115
+ mode="w",
116
+ encoding="utf-8"
117
+ ) as f:
118
+ f.write(content)
119
+ download_path = f.name
120
+
121
+ yield " ".join(transcript), download_path
122
+
123
+ # ---------------- UI ----------------
124
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
125
+ gr.Markdown("# Swedish Whisper Transcriber")
126
+ gr.Markdown("Upload an .m4a file and download the transcript with timestamps.")
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ audio_input = gr.Audio(type="filepath", label="Upload Audio (.m4a)")
131
+ timestamp_toggle = gr.Checkbox(label="Include timestamps in download")
132
+ transcribe_btn = gr.Button("Transcribe")
133
+
134
+ with gr.Column():
135
+ transcript_output = gr.Textbox(label="Live Transcription", lines=12)
136
+ download_output = gr.File(label="Download Transcript")
137
+
138
+ transcribe_btn.click(
139
+ fn=transcribe,
140
+ inputs=[audio_input, timestamp_toggle],
141
+ outputs=[transcript_output, download_output]
142
+ )
143
+
144
+ demo.launch()