| from fastapi import FastAPI, UploadFile, File, Form |
| import tempfile |
| import shutil |
| import uvicorn |
| import whisperx |
| import torch |
| import numpy as np |
| import soundfile as sf |
| from speechbrain.pretrained import EncoderClassifier |
|
|
| app = FastAPI() |
|
|
| device = "cpu" |
|
|
| |
| asr_model = whisperx.load_model("small", device) |
|
|
| speaker_model = EncoderClassifier.from_hparams( |
| source="speechbrain/spkrec-ecapa-voxceleb", |
| run_opts={"device": device} |
| ) |
|
|
| @app.post("/transcribe") |
| async def transcribe(audio: UploadFile = File(...), lang: str = Form("en")): |
|
|
| temp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
| with temp as buffer: |
| shutil.copyfileobj(audio.file, buffer) |
|
|
| audio_path = temp.name |
|
|
| |
| audio_data = whisperx.load_audio(audio_path) |
|
|
| |
| result = asr_model.transcribe(audio_data, language=lang) |
| segments = result["segments"] |
|
|
| y, sr = sf.read(audio_path) |
|
|
| speaker_embeddings = [] |
| speaker_labels = [] |
|
|
| final_segments = [] |
|
|
| for i, seg in enumerate(segments): |
|
|
| start = int(seg["start"] * sr) |
| end = int(seg["end"] * sr) |
|
|
| chunk = y[start:end] |
|
|
| if len(chunk) < sr * 0.5: |
| continue |
|
|
| chunk_tensor = torch.tensor(chunk).unsqueeze(0) |
|
|
| emb = speaker_model.encode_batch(chunk_tensor) |
| emb = emb.squeeze().detach().cpu().numpy() |
|
|
| |
| if len(speaker_embeddings) < 2: |
| speaker_id = f"SPEAKER_{len(speaker_embeddings)+1}" |
| speaker_embeddings.append(emb) |
| speaker_labels.append(speaker_id) |
| else: |
| sims = [] |
| for e in speaker_embeddings: |
| sim = np.dot(emb, e) / ( |
| np.linalg.norm(emb) * np.linalg.norm(e) |
| ) |
| sims.append(sim) |
|
|
| speaker_id = speaker_labels[np.argmax(sims)] |
|
|
| final_segments.append({ |
| "speaker": speaker_id, |
| "start": round(seg["start"], 2), |
| "end": round(seg["end"], 2), |
| "text": seg["text"] |
| }) |
|
|
| return {"segments": final_segments} |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |