| | import os |
| | import torch |
| | import gradio as gr |
| | from fastapi import FastAPI, UploadFile, File |
| | from fastapi.responses import JSONResponse |
| | import uvicorn |
| | from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
| | import soundfile as sf |
| | import numpy as np |
| | import tempfile |
| |
|
| | |
| | app = FastAPI() |
| |
|
| | |
| | device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| | torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| |
|
| | model_id = "nyrahealth/CrisperWhisper" |
| |
|
| | model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| | model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True |
| | ) |
| | model.to(device) |
| |
|
| | processor = AutoProcessor.from_pretrained(model_id) |
| |
|
| | pipe = pipeline( |
| | "automatic-speech-recognition", |
| | model=model, |
| | tokenizer=processor.tokenizer, |
| | feature_extractor=processor.feature_extractor, |
| | chunk_length_s=30, |
| | batch_size=16, |
| | return_timestamps='word', |
| | torch_dtype=torch_dtype, |
| | device=device, |
| | ) |
| |
|
| | def adjust_pauses_for_hf_pipeline_output(pipeline_output, split_threshold=0.12): |
| | """ |
| | Adjust pause timings by distributing pauses up to the threshold evenly between adjacent words. |
| | """ |
| | adjusted_chunks = pipeline_output["chunks"].copy() |
| |
|
| | for i in range(len(adjusted_chunks) - 1): |
| | current_chunk = adjusted_chunks[i] |
| | next_chunk = adjusted_chunks[i + 1] |
| |
|
| | current_start, current_end = current_chunk["timestamp"] |
| | next_start, next_end = next_chunk["timestamp"] |
| | pause_duration = next_start - current_end |
| |
|
| | if pause_duration > 0: |
| | if pause_duration > split_threshold: |
| | distribute = split_threshold / 2 |
| | else: |
| | distribute = pause_duration / 2 |
| |
|
| | adjusted_chunks[i]["timestamp"] = (current_start, current_end + distribute) |
| | adjusted_chunks[i + 1]["timestamp"] = (next_start - distribute, next_end) |
| | |
| | pipeline_output["chunks"] = adjusted_chunks |
| | return pipeline_output |
| |
|
| | def process_audio(audio_path): |
| | """Process audio file and return transcription with timestamps""" |
| | try: |
| | |
| | audio_data, sample_rate = sf.read(audio_path) |
| | |
| | |
| | if len(audio_data.shape) > 1: |
| | audio_data = audio_data.mean(axis=1) |
| | |
| | |
| | result = pipe({"array": audio_data, "sampling_rate": sample_rate}) |
| | |
| | |
| | adjusted_result = adjust_pauses_for_hf_pipeline_output(result) |
| | |
| | return adjusted_result |
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|
| | |
| | @app.post("/transcribe") |
| | async def transcribe_audio(file: UploadFile = File(...)): |
| | try: |
| | |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp_file: |
| | content = await file.read() |
| | temp_file.write(content) |
| | temp_file_path = temp_file.name |
| |
|
| | |
| | result = process_audio(temp_file_path) |
| | |
| | |
| | os.unlink(temp_file_path) |
| | |
| | return JSONResponse(content=result) |
| | except Exception as e: |
| | return JSONResponse( |
| | status_code=500, |
| | content={"error": str(e)} |
| | ) |
| |
|
| | |
| | def gradio_transcribe(audio): |
| | if audio is None: |
| | return "Please upload an audio file" |
| | |
| | result = process_audio(audio) |
| | return result |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=gradio_transcribe, |
| | inputs=gr.Audio(type="filepath", label="Upload Audio (MP3 or WAV)"), |
| | outputs=gr.JSON(label="Transcription with Timestamps"), |
| | title="CrisperWhisper Audio Transcription", |
| | description="Upload an audio file to get transcription with word-level timestamps", |
| | examples=[], |
| | allow_flagging="never" |
| | ) |
| |
|
| | |
| | app = gr.mount_gradio_app(app, demo, path="/") |
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |