Spaces:
Build error
Build error
| import os | |
| import time | |
| import json | |
| import random | |
| import string | |
| import pathlib | |
| import tempfile | |
| import logging | |
| import torch | |
| import whisperx | |
| import librosa | |
| import numpy as np | |
| import requests | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| app = FastAPI(title="WhisperX API") | |
| # ------------------------------- | |
| # Logging and Model Setup | |
| # ------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("whisperx_api") | |
| device = "cpu" | |
| compute_type = "int8" | |
| torch.set_num_threads(os.cpu_count()) | |
| # Pre-load models for different sizes | |
| models = { | |
| "tiny": whisperx.load_model("tiny", device, compute_type=compute_type, vad_method='silero'), | |
| "base": whisperx.load_model("base", device, compute_type=compute_type, vad_method='silero'), | |
| "small": whisperx.load_model("small", device, compute_type=compute_type, vad_method='silero'), | |
| "large": whisperx.load_model("large", device, compute_type=compute_type, vad_method='silero'), | |
| "large-v2": whisperx.load_model("large-v2", device, compute_type=compute_type, vad_method='silero'), | |
| "large-v3": whisperx.load_model("large-v3", device, compute_type=compute_type, vad_method='silero'), | |
| } | |
| def seconds_to_srt_time(seconds: float) -> str: | |
| """Convert seconds (float) into SRT timestamp format (HH:MM:SS,mmm).""" | |
| hours = int(seconds // 3600) | |
| minutes = int((seconds % 3600) // 60) | |
| secs = int(seconds % 60) | |
| millis = int((seconds - int(seconds)) * 1000) | |
| return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}" | |
| # ------------------------------- | |
| # Vocal Extraction Function | |
| # ------------------------------- | |
| def get_vocals(input_file): | |
| try: | |
| session_hash = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11)) | |
| file_id = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(11)) | |
| file_content = pathlib.Path(input_file).read_bytes() | |
| file_len = len(file_content) | |
| r = requests.post( | |
| f'https://politrees-audio-separator-uvr.hf.space/gradio_api/upload?upload_id={file_id}', | |
| files={'files': open(input_file, 'rb')} | |
| ) | |
| json_data = r.json() | |
| headers = { | |
| 'accept': '*/*', | |
| 'accept-language': 'en-US,en;q=0.5', | |
| 'content-type': 'application/json', | |
| 'origin': 'https://politrees-audio-separator-uvr.hf.space', | |
| 'priority': 'u=1, i', | |
| 'referer': 'https://politrees-audio-separator-uvr.hf.space/?__theme=system', | |
| 'sec-ch-ua': '"Not(A:Brand";v="99", "Brave";v="133", "Chromium";v="133"', | |
| 'sec-ch-ua-mobile': '?0', | |
| 'sec-ch-ua-platform': '"Windows"', | |
| 'sec-fetch-dest': 'empty', | |
| 'sec-fetch-mode': 'cors', | |
| 'sec-fetch-site': 'same-origin', | |
| 'sec-fetch-storage-access': 'none', | |
| 'sec-gpc': '1', | |
| 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36', | |
| } | |
| params = { | |
| '__theme': 'system', | |
| } | |
| json_payload = { | |
| 'data': [ | |
| { | |
| 'path': json_data[0], | |
| 'url': 'https://politrees-audio-separator-uvr.hf.space/gradio_api/file=' + json_data[0], | |
| 'orig_name': pathlib.Path(input_file).name, | |
| 'size': file_len, | |
| 'mime_type': 'audio/wav', | |
| 'meta': {'_type': 'gradio.FileData'}, | |
| }, | |
| 'MelBand Roformer | Vocals by Kimberley Jensen', | |
| 256, | |
| False, | |
| 5, | |
| 0, | |
| '/tmp/audio-separator-models/', | |
| 'output', | |
| 'wav', | |
| 0.9, | |
| 0, | |
| 1, | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| 'NAME_(STEM)_MODEL', | |
| ], | |
| 'event_data': None, | |
| 'fn_index': 5, | |
| 'trigger_id': 28, | |
| 'session_hash': session_hash, | |
| } | |
| response = requests.post( | |
| 'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/join', | |
| params=params, | |
| headers=headers, | |
| json=json_payload, | |
| ) | |
| max_retries = 5 | |
| retry_delay = 5 | |
| retry_count = 0 | |
| while retry_count < max_retries: | |
| try: | |
| logger.info(f"Connecting to stream... Attempt {retry_count + 1}") | |
| r = requests.get( | |
| f'https://politrees-audio-separator-uvr.hf.space/gradio_api/queue/data?session_hash={session_hash}', | |
| stream=True | |
| ) | |
| if r.status_code != 200: | |
| raise Exception(f"Failed to connect: HTTP {r.status_code}") | |
| logger.info("Connected successfully.") | |
| for line in r.iter_lines(): | |
| if line: | |
| json_resp = json.loads(line.decode('utf-8').replace('data: ', '')) | |
| logger.info(json_resp) | |
| if 'process_completed' in json_resp['msg']: | |
| logger.info("Process completed.") | |
| output_url = json_resp['output']['data'][1]['url'] | |
| logger.info(f"Output URL: {output_url}") | |
| return output_url | |
| logger.info("Stream ended prematurely. Reconnecting...") | |
| except Exception as e: | |
| logger.error(f"Error occurred: {e}. Retrying...") | |
| retry_count += 1 | |
| time.sleep(retry_delay) | |
| logger.error("Max retries reached. Exiting.") | |
| return None | |
| except Exception as ex: | |
| logger.error(f"Unexpected error in get_vocals: {ex}") | |
| return None | |
| def split_audio_by_pause(audio, sr, pause_threshold, top_db=30, energy_threshold=0.03): | |
| intervals = librosa.effects.split(audio, top_db=top_db) | |
| merged_intervals = [] | |
| current_start, current_end = intervals[0] | |
| for start, end in intervals[1:]: | |
| gap_duration = (start - current_end) / sr | |
| if gap_duration < pause_threshold: | |
| current_end = end | |
| else: | |
| merged_intervals.append((current_start, current_end)) | |
| current_start, current_end = start, end | |
| merged_intervals.append((current_start, current_end)) | |
| # Filter out segments with low average RMS energy | |
| filtered_intervals = [] | |
| for start, end in merged_intervals: | |
| segment = audio[start:end] | |
| rms = np.mean(librosa.feature.rms(y=segment)) | |
| if rms >= energy_threshold: | |
| filtered_intervals.append((start, end)) | |
| return filtered_intervals | |
| # ------------------------------- | |
| # Main Transcription Function | |
| # ------------------------------- | |
| def transcribe(audio_file, model_size="base", debug=False, pause_threshold=0.0, vocal_extraction=False, language="en"): | |
| start_time = time.time() | |
| srt_output = "" | |
| debug_log = [] | |
| subtitle_index = 1 | |
| try: | |
| # Optionally extract vocals first | |
| if vocal_extraction: | |
| debug_log.append("Vocal extraction enabled; processing input file for vocals...") | |
| extracted_url = get_vocals(audio_file) | |
| if extracted_url is not None: | |
| debug_log.append("Vocal extraction succeeded; downloading extracted audio...") | |
| response = requests.get(extracted_url) | |
| if response.status_code == 200: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: | |
| tmp.write(response.content) | |
| audio_file = tmp.name | |
| debug_log.append("Extracted audio downloaded and saved for transcription.") | |
| else: | |
| debug_log.append("Failed to download extracted audio; proceeding with original file.") | |
| else: | |
| debug_log.append("Vocal extraction failed; proceeding with original audio.") | |
| # Load audio file (resampled to 16kHz) | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| debug_log.append(f"Audio loaded: {len(audio)/sr:.2f} seconds at {sr} Hz") | |
| # Select model and set batch size | |
| model = models[model_size] | |
| batch_size = 8 if model_size == "tiny" else 4 | |
| # Transcribe using specified language (or auto-detect) | |
| if language: | |
| transcript = model.transcribe(audio, batch_size=batch_size, language=language) | |
| else: | |
| transcript = model.transcribe(audio, batch_size=batch_size) | |
| language = transcript.get("language", "unknown") | |
| # Load alignment model for the given language | |
| model_a, metadata = whisperx.load_align_model(language_code=language, device=device) | |
| if pause_threshold > 0: | |
| segments = split_audio_by_pause(audio, sr, pause_threshold) | |
| debug_log.append(f"Audio split into {len(segments)} segment(s) using pause threshold of {pause_threshold}s") | |
| for seg_idx, (seg_start, seg_end) in enumerate(segments): | |
| audio_segment = audio[seg_start:seg_end] | |
| seg_duration = (seg_end - seg_start) / sr | |
| debug_log.append(f"Segment {seg_idx+1}: start={seg_start/sr:.2f}s, duration={seg_duration:.2f}s") | |
| seg_transcript = model.transcribe(audio_segment, batch_size=batch_size, language=language) | |
| seg_aligned = whisperx.align( | |
| seg_transcript["segments"], model_a, metadata, audio_segment, device | |
| ) | |
| for segment in seg_aligned["segments"]: | |
| for word in segment["words"]: | |
| adjusted_start = word['start'] + seg_start/sr | |
| adjusted_end = word['end'] + seg_start/sr | |
| start_timestamp = seconds_to_srt_time(adjusted_start) | |
| end_timestamp = seconds_to_srt_time(adjusted_end) | |
| srt_output += f"{subtitle_index}\n{start_timestamp} --> {end_timestamp}\n{word['word']}\n\n" | |
| subtitle_index += 1 | |
| else: | |
| # Process the entire audio without splitting | |
| transcript = model.transcribe(audio, batch_size=batch_size, language=language) | |
| aligned = whisperx.align( | |
| transcript["segments"], model_a, metadata, audio, device | |
| ) | |
| for segment in aligned["segments"]: | |
| for word in segment["words"]: | |
| start_timestamp = seconds_to_srt_time(word['start']) | |
| end_timestamp = seconds_to_srt_time(word['end']) | |
| srt_output += f"{subtitle_index}\n{start_timestamp} --> {end_timestamp}\n{word['word']}\n\n" | |
| subtitle_index += 1 | |
| debug_log.append(f"Language used: {language}") | |
| debug_log.append(f"Batch size: {batch_size}") | |
| debug_log.append(f"Processed in {time.time()-start_time:.2f}s") | |
| except Exception as e: | |
| logger.error("Error during transcription:", exc_info=True) | |
| srt_output = "Error occurred during transcription" | |
| debug_log.append(f"ERROR: {str(e)}") | |
| if debug: | |
| return srt_output, "\n".join(debug_log) | |
| return srt_output | |
| # ------------------------------- | |
| # FastAPI Endpoints | |
| # ------------------------------- | |
| async def transcribe_endpoint( | |
| audio_file: UploadFile = File(...), | |
| model_size: str = Form("base"), | |
| debug: bool = Form(False), | |
| pause_threshold: float = Form(0.0), | |
| vocal_extraction: bool = Form(False), | |
| language: str = Form("en") | |
| ): | |
| try: | |
| # Save the uploaded file to a temporary location | |
| suffix = pathlib.Path(audio_file.filename).suffix | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| tmp.write(await audio_file.read()) | |
| tmp_path = tmp.name | |
| result = transcribe(tmp_path, model_size=model_size, debug=debug, | |
| pause_threshold=pause_threshold, | |
| vocal_extraction=vocal_extraction, | |
| language=language) | |
| os.remove(tmp_path) | |
| if debug: | |
| srt_text, debug_info = result | |
| return JSONResponse(content={"srt": srt_text, "debug": debug_info}) | |
| else: | |
| return JSONResponse(content={"srt": result}) | |
| except Exception as e: | |
| logger.error(f"Error in transcribe_endpoint: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def root(): | |
| return {"message": "WhisperX API is running."} | |