cheesecz's picture
Upload 3 files
536e766 verified
import os
import torch
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from transformers.pipelines.audio_utils import ffmpeg_read
# Set custom cache path
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.makedirs("/tmp/hf_cache", exist_ok=True)
MODEL_NAME = "openai/whisper-small"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model.to(device)
def adjust_pauses_for_hf_pipeline_output(segments, split_threshold=0.12):
for i in range(len(segments) - 1):
pause = segments[i + 1]["start"] - segments[i]["end"]
distribute = min(pause / 2, split_threshold / 2)
segments[i]["end"] += distribute
segments[i + 1]["start"] -= distribute
return segments
app = FastAPI()
@app.post("/speech2text")
async def speech2text(request: Request):
content_type = request.headers.get("Content-Type", "")
if content_type not in ["audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3"]:
raise HTTPException(status_code=400, detail="Only WAV or MP3 accepted.")
try:
body = await request.body()
array = ffmpeg_read(body, sampling_rate=16000)
inputs = processor(array, sampling_rate=16000, return_timestamps=True, return_tensors="pt")
input_ids = inputs.input_features.to(device)
generated_ids = model.generate(input_ids=input_ids)
output = processor.batch_decode(generated_ids, skip_special_tokens=False, output_word_offsets=True)[0]
chunks = []
for seg in output.get("chunks", []):
chunks.append({
"text": seg["text"],
"timestamp": [seg["timestamp"][0], seg["timestamp"][1]]
})
adjusted = adjust_pauses_for_hf_pipeline_output(chunks)
return JSONResponse(content={"text": output["text"], "chunks": adjusted})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Transcription error: {str(e)}")