File size: 6,761 Bytes
b40215d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e6994d
b40215d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50108e8
b40215d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50108e8
 
b40215d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50108e8
 
b40215d
 
 
 
 
 
50108e8
b40215d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import torch
import torchaudio
import torchaudio.transforms as T
from pyannote.audio import Pipeline
from faster_whisper import WhisperModel
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse
import shutil
from typing import Dict
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer

# =========================================================
# 1. LOAD MODELS & CONFIG
# =========================================================

print("Loading models...")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TARGET_SAMPLE_RATE = 16000
TRANSCRIPTION_MODEL_SIZE = "medium"

# Diarization
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
diarization_pipeline.to(torch.device(DEVICE))

# Transcription
# Models should be preloaded in Docker build to ~/.cache/huggingface
transcription_model = WhisperModel(
    TRANSCRIPTION_MODEL_SIZE,
    device=DEVICE,
    compute_type="int8"
)

# Summarization
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6").to(DEVICE)

# Sentiment analysis
sentiment_pipeline = pipeline(
    "sentiment-analysis",
    model="distilbert-base-uncased-finetuned-sst-2-english",
    device=0 if DEVICE == "cuda" else -1
)

print("Models loaded successfully.")

# =========================================================
# 2. AUDIO PROCESSING
# =========================================================


def process_audio_file(file_path: str) -> Dict:
    segments_with_sentiment = []
    full_text = []

    # Load audio for resampling if needed
    waveform, sample_rate = torchaudio.load(file_path)
    
    # Resample if needed
    if sample_rate != TARGET_SAMPLE_RATE:
        resampler = T.Resample(sample_rate, TARGET_SAMPLE_RATE).to(DEVICE)
        waveform = resampler(waveform)
        sample_rate = TARGET_SAMPLE_RATE
        # Save resampled file
        temp_wav = os.path.join("/tmp", "temp_input.wav")
        torchaudio.save(temp_wav, waveform.cpu(), sample_rate)
        file_processing_path = temp_wav
    else:
        file_processing_path = file_path

    try:
        # 1. Transcribe full file (Context preserving)
        print("Starting transcription...")
        segments, _ = transcription_model.transcribe(file_processing_path, beam_size=5)
        # Convert generator to list to iterate multiple times if needed and to get full text
        segments = list(segments)
        
        # 2. Diarization
        print("Starting diarization...")
        # Pass pre-loaded waveform to avoid internal loading issues (missing torchcodec)
        diarization_input = {"waveform": waveform, "sample_rate": sample_rate}
        diarization = diarization_pipeline(diarization_input)
        # Compatibility with pyannote.audio 4.x
        if hasattr(diarization, "speaker_diarization"):
            diarization = diarization.speaker_diarization
            
        # 3. Align Transcription with Diarization
        print("Aligning...")
        for seg in segments:
            # Find dominant speaker for this segment
            seg_start = seg.start
            seg_end = seg.end
            
            # Get all speakers overlapping with this segment
            # diarization.crop returns a Segment/Annotation of the intersection
            # We can calculate overlap duration for each speaker
            overlap_counts = {}
            from pyannote.core import Segment
            
            # Simple alignment: query the annotation for the segment duration
            sub_ann = diarization.crop(Segment(seg_start, seg_end))
            
            # Sum duration per label
            for segment, _, label in sub_ann.itertracks(yield_label=True):
                duration = segment.duration
                overlap_counts[label] = overlap_counts.get(label, 0.0) + duration
                
            # Pick speaker with max overlap, or "UNKNOWN" if no overlap
            if overlap_counts:
                speaker = max(overlap_counts, key=overlap_counts.get)
            else:
                speaker = "UNKNOWN"
            
            text = seg.text.strip()
            
            # Sentiment
            sentiment = sentiment_pipeline(text)[0] if text else {"label": "NEUTRAL"}
            
            segments_with_sentiment.append({
                "start": round(seg_start, 2),
                "end": round(seg_end, 2),
                "speaker": speaker,
                "text": text,
                "sentiment": sentiment["label"]
            })
            
            if text:
                full_text.append(f"{speaker}: {text}")

        # --- Summarization ---
        summary_result = ""
        combined_text = "\n".join(full_text)
        if combined_text:
            # Check length to avoid tokens limit error (brief truncation if needed)
            # BART has 1024 limit, we truncate in tokenizer
            inputs = tokenizer(combined_text, return_tensors="pt", max_length=1024, truncation=True).to(DEVICE)
            summary_ids = summarization_model.generate(
                inputs["input_ids"],
                max_length=150,
                min_length=30,
                length_penalty=2.0,
                num_beams=4,
                early_stopping=True
            )
            summary_result = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

        return {
            "segments": segments_with_sentiment,
            "summary": summary_result
        }

    finally:
        # Cleanup temp file if created
        if file_processing_path.startswith("/tmp") and os.path.exists(file_processing_path):
            os.remove(file_processing_path)


# =========================================================
# 3. FASTAPI APP
# =========================================================

app = FastAPI()

@app.get("/", response_class=HTMLResponse)
async def index():
    if not os.path.exists("index.html"):
        return HTMLResponse(content="<h2>Upload audio via /transcribe endpoint</h2>", status_code=200)
    with open("index.html") as f:
        return f.read()

@app.post("/transcribe")
async def transcribe(file: UploadFile = File(...)):
    # Use /tmp for upload storage
    temp_path = os.path.join("/tmp", f"temp_{file.filename}")
    try:
        with open(temp_path, "wb") as f:
            shutil.copyfileobj(file.file, f)
        return process_audio_file(temp_path)
    except Exception as e:
        print("ERROR:", e)
        # Clean error message
        raise HTTPException(status_code=500, detail=str(e))
    finally:
        if os.path.exists(temp_path):
            os.remove(temp_path)