transcription / api.py
AI Assistant
Optimize deployment: preload models and disable experimental transfer
5e6994d
import os
import torch
import torchaudio
import torchaudio.transforms as T
from pyannote.audio import Pipeline
from faster_whisper import WhisperModel
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
import shutil
from typing import Dict
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
# =========================================================
# 1. LOAD MODELS & CONFIG
# =========================================================
print("Loading models...")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TARGET_SAMPLE_RATE = 16000
TRANSCRIPTION_MODEL_SIZE = "medium"
# Diarization
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
diarization_pipeline.to(torch.device(DEVICE))
# Transcription
# Models should be preloaded in Docker build to ~/.cache/huggingface
transcription_model = WhisperModel(
TRANSCRIPTION_MODEL_SIZE,
device=DEVICE,
compute_type="int8"
)
# Summarization
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6").to(DEVICE)
# Sentiment analysis
sentiment_pipeline = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=0 if DEVICE == "cuda" else -1
)
print("Models loaded successfully.")
# =========================================================
# 2. AUDIO PROCESSING
# =========================================================
def process_audio_file(file_path: str) -> Dict:
segments_with_sentiment = []
full_text = []
# Load audio for resampling if needed
waveform, sample_rate = torchaudio.load(file_path)
# Resample if needed
if sample_rate != TARGET_SAMPLE_RATE:
resampler = T.Resample(sample_rate, TARGET_SAMPLE_RATE).to(DEVICE)
waveform = resampler(waveform)
sample_rate = TARGET_SAMPLE_RATE
# Save resampled file
temp_wav = os.path.join("/tmp", "temp_input.wav")
torchaudio.save(temp_wav, waveform.cpu(), sample_rate)
file_processing_path = temp_wav
else:
file_processing_path = file_path
try:
# 1. Transcribe full file (Context preserving)
print("Starting transcription...")
segments, _ = transcription_model.transcribe(file_processing_path, beam_size=5)
# Convert generator to list to iterate multiple times if needed and to get full text
segments = list(segments)
# 2. Diarization
print("Starting diarization...")
# Pass pre-loaded waveform to avoid internal loading issues (missing torchcodec)
diarization_input = {"waveform": waveform, "sample_rate": sample_rate}
diarization = diarization_pipeline(diarization_input)
# Compatibility with pyannote.audio 4.x
if hasattr(diarization, "speaker_diarization"):
diarization = diarization.speaker_diarization
# 3. Align Transcription with Diarization
print("Aligning...")
for seg in segments:
# Find dominant speaker for this segment
seg_start = seg.start
seg_end = seg.end
# Get all speakers overlapping with this segment
# diarization.crop returns a Segment/Annotation of the intersection
# We can calculate overlap duration for each speaker
overlap_counts = {}
from pyannote.core import Segment
# Simple alignment: query the annotation for the segment duration
sub_ann = diarization.crop(Segment(seg_start, seg_end))
# Sum duration per label
for segment, _, label in sub_ann.itertracks(yield_label=True):
duration = segment.duration
overlap_counts[label] = overlap_counts.get(label, 0.0) + duration
# Pick speaker with max overlap, or "UNKNOWN" if no overlap
if overlap_counts:
speaker = max(overlap_counts, key=overlap_counts.get)
else:
speaker = "UNKNOWN"
text = seg.text.strip()
# Sentiment
sentiment = sentiment_pipeline(text)[0] if text else {"label": "NEUTRAL"}
segments_with_sentiment.append({
"start": round(seg_start, 2),
"end": round(seg_end, 2),
"speaker": speaker,
"text": text,
"sentiment": sentiment["label"]
})
if text:
full_text.append(f"{speaker}: {text}")
# --- Summarization ---
summary_result = ""
combined_text = "\n".join(full_text)
if combined_text:
# Check length to avoid tokens limit error (brief truncation if needed)
# BART has 1024 limit, we truncate in tokenizer
inputs = tokenizer(combined_text, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
summary_ids = summarization_model.generate(
inputs["input_ids"],
max_length=150,
min_length=30,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
summary_result = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {
"segments": segments_with_sentiment,
"summary": summary_result
}
finally:
# Cleanup temp file if created
if file_processing_path.startswith("/tmp") and os.path.exists(file_processing_path):
os.remove(file_processing_path)
# =========================================================
# 3. FASTAPI APP
# =========================================================
app = FastAPI()
@app.get("/", response_class=HTMLResponse)
async def index():
if not os.path.exists("index.html"):
return HTMLResponse(content="<h2>Upload audio via /transcribe endpoint</h2>", status_code=200)
with open("index.html") as f:
return f.read()
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
# Use /tmp for upload storage
temp_path = os.path.join("/tmp", f"temp_{file.filename}")
try:
with open(temp_path, "wb") as f:
shutil.copyfileobj(file.file, f)
return process_audio_file(temp_path)
except Exception as e:
print("ERROR:", e)
# Clean error message
raise HTTPException(status_code=500, detail=str(e))
finally:
if os.path.exists(temp_path):
os.remove(temp_path)