Spaces:
Running on Zero
Running on Zero
| """ | |
| Parakeet TDT ZeroGPU Space for Cadayn/EagleEye | |
| Ultra-fast ASR using NVIDIA Parakeet TDT 0.6B v3. | |
| 3000x faster than Whisper - transcribes 1 hour in ~1 second on GPU. | |
| Speaker diarization runs in a separate ZeroGPU Space. | |
| """ | |
| import base64 | |
| import gc | |
| import os | |
| import tempfile | |
| import traceback | |
| import gradio as gr | |
| import requests | |
| import spaces | |
| import torch | |
| from nemo.collections.asr.models import ASRModel | |
| from pydub import AudioSegment | |
| MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v3" | |
| # NeMo Conformer attention needs ~1.9 GiB for 60s chunks. ZeroGPU only | |
| # leaves ~1.5 GiB free after loading the model (~13 GiB on 14.5 GiB GPU). | |
| # 30s chunks need ~0.5 GiB headroom which fits comfortably. | |
| CHUNK_DURATION_S = 30 | |
| OVERLAP_S = 5 | |
| print(f"Loading {MODEL_NAME} on CPU...") | |
| model = ASRModel.from_pretrained(model_name=MODEL_NAME) | |
| model.eval() | |
| print("Parakeet loaded on CPU (GPU allocated per-request via ZeroGPU)") | |
| def _extract_words(result) -> dict: | |
| """Extract word timestamps from a NeMo transcription result.""" | |
| words = [] | |
| if hasattr(result, "timestamp") and result.timestamp: | |
| word_stamps = result.timestamp.get("word", []) | |
| for stamp in word_stamps: | |
| words.append({ | |
| "start": float(stamp["start"]), | |
| "end": float(stamp["end"]), | |
| "text": stamp["word"], | |
| }) | |
| elif hasattr(result, "words") and result.words: | |
| for word_info in result.words: | |
| words.append({ | |
| "start": float(word_info.start), | |
| "end": float(word_info.end), | |
| "text": word_info.word, | |
| }) | |
| else: | |
| text = result.text if hasattr(result, "text") else str(result) | |
| words.append({"start": 0.0, "end": 0.0, "text": text}) | |
| text = result.text if hasattr(result, "text") else " ".join(w["text"] for w in words) | |
| return {"text": text, "words": words} | |
| def _disable_cuda_graphs(m) -> None: | |
| """Disable CUDA Graphs on decoding computer to avoid NeMo TDT CUDA failure 35.""" | |
| try: | |
| m.decoding.decoding.decoding_computer.disable_cuda_graphs() | |
| except AttributeError: | |
| pass | |
| try: | |
| m.decoding.decoding.disable_cuda_graphs() | |
| except AttributeError: | |
| pass | |
| def _transcribe_on_gpu(audio_path: str) -> dict: | |
| """Run model inference on ZeroGPU-allocated GPU.""" | |
| try: | |
| model.to("cuda") | |
| model.to(torch.float32) | |
| _disable_cuda_graphs(model) | |
| output = model.transcribe([audio_path], timestamps=True) | |
| result = output[0] | |
| return _extract_words(result) | |
| except Exception as e: | |
| return {"error": f"{type(e).__name__}: {e}\n{traceback.format_exc()}"} | |
| finally: | |
| model.cpu() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def _transcribe_chunks_on_gpu(chunk_paths: list[str]) -> list[dict]: | |
| """Transcribe multiple audio chunks sequentially on ZeroGPU-allocated GPU.""" | |
| try: | |
| model.to("cuda") | |
| model.to(torch.float32) | |
| _disable_cuda_graphs(model) | |
| results = [] | |
| for path in chunk_paths: | |
| output = model.transcribe([path], timestamps=True) | |
| results.append(_extract_words(output[0])) | |
| torch.cuda.empty_cache() | |
| return results | |
| finally: | |
| model.cpu() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def _transcribe_single(audio_path: str) -> dict: | |
| """Transcribe a single audio file using GPU.""" | |
| file_size = os.path.getsize(audio_path) | |
| print(f"[_transcribe_single] path={audio_path}, size={file_size}") | |
| result = _transcribe_on_gpu(audio_path) | |
| if "error" in result: | |
| raise RuntimeError(result["error"]) | |
| return result | |
| def _transcribe_chunked(audio: AudioSegment, audio_path: str) -> dict: | |
| """Transcribe long audio by splitting into chunks.""" | |
| chunk_ms = CHUNK_DURATION_S * 1000 | |
| overlap_ms = OVERLAP_S * 1000 | |
| chunk_paths: list[str] = [] | |
| chunk_offsets: list[float] = [] | |
| try: | |
| for start_ms in range(0, len(audio), chunk_ms - overlap_ms): | |
| end_ms = min(start_ms + chunk_ms, len(audio)) | |
| chunk = audio[start_ms:end_ms] | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| chunk_path = f.name | |
| chunk.export(chunk_path, format="wav") | |
| chunk_paths.append(chunk_path) | |
| chunk_offsets.append(start_ms / 1000) | |
| gpu_results = _transcribe_chunks_on_gpu(chunk_paths) | |
| all_words: list[dict] = [] | |
| all_text: list[str] = [] | |
| for i, chunk_result in enumerate(gpu_results): | |
| offset_s = chunk_offsets[i] | |
| for word in chunk_result["words"]: | |
| word["start"] += offset_s | |
| word["end"] += offset_s | |
| all_words.append(word) | |
| all_text.append(chunk_result["text"]) | |
| return {"text": " ".join(all_text), "words": all_words} | |
| finally: | |
| for path in chunk_paths: | |
| if os.path.exists(path): | |
| os.unlink(path) | |
| def _ensure_mono_wav(audio_path: str) -> tuple[str, float]: | |
| """Convert audio to mono 16kHz WAV for NeMo compatibility.""" | |
| audio = AudioSegment.from_file(audio_path) | |
| audio = audio.set_channels(1).set_frame_rate(16000) | |
| mono_path = audio_path + ".mono.wav" | |
| audio.export(mono_path, format="wav") | |
| return mono_path, len(audio) / 1000 | |
| def transcribe_audio(audio_file) -> str: | |
| """Transcribe audio file with word timestamps (for UI).""" | |
| mono_path = None | |
| try: | |
| if audio_file is None: | |
| return "Please upload an audio file." | |
| if isinstance(audio_file, str): | |
| audio_path = audio_file | |
| elif hasattr(audio_file, "name"): | |
| audio_path = audio_file.name | |
| else: | |
| return f"Error: Unexpected file type: {type(audio_file)}" | |
| file_size = os.path.getsize(audio_path) | |
| print(f"[transcribe_audio] input type={type(audio_file)}, path={audio_path}, size={file_size}") | |
| mono_path, duration_s = _ensure_mono_wav(audio_path) | |
| audio = AudioSegment.from_file(mono_path) | |
| print(f"[transcribe_audio] duration_s={duration_s:.1f}, CHUNK_DURATION_S={CHUNK_DURATION_S}, will_chunk={duration_s > CHUNK_DURATION_S}") | |
| if duration_s <= CHUNK_DURATION_S: | |
| result = _transcribe_single(mono_path) | |
| else: | |
| result = _transcribe_chunked(audio, mono_path) | |
| output_lines = [] | |
| for word in result["words"]: | |
| output_lines.append(f"[{word['start']:.2f}s - {word['end']:.2f}s] {word['text']}") | |
| return "\n".join(output_lines) if output_lines else result["text"] | |
| except Exception as e: | |
| return f"Error: {str(e)}\n{traceback.format_exc()}" | |
| finally: | |
| if mono_path and os.path.exists(mono_path): | |
| os.unlink(mono_path) | |
| def api_transcribe(audio_url: str, audio_base64: str) -> dict: | |
| """API endpoint for transcription from EagleEye.""" | |
| temp_file = None | |
| mono_path = None | |
| try: | |
| audio_path = None | |
| if audio_url and audio_url.strip(): | |
| response = requests.get(audio_url, timeout=120, stream=True) | |
| response.raise_for_status() | |
| suffix = ".mp3" | |
| content_type = response.headers.get("content-type", "") | |
| if "video" in content_type: | |
| suffix = ".mp4" | |
| elif "wav" in audio_url or "wav" in content_type: | |
| suffix = ".wav" | |
| temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) | |
| for chunk in response.iter_content(chunk_size=8192): | |
| temp_file.write(chunk) | |
| temp_file.close() | |
| audio_path = temp_file.name | |
| elif audio_base64 and audio_base64.strip(): | |
| audio_bytes = base64.b64decode(audio_base64) | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) | |
| temp_file.write(audio_bytes) | |
| temp_file.close() | |
| audio_path = temp_file.name | |
| else: | |
| return {"error": "No audio provided", "success": False} | |
| mono_path, duration_s = _ensure_mono_wav(audio_path) | |
| audio = AudioSegment.from_file(mono_path) | |
| audio_path = mono_path | |
| if duration_s <= CHUNK_DURATION_S: | |
| result = _transcribe_single(audio_path) | |
| else: | |
| result = _transcribe_chunked(audio, audio_path) | |
| api_words = [ | |
| { | |
| "start_s": w["start"], | |
| "end_s": w["end"], | |
| "text": w["text"], | |
| "speaker": None, | |
| } | |
| for w in result["words"] | |
| ] | |
| return { | |
| "success": True, | |
| "text": result["text"], | |
| "words": api_words, | |
| "speakers": [], | |
| "segments": [], | |
| "duration_s": duration_s, | |
| "model": MODEL_NAME, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "success": False, "traceback": traceback.format_exc()} | |
| finally: | |
| if temp_file and os.path.exists(temp_file.name): | |
| os.unlink(temp_file.name) | |
| if mono_path and os.path.exists(mono_path): | |
| os.unlink(mono_path) | |
| def api_transcribe_segment(audio_base64: str, start_s: float = 0.0, end_s: float | None = None) -> dict: | |
| """API endpoint for transcribing a short audio segment.""" | |
| temp_file = None | |
| mono_path = None | |
| try: | |
| audio_bytes = base64.b64decode(audio_base64) | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| temp_file.write(audio_bytes) | |
| temp_file.close() | |
| mono_path, _ = _ensure_mono_wav(temp_file.name) | |
| result = _transcribe_single(mono_path) | |
| segments = [ | |
| { | |
| "start_s": word["start"] + start_s, | |
| "end_s": word["end"] + start_s, | |
| "text": word["text"], | |
| "speaker": None, | |
| } | |
| for word in result["words"] | |
| ] | |
| return { | |
| "success": True, | |
| "text": result["text"], | |
| "words": segments, | |
| "speakers": [], | |
| "segments": [], | |
| "offset_applied": start_s, | |
| "model": MODEL_NAME, | |
| } | |
| except Exception as e: | |
| return {"error": str(e), "success": False, "traceback": traceback.format_exc()} | |
| finally: | |
| if temp_file and os.path.exists(temp_file.name): | |
| os.unlink(temp_file.name) | |
| if mono_path and os.path.exists(mono_path): | |
| os.unlink(mono_path) | |
| def health_check() -> dict: | |
| """Health check endpoint (no GPU required).""" | |
| return { | |
| "status": "ok", | |
| "model_name": MODEL_NAME, | |
| "model_loaded": model is not None, | |
| "diarization_available": False, | |
| "torch_version": torch.__version__, | |
| } | |
| # Main UI | |
| demo_ui = gr.Interface( | |
| fn=transcribe_audio, | |
| inputs=gr.File( | |
| label="Upload Audio/Video", | |
| file_types=[".mp3", ".wav", ".m4a", ".mp4", ".webm", ".ogg", ".flac"], | |
| ), | |
| outputs=gr.Textbox(label="Transcription with Word Timestamps", lines=20), | |
| title="Parakeet TDT ASR", | |
| description="NVIDIA Parakeet TDT 0.6B — ultra-fast speech-to-text with word timestamps.", | |
| ) | |
| # API interfaces | |
| api_transcribe_interface = gr.Interface( | |
| fn=api_transcribe, | |
| inputs=[ | |
| gr.Textbox(label="Audio URL"), | |
| gr.Textbox(label="Audio Base64"), | |
| ], | |
| outputs=gr.JSON(label="Response"), | |
| api_name="api_transcribe", | |
| title="API: Transcribe", | |
| ) | |
| api_segment_interface = gr.Interface( | |
| fn=api_transcribe_segment, | |
| inputs=[ | |
| gr.Textbox(label="Audio Base64"), | |
| gr.Number(label="Start Offset (s)", value=0.0), | |
| gr.Number(label="End Time (s)"), | |
| ], | |
| outputs=gr.JSON(label="Response"), | |
| api_name="api_transcribe_segment", | |
| title="API: Segment", | |
| ) | |
| health_interface = gr.Interface( | |
| fn=health_check, | |
| inputs=[], | |
| outputs=gr.JSON(label="Health"), | |
| api_name="health", | |
| title="Health", | |
| ) | |
| demo = gr.TabbedInterface( | |
| [demo_ui, api_transcribe_interface, api_segment_interface, health_interface], | |
| ["Transcribe", "API: Full", "API: Segment", "Health"], | |
| title="Parakeet ASR for Cadayn", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |