Spaces:
Sleeping
Sleeping
File size: 6,761 Bytes
b40215d 5e6994d b40215d 50108e8 b40215d 50108e8 b40215d 50108e8 b40215d 50108e8 b40215d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import torch
import torchaudio
import torchaudio.transforms as T
from pyannote.audio import Pipeline
from faster_whisper import WhisperModel
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
import shutil
from typing import Dict
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
# =========================================================
# 1. LOAD MODELS & CONFIG
# =========================================================
print("Loading models...")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TARGET_SAMPLE_RATE = 16000
TRANSCRIPTION_MODEL_SIZE = "medium"
# Diarization
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
diarization_pipeline.to(torch.device(DEVICE))
# Transcription
# Models should be preloaded in Docker build to ~/.cache/huggingface
transcription_model = WhisperModel(
TRANSCRIPTION_MODEL_SIZE,
device=DEVICE,
compute_type="int8"
)
# Summarization
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6").to(DEVICE)
# Sentiment analysis
sentiment_pipeline = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=0 if DEVICE == "cuda" else -1
)
print("Models loaded successfully.")
# =========================================================
# 2. AUDIO PROCESSING
# =========================================================
def process_audio_file(file_path: str) -> Dict:
segments_with_sentiment = []
full_text = []
# Load audio for resampling if needed
waveform, sample_rate = torchaudio.load(file_path)
# Resample if needed
if sample_rate != TARGET_SAMPLE_RATE:
resampler = T.Resample(sample_rate, TARGET_SAMPLE_RATE).to(DEVICE)
waveform = resampler(waveform)
sample_rate = TARGET_SAMPLE_RATE
# Save resampled file
temp_wav = os.path.join("/tmp", "temp_input.wav")
torchaudio.save(temp_wav, waveform.cpu(), sample_rate)
file_processing_path = temp_wav
else:
file_processing_path = file_path
try:
# 1. Transcribe full file (Context preserving)
print("Starting transcription...")
segments, _ = transcription_model.transcribe(file_processing_path, beam_size=5)
# Convert generator to list to iterate multiple times if needed and to get full text
segments = list(segments)
# 2. Diarization
print("Starting diarization...")
# Pass pre-loaded waveform to avoid internal loading issues (missing torchcodec)
diarization_input = {"waveform": waveform, "sample_rate": sample_rate}
diarization = diarization_pipeline(diarization_input)
# Compatibility with pyannote.audio 4.x
if hasattr(diarization, "speaker_diarization"):
diarization = diarization.speaker_diarization
# 3. Align Transcription with Diarization
print("Aligning...")
for seg in segments:
# Find dominant speaker for this segment
seg_start = seg.start
seg_end = seg.end
# Get all speakers overlapping with this segment
# diarization.crop returns a Segment/Annotation of the intersection
# We can calculate overlap duration for each speaker
overlap_counts = {}
from pyannote.core import Segment
# Simple alignment: query the annotation for the segment duration
sub_ann = diarization.crop(Segment(seg_start, seg_end))
# Sum duration per label
for segment, _, label in sub_ann.itertracks(yield_label=True):
duration = segment.duration
overlap_counts[label] = overlap_counts.get(label, 0.0) + duration
# Pick speaker with max overlap, or "UNKNOWN" if no overlap
if overlap_counts:
speaker = max(overlap_counts, key=overlap_counts.get)
else:
speaker = "UNKNOWN"
text = seg.text.strip()
# Sentiment
sentiment = sentiment_pipeline(text)[0] if text else {"label": "NEUTRAL"}
segments_with_sentiment.append({
"start": round(seg_start, 2),
"end": round(seg_end, 2),
"speaker": speaker,
"text": text,
"sentiment": sentiment["label"]
})
if text:
full_text.append(f"{speaker}: {text}")
# --- Summarization ---
summary_result = ""
combined_text = "\n".join(full_text)
if combined_text:
# Check length to avoid tokens limit error (brief truncation if needed)
# BART has 1024 limit, we truncate in tokenizer
inputs = tokenizer(combined_text, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
summary_ids = summarization_model.generate(
inputs["input_ids"],
max_length=150,
min_length=30,
length_penalty=2.0,
num_beams=4,
early_stopping=True
)
summary_result = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return {
"segments": segments_with_sentiment,
"summary": summary_result
}
finally:
# Cleanup temp file if created
if file_processing_path.startswith("/tmp") and os.path.exists(file_processing_path):
os.remove(file_processing_path)
# =========================================================
# 3. FASTAPI APP
# =========================================================
app = FastAPI()
@app.get("/", response_class=HTMLResponse)
async def index():
if not os.path.exists("index.html"):
return HTMLResponse(content="<h2>Upload audio via /transcribe endpoint</h2>", status_code=200)
with open("index.html") as f:
return f.read()
@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
# Use /tmp for upload storage
temp_path = os.path.join("/tmp", f"temp_{file.filename}")
try:
with open(temp_path, "wb") as f:
shutil.copyfileobj(file.file, f)
return process_audio_file(temp_path)
except Exception as e:
print("ERROR:", e)
# Clean error message
raise HTTPException(status_code=500, detail=str(e))
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
|