Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| from pyannote.audio import Pipeline | |
| from faster_whisper import WhisperModel | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| import shutil | |
| from typing import Dict | |
| from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
| # ========================================================= | |
| # 1. LOAD MODELS & CONFIG | |
| # ========================================================= | |
| print("Loading models...") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TARGET_SAMPLE_RATE = 16000 | |
| TRANSCRIPTION_MODEL_SIZE = "medium" | |
| # Diarization | |
| diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1") | |
| diarization_pipeline.to(torch.device(DEVICE)) | |
| # Transcription | |
| # Models should be preloaded in Docker build to ~/.cache/huggingface | |
| transcription_model = WhisperModel( | |
| TRANSCRIPTION_MODEL_SIZE, | |
| device=DEVICE, | |
| compute_type="int8" | |
| ) | |
| # Summarization | |
| tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6") | |
| summarization_model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6").to(DEVICE) | |
| # Sentiment analysis | |
| sentiment_pipeline = pipeline( | |
| "sentiment-analysis", | |
| model="distilbert-base-uncased-finetuned-sst-2-english", | |
| device=0 if DEVICE == "cuda" else -1 | |
| ) | |
| print("Models loaded successfully.") | |
| # ========================================================= | |
| # 2. AUDIO PROCESSING | |
| # ========================================================= | |
| def process_audio_file(file_path: str) -> Dict: | |
| segments_with_sentiment = [] | |
| full_text = [] | |
| # Load audio for resampling if needed | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| # Resample if needed | |
| if sample_rate != TARGET_SAMPLE_RATE: | |
| resampler = T.Resample(sample_rate, TARGET_SAMPLE_RATE).to(DEVICE) | |
| waveform = resampler(waveform) | |
| sample_rate = TARGET_SAMPLE_RATE | |
| # Save resampled file | |
| temp_wav = os.path.join("/tmp", "temp_input.wav") | |
| torchaudio.save(temp_wav, waveform.cpu(), sample_rate) | |
| file_processing_path = temp_wav | |
| else: | |
| file_processing_path = file_path | |
| try: | |
| # 1. Transcribe full file (Context preserving) | |
| print("Starting transcription...") | |
| segments, _ = transcription_model.transcribe(file_processing_path, beam_size=5) | |
| # Convert generator to list to iterate multiple times if needed and to get full text | |
| segments = list(segments) | |
| # 2. Diarization | |
| print("Starting diarization...") | |
| # Pass pre-loaded waveform to avoid internal loading issues (missing torchcodec) | |
| diarization_input = {"waveform": waveform, "sample_rate": sample_rate} | |
| diarization = diarization_pipeline(diarization_input) | |
| # Compatibility with pyannote.audio 4.x | |
| if hasattr(diarization, "speaker_diarization"): | |
| diarization = diarization.speaker_diarization | |
| # 3. Align Transcription with Diarization | |
| print("Aligning...") | |
| for seg in segments: | |
| # Find dominant speaker for this segment | |
| seg_start = seg.start | |
| seg_end = seg.end | |
| # Get all speakers overlapping with this segment | |
| # diarization.crop returns a Segment/Annotation of the intersection | |
| # We can calculate overlap duration for each speaker | |
| overlap_counts = {} | |
| from pyannote.core import Segment | |
| # Simple alignment: query the annotation for the segment duration | |
| sub_ann = diarization.crop(Segment(seg_start, seg_end)) | |
| # Sum duration per label | |
| for segment, _, label in sub_ann.itertracks(yield_label=True): | |
| duration = segment.duration | |
| overlap_counts[label] = overlap_counts.get(label, 0.0) + duration | |
| # Pick speaker with max overlap, or "UNKNOWN" if no overlap | |
| if overlap_counts: | |
| speaker = max(overlap_counts, key=overlap_counts.get) | |
| else: | |
| speaker = "UNKNOWN" | |
| text = seg.text.strip() | |
| # Sentiment | |
| sentiment = sentiment_pipeline(text)[0] if text else {"label": "NEUTRAL"} | |
| segments_with_sentiment.append({ | |
| "start": round(seg_start, 2), | |
| "end": round(seg_end, 2), | |
| "speaker": speaker, | |
| "text": text, | |
| "sentiment": sentiment["label"] | |
| }) | |
| if text: | |
| full_text.append(f"{speaker}: {text}") | |
| # --- Summarization --- | |
| summary_result = "" | |
| combined_text = "\n".join(full_text) | |
| if combined_text: | |
| # Check length to avoid tokens limit error (brief truncation if needed) | |
| # BART has 1024 limit, we truncate in tokenizer | |
| inputs = tokenizer(combined_text, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE) | |
| summary_ids = summarization_model.generate( | |
| inputs["input_ids"], | |
| max_length=150, | |
| min_length=30, | |
| length_penalty=2.0, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| summary_result = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| return { | |
| "segments": segments_with_sentiment, | |
| "summary": summary_result | |
| } | |
| finally: | |
| # Cleanup temp file if created | |
| if file_processing_path.startswith("/tmp") and os.path.exists(file_processing_path): | |
| os.remove(file_processing_path) | |
| # ========================================================= | |
| # 3. FASTAPI APP | |
| # ========================================================= | |
| app = FastAPI() | |
| async def index(): | |
| if not os.path.exists("index.html"): | |
| return HTMLResponse(content="<h2>Upload audio via /transcribe endpoint</h2>", status_code=200) | |
| with open("index.html") as f: | |
| return f.read() | |
| async def transcribe(file: UploadFile = File(...)): | |
| # Use /tmp for upload storage | |
| temp_path = os.path.join("/tmp", f"temp_{file.filename}") | |
| try: | |
| with open(temp_path, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| return process_audio_file(temp_path) | |
| except Exception as e: | |
| print("ERROR:", e) | |
| # Clean error message | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |