import os import torch import torchaudio import torchaudio.transforms as T from pyannote.audio import Pipeline from faster_whisper import WhisperModel from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import HTMLResponse import shutil from typing import Dict from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer # ========================================================= # 1. LOAD MODELS & CONFIG # ========================================================= print("Loading models...") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TARGET_SAMPLE_RATE = 16000 TRANSCRIPTION_MODEL_SIZE = "medium" # Diarization diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1") diarization_pipeline.to(torch.device(DEVICE)) # Transcription # Models should be preloaded in Docker build to ~/.cache/huggingface transcription_model = WhisperModel( TRANSCRIPTION_MODEL_SIZE, device=DEVICE, compute_type="int8" ) # Summarization tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6") summarization_model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6").to(DEVICE) # Sentiment analysis sentiment_pipeline = pipeline( "sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if DEVICE == "cuda" else -1 ) print("Models loaded successfully.") # ========================================================= # 2. AUDIO PROCESSING # ========================================================= def process_audio_file(file_path: str) -> Dict: segments_with_sentiment = [] full_text = [] # Load audio for resampling if needed waveform, sample_rate = torchaudio.load(file_path) # Resample if needed if sample_rate != TARGET_SAMPLE_RATE: resampler = T.Resample(sample_rate, TARGET_SAMPLE_RATE).to(DEVICE) waveform = resampler(waveform) sample_rate = TARGET_SAMPLE_RATE # Save resampled file temp_wav = os.path.join("/tmp", "temp_input.wav") torchaudio.save(temp_wav, waveform.cpu(), sample_rate) file_processing_path = temp_wav else: file_processing_path = file_path try: # 1. Transcribe full file (Context preserving) print("Starting transcription...") segments, _ = transcription_model.transcribe(file_processing_path, beam_size=5) # Convert generator to list to iterate multiple times if needed and to get full text segments = list(segments) # 2. Diarization print("Starting diarization...") # Pass pre-loaded waveform to avoid internal loading issues (missing torchcodec) diarization_input = {"waveform": waveform, "sample_rate": sample_rate} diarization = diarization_pipeline(diarization_input) # Compatibility with pyannote.audio 4.x if hasattr(diarization, "speaker_diarization"): diarization = diarization.speaker_diarization # 3. Align Transcription with Diarization print("Aligning...") for seg in segments: # Find dominant speaker for this segment seg_start = seg.start seg_end = seg.end # Get all speakers overlapping with this segment # diarization.crop returns a Segment/Annotation of the intersection # We can calculate overlap duration for each speaker overlap_counts = {} from pyannote.core import Segment # Simple alignment: query the annotation for the segment duration sub_ann = diarization.crop(Segment(seg_start, seg_end)) # Sum duration per label for segment, _, label in sub_ann.itertracks(yield_label=True): duration = segment.duration overlap_counts[label] = overlap_counts.get(label, 0.0) + duration # Pick speaker with max overlap, or "UNKNOWN" if no overlap if overlap_counts: speaker = max(overlap_counts, key=overlap_counts.get) else: speaker = "UNKNOWN" text = seg.text.strip() # Sentiment sentiment = sentiment_pipeline(text)[0] if text else {"label": "NEUTRAL"} segments_with_sentiment.append({ "start": round(seg_start, 2), "end": round(seg_end, 2), "speaker": speaker, "text": text, "sentiment": sentiment["label"] }) if text: full_text.append(f"{speaker}: {text}") # --- Summarization --- summary_result = "" combined_text = "\n".join(full_text) if combined_text: # Check length to avoid tokens limit error (brief truncation if needed) # BART has 1024 limit, we truncate in tokenizer inputs = tokenizer(combined_text, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE) summary_ids = summarization_model.generate( inputs["input_ids"], max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True ) summary_result = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return { "segments": segments_with_sentiment, "summary": summary_result } finally: # Cleanup temp file if created if file_processing_path.startswith("/tmp") and os.path.exists(file_processing_path): os.remove(file_processing_path) # ========================================================= # 3. FASTAPI APP # ========================================================= app = FastAPI() @app.get("/", response_class=HTMLResponse) async def index(): if not os.path.exists("index.html"): return HTMLResponse(content="

Upload audio via /transcribe endpoint

", status_code=200) with open("index.html") as f: return f.read() @app.post("/transcribe") async def transcribe(file: UploadFile = File(...)): # Use /tmp for upload storage temp_path = os.path.join("/tmp", f"temp_{file.filename}") try: with open(temp_path, "wb") as f: shutil.copyfileobj(file.file, f) return process_audio_file(temp_path) except Exception as e: print("ERROR:", e) # Clean error message raise HTTPException(status_code=500, detail=str(e)) finally: if os.path.exists(temp_path): os.remove(temp_path)