Spaces:
Runtime error
Runtime error
| !pip install -U git+https://github.com/openai/whisper.git | |
| !pip install -U gradio | |
| !pip install torch | |
| !pip install numpy | |
| !pip install ffmpeg-python | |
| !pip install tqdm | |
| !pip install jiwer | |
| from huggingface_hub import InferenceClient | |
| import os | |
| import asyncio | |
| import whisper | |
| import gradio as gr | |
| import torch | |
| import shutil | |
| import logging | |
| from pathlib import Path | |
| import concurrent.futures | |
| import ffmpeg | |
| import re | |
| import threading | |
| from tqdm.notebook import tqdm | |
| import numpy as np | |
| # --- File Handling --- | |
| # Define paths and constants | |
| TEMP_FOLDER = '/content/temp/' | |
| SUPPORTED_AUDIO_FORMATS = ['.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.amr', '.wma'] | |
| SUPPORTED_VIDEO_FORMATS = ['.mp4', '.avi', '.mov', '.wmv', '.mkv', '.webm', '.3gp'] | |
| SUPPORTED_FORMATS = SUPPORTED_AUDIO_FORMATS + SUPPORTED_VIDEO_FORMATS | |
| def create_folders(): | |
| """Creates the necessary temporary folder if it doesn't exist.""" | |
| Path(TEMP_FOLDER).mkdir(parents=True, exist_ok=True) | |
| def is_supported_format(file): | |
| """Checks if a file has a supported audio/video format.""" | |
| if file is not None: | |
| return any(file.lower().endswith(ext) for ext in SUPPORTED_FORMATS) | |
| else: | |
| return False | |
| def convert_to_wav(original_file_path): | |
| """Converts input file to WAV format.""" | |
| output_path = os.path.join(TEMP_FOLDER, os.path.splitext(os.path.basename(original_file_path))[0] + '.wav') | |
| try: | |
| ( | |
| ffmpeg | |
| .input(original_file_path) | |
| .output(output_path, acodec='pcm_s16le', ac=1, ar='16k') | |
| .overwrite_output() | |
| .run(capture_stdout=True, capture_stderr=True) | |
| ) | |
| return output_path | |
| except ffmpeg.Error as e: | |
| print(f'Error converting {original_file_path}: {e.stderr.decode()}') | |
| return None | |
| def delete_temp_file(file_path): | |
| """Deletes a temporary file.""" | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| # --- Transcription --- | |
| class WhisperModelCache: | |
| """Singleton class to load and cache the Whisper model.""" | |
| _instance = None | |
| def get_instance(): | |
| """Get the singleton instance.""" | |
| if WhisperModelCache._instance is None: | |
| WhisperModelCache._instance = WhisperModelCache() | |
| return WhisperModelCache._instance | |
| def __init__(self): | |
| self.model = None | |
| self.device = None | |
| def load_model(self): | |
| """Loads the Whisper model, prioritizing GPU and handling memory.""" | |
| if self.model is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logging.info(f"Using device: {self.device}") | |
| model_size = "large-v2" if torch.cuda.is_available() else "medium" | |
| logging.info(f"Loading Whisper model: {model_size}") | |
| try: | |
| self.model = whisper.load_model(model_size, device=self.device) | |
| except RuntimeError as e: | |
| if "out of memory" in str(e): | |
| logging.error(f"Error: {e}") | |
| logging.warning("Falling back to 'medium' model size due to memory constraints.") | |
| self.model = whisper.load_model("medium", device=self.device) | |
| else: | |
| raise e | |
| return self.model | |
| def unload_model(self): | |
| """Unloads the model and clears CUDA cache.""" | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logging.info("Model unloaded and CUDA cache cleared.") | |
| async def transcribe_audio(audio_path, language, progress_bar, | |
| task='transcribe', initial_prompt=None, | |
| temperature=0.5, chunk_duration=30): | |
| """Transcribes audio using Whisper, handling chunking and errors.""" | |
| try: | |
| model = WhisperModelCache.get_instance().load_model() | |
| device = WhisperModelCache.get_instance().device | |
| probe = ffmpeg.probe(audio_path) | |
| total_duration = float(probe['format']['duration']) | |
| num_chunks = int(total_duration // chunk_duration) + (total_duration % chunk_duration > 0) | |
| progress_per_chunk = 20 / num_chunks | |
| full_transcription = "" | |
| for chunk_idx in range(num_chunks): | |
| start_time = chunk_idx * chunk_duration | |
| end_time = min((chunk_idx + 1) * chunk_duration, total_duration) | |
| temp_chunk_path = f"{TEMP_FOLDER}/temp_chunk_{chunk_idx}.wav" | |
| try: | |
| ( | |
| ffmpeg | |
| .input(audio_path) | |
| .filter('atrim', start=start_time, end=end_time) | |
| .output(temp_chunk_path, acodec='pcm_s16le', ac=1, ar='16k') | |
| .overwrite_output() | |
| .run(capture_stdout=True, capture_stderr=True) | |
| ) | |
| except ffmpeg.Error as e: | |
| logging.error(f"Error extracting audio chunk: {e.stderr.decode()}") | |
| return "Error: Could not extract audio chunk for transcription" | |
| result = await asyncio.to_thread(model.transcribe, temp_chunk_path, | |
| language=language, | |
| task=task, | |
| initial_prompt=initial_prompt, | |
| temperature=temperature) | |
| full_transcription += result['text'] | |
| progress_bar.update(progress_per_chunk) | |
| delete_temp_file(temp_chunk_path) | |
| return full_transcription | |
| except Exception as e: | |
| logging.error(f"Error transcribing {audio_path}: {str(e)}") | |
| return f"Error during transcription: {str(e)}" | |
| # --- Anonymization --- | |
| def anonymize_text(text): | |
| """Anonymizes personal information in text.""" | |
| text = re.sub(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b|\S+@\S+|\d{3}[-.]?\d{3}[-.]?\d{4}', | |
| lambda m: '[NAME]' if re.match(r'\b[A-Z][a-z]+ [A-Z][a-z]+\b', m.group()) else | |
| '[EMAIL]' if '@' in m.group() else '[PHONE]', | |
| text) | |
| return text | |
| # --- Gradio UI --- | |
| async def process_audio(file, language, anonymize): | |
| """Processes audio: validation, conversion, transcription, anonymization, cleanup.""" | |
| try: | |
| if file is None: | |
| return "Error: Please upload an audio or video file." | |
| if not is_supported_format(file): | |
| raise ValueError(f"Unsupported file format: {file}") | |
| progress_bar = tqdm(total=100, desc="Overall Process", unit="%", position=0, leave=True) | |
| progress_bar.update(10) | |
| temp_audio_path = convert_to_wav(file) | |
| if not temp_audio_path: | |
| raise ValueError(f"Failed to convert {file} to WAV format.") | |
| progress_bar.update(30) | |
| transcription = await transcribe_audio(temp_audio_path, language, progress_bar) | |
| progress_bar.update(20) | |
| delete_temp_file(temp_audio_path) | |
| if anonymize: | |
| transcription = anonymize_text(transcription) | |
| progress_bar.update(10) | |
| progress_bar.update(30) | |
| progress_bar.close() | |
| return transcription | |
| except Exception as e: | |
| print(f"Error processing audio: {e}") | |
| return f"Error: {str(e)}" | |
| def create_ui(): | |
| """Create the Gradio UI.""" | |
| language_choices = ["en", "es", "fr", "de", "it", "pt", "nl", "ru", "zh", "ja", "ko", "ar", "he", "iw", "ar", "auto"] | |
| output_format_choices = ["txt", "srt", "vtt", "tsv", "json"] | |
| with gr.Blocks() as interface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(label="Upload Audio/Video", type="filepath") | |
| task_dropdown = gr.Dropdown( | |
| choices=["Transcribe", "Translate"], | |
| label="Task", | |
| value="Transcribe" | |
| ) | |
| language_dropdown = gr.Dropdown( | |
| choices=language_choices, | |
| label="Language", | |
| value="en", # Default to English | |
| info="Select 'auto' for automatic language detection." | |
| ) | |
| output_format_checkbox_group = gr.CheckboxGroup( | |
| choices=output_format_choices, | |
| label="Output Formats", | |
| value=["txt"] | |
| ) | |
| anonymize_checkbox = gr.Checkbox(label="Anonymize Transcription") | |
| prompt_input = gr.Textbox( | |
| label="Initial Prompt", | |
| lines=2, | |
| placeholder="Optional prompt to guide transcription" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.5, | |
| label="Temperature" | |
| ) | |
| timestamps_checkbox = gr.Checkbox(label="Include Word Timestamps") | |
| transcribe_button = gr.Button(value="Transcribe") | |
| with gr.Column(): | |
| transcription_output = gr.Textbox(label="Transcription", lines=10) | |
| transcribe_button.click( | |
| fn=process_audio, | |
| inputs=[audio_input, language_dropdown, anonymize_checkbox], | |
| outputs=transcription_output | |
| ) | |
| return interface | |
| # --- Main Execution --- | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| create_folders() | |
| iface = create_ui() | |
| iface.launch(debug=True, share=True) |