video_editing_mcp_server / deploy_whisper_on_modal.py
MalikIbrar's picture
Initial
1b544c3
import modal
from fastapi import Form, HTTPException
cuda_version = "12.4.0" # should be no greater than host CUDA version
flavor = "devel" # includes full CUDA toolkit
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11")
.apt_install(
"git",
"ffmpeg",
"libcudnn8",
"libcudnn8-dev",
)
.pip_install(
"fastapi[standard]",
"httpx",
"torch==2.0.0",
"torchaudio==2.0.0",
"numpy<2.0",
extra_index_url="https://download.pytorch.org/whl/cu118",
)
.pip_install(
"git+https://github.com/m-bain/whisperx.git@v3.2.0",
"ffmpeg-python",
"ctranslate2==4.4.0",
)
)
app = modal.App("whisperx-api", image=image)
GPU_CONFIG = "L4"
CACHE_DIR = "/cache"
cache_vol = modal.Volume.from_name("whisper-cache", create_if_missing=True)
@app.cls(
gpu=GPU_CONFIG,
volumes={CACHE_DIR: cache_vol},
scaledown_window=60 * 10,
timeout=60 * 60,
)
@modal.concurrent(max_inputs=15)
class Model:
@modal.enter()
def setup(self):
import whisperx
device = "cuda"
compute_type = (
"float16" # change to "int8" if low on GPU mem (may reduce accuracy)
)
# 1. Transcribe with original whisper (batched)
self.model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=CACHE_DIR)
@modal.method()
def transcribe(self, audio_url: str):
import requests
import whisperx
batch_size = 16
response = requests.get(audio_url)
audio_path = "downloaded_audio.wav"
with open(audio_path, "wb") as audio_file:
audio_file.write(response.content)
audio = whisperx.load_audio(audio_path)
result = self.model.transcribe(audio, batch_size=batch_size)
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device="cuda")
aligned_result = whisperx.align(result["segments"], model_a, metadata, audio_path, device="cuda")
results = {
"language": result["language"],
"language_probability": result.get("language_probability", 1.0),
"words": []
}
for word in aligned_result["word_segments"]:
results["words"].append({
"start": word["start"],
"end": word["end"],
"word": word["word"]
})
return results
@app.function()
@modal.fastapi_endpoint(docs=True, method="POST")
async def transcribe_endpoint(url: str = Form(...)):
if not url.startswith(("http://", "https://")):
raise HTTPException(status_code=400, detail="URL must start with http:// or https://")
return Model().transcribe.remote(audio_url=url)
# ## Run the model
@app.local_entrypoint()
def main():
url = "https://pub-ebe9e51393584bf5b5bea84a67b343c2.r2.dev/examples_english_english.wav"
print(Model().transcribe.remote(url))