Spaces:
Running
Running
| import gradio as gr | |
| import whisperx | |
| import torch | |
| import librosa | |
| import logging | |
| import os | |
| import time | |
| import numpy as np | |
| import requests | |
| import random | |
| import string | |
| import json | |
| import pathlib | |
| import tempfile | |
| # ------------------------------- | |
| # Vocal Extraction Function | |
| # ------------------------------- | |
| def get_vocals(input_file): | |
| try: | |
| session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11)) | |
| file_id = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11)) | |
| file_len = 0 | |
| file_content = pathlib.Path(input_file).read_bytes() | |
| file_len = len(file_content) | |
| r = requests.post( | |
| f'https://politrees-audio-separator-uvr.hf.space/gradio_api/upload?upload_id={file_id}', | |
| files={'files': open(input_file, 'rb')} | |
| ) | |
| json_data = r.json() | |
| headers = { | |
| 'accept': '*/*', | |
| 'accept-language': 'en-US,en;q=0.5', | |
| 'content-type': 'application/json', | |
| 'origin': 'https://politrees-audio-separator-uvr.hf.space', | |
| 'priority': 'u=1, i', | |
| 'referer': 'https://politrees-audio-separator-uvr.hf.space/?__theme=system', | |
| 'sec-ch-ua': '"Not(A:Brand";v="99", "Brave";v="133", "Chromium";v="133"', | |
| 'sec-ch-ua-mobile': '?0', | |
| 'sec-ch-ua-platform': '"Windows"', | |
| 'sec-fetch-dest': 'empty', | |
| 'sec-fetch-mode': 'cors', | |
| 'sec-fetch-site': 'same-origin', | |
| 'sec-fetch-storage-access': 'none', | |
| 'sec-gpc': '1', | |
| 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36', | |
| } | |
| params = { | |
| '__theme': 'system', | |
| } | |
| json_payload = { | |
| 'data': [ | |
| { | |
| 'path': json_data[0], | |
| 'url': 'https://politrees-audio-separator-uvr.hf.space/gradio_api/file=' + json_data[0], | |
| 'orig_name': pathlib.Path(input_file).name, | |
| 'size': file_len, | |
| 'mime_type': 'audio/wav', | |
| 'meta': { | |
| '_type': 'gradio.FileData', | |
| }, | |
| }, | |
| 'MelBand Roformer | Vocals by Kimberley Jensen', | |
| 256, | |
| False, | |
| 5, | |
| 0, | |
| '/tmp/audio-separator-models/', | |
| 'output', | |
| 'wav', | |
| 0.9, | |
| 0, | |
| 1, | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| ], | |
| 'event_data': None, | |
| 'fn_index': 5, | |
| 'trigger_id': 28, | |
| 'session_hash': session_hash, | |
| } | |
| response = requests.post( | |
| 'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/join', | |
| params=params, | |
| headers=headers, | |
| json=json_payload, | |
| ) | |
| max_retries = 5 | |
| retry_delay = 5 | |
| retry_count = 0 | |
| while retry_count < max_retries: | |
| try: | |
| print(f"Connecting to stream... Attempt {retry_count + 1}") | |
| r = requests.get( | |
| f'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/data?session_hash={session_hash}', | |
| stream=True | |
| ) | |
| if r.status_code != 200: | |
| raise Exception(f"Failed to connect: HTTP {r.status_code}") | |
| print("Connected successfully.") | |
| for line in r.iter_lines(): | |
| if line: | |
| json_resp = json.loads(line.decode('utf-8').replace('data: ', '')) | |
| print(json_resp) | |
| if 'process_completed' in json_resp['msg']: | |
| print("Process completed.") | |
| output_url = json_resp['output']['data'][1]['url'] | |
| print(f"Output URL: {output_url}") | |
| return output_url | |
| print("Stream ended prematurely. Reconnecting...") | |
| except Exception as e: | |
| print(f"Error occurred: {e}. Retrying...") | |
| retry_count += 1 | |
| time.sleep(retry_delay) | |
| print("Max retries reached. Exiting.") | |
| return None | |
| except Exception as ex: | |
| print(f"Unexpected error in get_vocals: {ex}") | |
| return None | |
| # ------------------------------- | |
| # Logging and Model Setup | |
| # ------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("whisperx_app") | |
| device = "cpu" | |
| compute_type = "int8" | |
| torch.set_num_threads(os.cpu_count()) | |
| models = { | |
| "tiny": whisperx.load_model("tiny", device, compute_type=compute_type, vad_method='silero'), | |
| "base": whisperx.load_model("base", device, compute_type=compute_type, vad_method='silero'), | |
| "small": whisperx.load_model("small", device, compute_type=compute_type, vad_method='silero'), | |
| "large": whisperx.load_model("large", device, compute_type=compute_type, vad_method='silero'), | |
| "large-v2": whisperx.load_model("large-v2", device, compute_type=compute_type, vad_method='silero'), | |
| "large-v3": whisperx.load_model("large-v3", device, compute_type=compute_type, vad_method='silero'), | |
| } | |
| def split_audio_by_pause(audio, sr, pause_threshold, top_db=30, energy_threshold=0.03): | |
| intervals = librosa.effects.split(audio, top_db=top_db) | |
| merged_intervals = [] | |
| current_start, current_end = intervals[0] | |
| for start, end in intervals[1:]: | |
| gap_duration = (start - current_end) / sr | |
| if gap_duration < pause_threshold: | |
| current_end = end | |
| else: | |
| merged_intervals.append((current_start, current_end)) | |
| current_start, current_end = start, end | |
| merged_intervals.append((current_start, current_end)) | |
| # Filter out segments with low average RMS energy | |
| filtered_intervals = [] | |
| for start, end in merged_intervals: | |
| segment = audio[start:end] | |
| rms = np.mean(librosa.feature.rms(y=segment)) | |
| if rms >= energy_threshold: | |
| filtered_intervals.append((start, end)) | |
| return filtered_intervals | |
| def seconds_to_srt_time(seconds): | |
| msec_total = int(round(seconds * 1000)) | |
| hours, msec_remainder = divmod(msec_total, 3600 * 1000) | |
| minutes, msec_remainder = divmod(msec_remainder, 60 * 1000) | |
| sec, msec = divmod(msec_remainder, 1000) | |
| return f"{hours:02d}:{minutes:02d}:{sec:02d},{msec:03d}" | |
| # ------------------------------- | |
| # Main Transcription Function | |
| # ------------------------------- | |
| def transcribe(audio_file, model_size="base", debug=False, pause_threshold=0.0, vocal_extraction=False, language="en"): | |
| start_time = time.time() | |
| final_result = "" | |
| debug_log = [] | |
| srt_entries = [] | |
| try: | |
| # If vocal extraction is enabled, process the file first | |
| if vocal_extraction: | |
| debug_log.append("Vocal extraction enabled; processing input file for vocals...") | |
| extracted_url = get_vocals(audio_file) | |
| if extracted_url is not None: | |
| debug_log.append("Vocal extraction succeeded; downloading extracted audio...") | |
| response = requests.get(extracted_url) | |
| if response.status_code == 200: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: | |
| tmp.write(response.content) | |
| audio_file = tmp.name | |
| debug_log.append("Extracted audio downloaded and saved for transcription.") | |
| else: | |
| debug_log.append("Failed to download extracted audio; proceeding with original file.") | |
| else: | |
| debug_log.append("Vocal extraction failed; proceeding with original audio.") | |
| # Load audio file at 16kHz | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| debug_log.append(f"Audio loaded: {len(audio)/sr:.2f} seconds long at {sr} Hz") | |
| # Select the model and set batch size | |
| model = models[model_size] | |
| batch_size = 8 if model_size == "tiny" else 4 | |
| # Use provided language if set; otherwise, use language detection. | |
| if language: | |
| transcript = model.transcribe(audio, batch_size=batch_size, language=language) | |
| else: | |
| transcript = model.transcribe(audio, batch_size=batch_size) | |
| language = transcript.get("language", "unknown") | |
| # Load alignment model using the specified language | |
| model_a, metadata = whisperx.load_align_model(language_code=language, device=device) | |
| # If pause_threshold > 0, split audio and process segments individually | |
| if pause_threshold > 0: | |
| segments = split_audio_by_pause(audio, sr, pause_threshold) | |
| debug_log.append(f"Audio split into {len(segments)} segment(s) using a pause threshold of {pause_threshold}s") | |
| for seg_idx, (seg_start, seg_end) in enumerate(segments): | |
| audio_segment = audio[seg_start:seg_end] | |
| seg_duration = (seg_end - seg_start) / sr | |
| debug_log.append(f"Segment {seg_idx+1}: start={seg_start/sr:.2f}s, duration={seg_duration:.2f}s") | |
| seg_transcript = model.transcribe(audio_segment, batch_size=batch_size, language=language) | |
| seg_aligned = whisperx.align( | |
| seg_transcript["segments"], model_a, metadata, audio_segment, device | |
| ) | |
| for segment in seg_aligned["segments"]: | |
| for word in segment["words"]: | |
| adjusted_start = word['start'] + seg_start/sr | |
| adjusted_end = word['end'] + seg_start/sr | |
| srt_entries.append({ | |
| 'start': adjusted_start, | |
| 'end': adjusted_end, | |
| 'word': word['word'].strip() | |
| }) | |
| #final_result += f"[{adjusted_start:5.2f}s-{adjusted_end:5.2f}s] {word['word']}\n" | |
| else: | |
| # Process the entire audio without splitting | |
| transcript = model.transcribe(audio, batch_size=batch_size, language=language) | |
| aligned = whisperx.align( | |
| transcript["segments"], model_a, metadata, audio, device | |
| ) | |
| for segment in aligned["segments"]: | |
| for word in segment["words"]: | |
| #final_result += f"[{word['start']:5.2f}s-{word['end']:5.2f}s] {word['word']}\n" | |
| srt_entries.append({ | |
| 'start': word['start'], | |
| 'end': word['end'], | |
| 'word': word['word'].strip() | |
| }) | |
| srt_content = [] | |
| for idx, entry in enumerate(srt_entries, start=1): | |
| start_time_srt = seconds_to_srt_time(entry['start']) | |
| end_time_srt = seconds_to_srt_time(entry['end']) | |
| srt_content.append( | |
| f"{idx}\n" | |
| f"{start_time_srt} --> {end_time_srt}\n" | |
| f"{entry['word']}\n" | |
| ) | |
| final_result = "\n".join(srt_content) | |
| debug_log.append(f"Language used: {language}") | |
| debug_log.append(f"Batch size: {batch_size}") | |
| debug_log.append(f"Processed in {time.time()-start_time:.2f}s") | |
| except Exception as e: | |
| logger.error("Error during transcription:", exc_info=True) | |
| final_result = "Error occurred during transcription" | |
| debug_log.append(f"ERROR: {str(e)}") | |
| if debug: | |
| return final_result, "\n".join(debug_log) | |
| else: | |
| return final_result, "" | |
| # ------------------------------- | |
| # Gradio Interface | |
| # ------------------------------- | |
| with gr.Blocks(title="WhisperX CPU Transcription") as demo: | |
| gr.Markdown("# WhisperX CPU Transcription with Vocal Extraction Option") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| sources=["upload", "microphone"], | |
| interactive=True, | |
| ) | |
| model_selector = gr.Dropdown( | |
| choices=list(models.keys()), | |
| value="base", | |
| label="Model Size", | |
| interactive=True, | |
| ) | |
| pause_threshold_slider = gr.Slider( | |
| minimum=0, maximum=5, step=0.1, value=0, | |
| label="Pause Threshold (seconds)", | |
| interactive=True, | |
| info="Set a pause duration threshold. Audio pauses longer than this will be used to split the audio into segments." | |
| ) | |
| vocal_extraction_checkbox = gr.Checkbox( | |
| label="Extract Vocals (improves accuracy on noisy audio)", | |
| value=False | |
| ) | |
| language_input = gr.Textbox( | |
| label="Language Code (e.g., en, es, fr)", | |
| placeholder="Enter language code", | |
| value="en" | |
| ) | |
| debug_checkbox = gr.Checkbox(label="Enable Debug Mode", value=False) | |
| transcribe_btn = gr.Button("Transcribe", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Transcription Output", | |
| lines=20, | |
| placeholder="Transcription will appear here..." | |
| ) | |
| debug_output = gr.Textbox( | |
| label="Debug Information", | |
| lines=10, | |
| placeholder="Debug logs will appear here...", | |
| visible=False, | |
| ) | |
| def toggle_debug(debug_enabled): | |
| return gr.update(visible=debug_enabled) | |
| debug_checkbox.change( | |
| toggle_debug, | |
| inputs=[debug_checkbox], | |
| outputs=[debug_output] | |
| ) | |
| transcribe_btn.click( | |
| transcribe, | |
| inputs=[audio_input, model_selector, debug_checkbox, pause_threshold_slider, vocal_extraction_checkbox, language_input], | |
| outputs=[output_text, debug_output] | |
| ) | |
| # ------------------------------- | |
| # Launch the App | |
| # ------------------------------- | |
| if __name__ == "__main__": | |
| demo.queue(max_size=4).launch() | |