pgkwon1's picture
Upload 4 files
4ebed0f verified
"""
Speechlib REST API - HuggingFace Spaces (ECAPA-TDNN 버전)
ν™”μž 뢄리 + ν™”μž 식별 + STT
"""
import os
import tempfile
import json
import numpy as np
import shutil
from typing import List, Dict, Optional
from contextlib import asynccontextmanager
# ν™˜κ²½ μ„€μ •
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1"
import torch
# PyTorch ν˜Έν™˜μ„± 패치 (버전에 따라 λΆ„κΈ°)
if hasattr(torch.serialization, 'add_safe_globals'):
torch.serialization.add_safe_globals([torch.torch_version.TorchVersion])
from pyannote.audio.core import task as pyannote_task
from pyannote.audio.core.io import Audio
torch.serialization.add_safe_globals([
pyannote_task.Specifications,
pyannote_task.Problem,
pyannote_task.Resolution,
Audio
])
# weights_only=False 패치
original_load = torch.load
def patched_load(*args, **kwargs):
if 'weights_only' not in kwargs:
kwargs['weights_only'] = False
return original_load(*args, **kwargs)
torch.load = patched_load
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
import uvicorn
import torchaudio
from pydub import AudioSegment
class SpeakerPipelineECAPA:
"""
ECAPA-TDNN μž„λ² λ”©μ„ μ‚¬μš©ν•œ ν™”μž 식별 νŒŒμ΄ν”„λΌμΈ
"""
def __init__(
self,
hf_token: str,
whisper_model: str = "large-v3-turbo",
similarity_threshold: float = 0.25,
device: str = None
):
self.hf_token = hf_token
self.whisper_model_size = whisper_model
self.similarity_threshold = similarity_threshold
# GPU μ‚¬μš© κ°€λŠ₯ν•˜λ©΄ GPU, μ•„λ‹ˆλ©΄ CPU
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.registered_speakers: Dict[str, np.ndarray] = {}
# λͺ¨λΈλ“€ (lazy loading)
self._diarization_pipeline = None
self._ecapa_model = None
self._whisper_model = None
print(f"[SpeakerPipeline ECAPA-TDNN] μ΄ˆκΈ°ν™”")
print(f" - Device: {self.device}")
print(f" - μž„κ³„κ°’: {similarity_threshold}")
@property
def diarization_pipeline(self):
if self._diarization_pipeline is None:
print("[λ‘œλ”©] pyannote/speaker-diarization-3.1...")
from pyannote.audio import Pipeline
self._diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=self.hf_token
)
self._diarization_pipeline.to(torch.device(self.device))
return self._diarization_pipeline
@property
def ecapa_model(self):
if self._ecapa_model is None:
print("[λ‘œλ”©] speechbrain ECAPA-TDNN...")
from speechbrain.inference.speaker import EncoderClassifier
self._ecapa_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="pretrained_models/spkrec-ecapa-voxceleb",
run_opts={"device": self.device}
)
return self._ecapa_model
@property
def whisper_model(self):
if self._whisper_model is None:
print(f"[λ‘œλ”©] faster-whisper {self.whisper_model_size}...")
from faster_whisper import WhisperModel
compute_type = "float16" if self.device == "cuda" else "int8"
self._whisper_model = WhisperModel(
self.whisper_model_size,
device=self.device,
compute_type=compute_type
)
return self._whisper_model
def _load_audio(self, audio_path: str) -> tuple:
"""μ˜€λ””μ˜€ λ‘œλ“œ 및 μ „μ²˜λ¦¬"""
ext = os.path.splitext(audio_path)[1].lower()
if ext in ['.m4a', '.mp4', '.aac', '.ogg', '.flac', '.mp3']:
audio = AudioSegment.from_file(audio_path)
audio = audio.set_channels(1).set_frame_rate(16000)
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
tmp_path = tmp.name
audio.export(tmp_path, format='wav')
waveform, sample_rate = torchaudio.load(tmp_path)
os.unlink(tmp_path)
else:
waveform, sample_rate = torchaudio.load(audio_path)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
sample_rate = 16000
return waveform, sample_rate
def get_embedding_ecapa(self, waveform: torch.Tensor) -> np.ndarray:
"""ECAPA-TDNN으둜 μž„λ² λ”© μΆ”μΆœ"""
if waveform.dim() == 2:
waveform = waveform.squeeze(0)
waveform = waveform.to(self.device)
with torch.no_grad():
embedding = self.ecapa_model.encode_batch(waveform.unsqueeze(0))
return embedding.squeeze().cpu().numpy()
def register_speaker(self, name: str, audio_paths: List[str]) -> None:
"""ν™”μž 등둝"""
print(f"\n[ν™”μž 등둝] {name} ({len(audio_paths)}개 μƒ˜ν”Œ)")
embeddings = []
for path in audio_paths:
if not os.path.exists(path):
continue
try:
waveform, sr = self._load_audio(path)
emb = self.get_embedding_ecapa(waveform)
emb = emb / np.linalg.norm(emb)
embeddings.append(emb)
print(f" βœ“ {os.path.basename(path)}")
except Exception as e:
print(f" βœ— μ—λŸ¬({os.path.basename(path)}): {e}")
if not embeddings:
print(f" [κ²½κ³ ] μœ νš¨ν•œ μƒ˜ν”Œμ΄ μ—†μŠ΅λ‹ˆλ‹€!")
return
avg_embedding = np.mean(embeddings, axis=0)
avg_embedding = avg_embedding / np.linalg.norm(avg_embedding)
self.registered_speakers[name] = avg_embedding
print(f"[ν™”μž 등둝] {name} μ™„λ£Œ!")
def process(self, audio_path: str, language: str = "ko") -> List[Dict]:
"""메인 처리 ν•¨μˆ˜"""
print(f"\n[처리 μ‹œμž‘] {os.path.basename(audio_path)}")
waveform, sample_rate = self._load_audio(audio_path)
audio_dict = {"waveform": waveform, "sample_rate": sample_rate}
# 1. ν™”μž 뢄리
print("[1/3] ν™”μž 뢄리 쀑...")
raw_diarization = self.diarization_pipeline(audio_dict)
diarization = None
if hasattr(raw_diarization, "itertracks"):
diarization = raw_diarization
else:
for attr in dir(raw_diarization):
if attr.startswith("_"): continue
try:
val = getattr(raw_diarization, attr)
if hasattr(val, "itertracks"):
diarization = val
break
except: pass
if diarization is None:
raise RuntimeError("ν™”μž 뢄리 κ²°κ³Όλ₯Ό νŒŒμ‹±ν•  수 μ—†μŠ΅λ‹ˆλ‹€.")
segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
segments.append({
"start": turn.start,
"end": turn.end,
"diarization_speaker": speaker
})
print(f" β†’ {len(segments)}개 μ„Έκ·Έλ¨ΌνŠΈ 감지")
# 2. ν™”μž 식별 (ECAPA-TDNN)
if self.registered_speakers:
print("[2/3] ν™”μž 식별 쀑 (ECAPA-TDNN)...")
speaker_embeddings = {}
speakers_found = set(seg["diarization_speaker"] for seg in segments)
for spk in speakers_found:
spk_embs = []
for seg in segments:
if seg["diarization_speaker"] != spk:
continue
duration = seg["end"] - seg["start"]
if duration < 0.5:
continue
try:
start_sample = int(seg["start"] * sample_rate)
end_sample = int(seg["end"] * sample_rate)
end_sample = min(end_sample, waveform.shape[1])
seg_waveform = waveform[:, start_sample:end_sample]
if seg_waveform.shape[1] < sample_rate * 0.3:
continue
emb = self.get_embedding_ecapa(seg_waveform)
emb = emb / np.linalg.norm(emb)
spk_embs.append(emb)
except:
pass
if spk_embs:
speaker_embeddings[spk] = spk_embs
# ν™”μž λ§€ν•‘
speaker_mapping = {}
speaker_scores = {}
for spk, embs in speaker_embeddings.items():
avg_emb = np.mean(embs, axis=0)
avg_emb = avg_emb / np.linalg.norm(avg_emb)
speaker_scores[spk] = {}
for name, ref_emb in self.registered_speakers.items():
sim = np.dot(avg_emb, ref_emb)
speaker_scores[spk][name] = sim
# 경쟁 맀칭
for reg_name in self.registered_speakers.keys():
best_spk = None
best_sim = -1
for spk in speaker_scores.keys():
if spk in [m[0] for m in speaker_mapping.values() if m[0] != spk]:
continue
sim = speaker_scores[spk].get(reg_name, -1)
if sim > best_sim:
best_sim = sim
best_spk = spk
if best_spk and best_sim >= self.similarity_threshold:
speaker_mapping[best_spk] = (reg_name, best_sim)
for spk in speaker_scores.keys():
if spk not in speaker_mapping:
speaker_mapping[spk] = (spk, 0.0)
for seg in segments:
d_spk = seg["diarization_speaker"]
if d_spk in speaker_mapping:
seg["speaker"], seg["similarity"] = speaker_mapping[d_spk]
else:
seg["speaker"] = d_spk
seg["similarity"] = 0.0
else:
for seg in segments:
seg["speaker"] = seg["diarization_speaker"]
seg["similarity"] = 0.0
# 3. STT
print("[3/3] μŒμ„± 인식(STT) 쀑...")
whisper_segs, _ = self.whisper_model.transcribe(
audio_path, language=language, beam_size=5, vad_filter=True
)
whisper_results = [{"start": s.start, "end": s.end, "text": s.text.strip()} for s in whisper_segs]
# 4. 병합
final_results = []
for w_seg in whisper_results:
best_speaker = "Unknown"
best_overlap = 0
best_sim = 0.0
for d_seg in segments:
overlap = max(0, min(w_seg["end"], d_seg["end"]) - max(w_seg["start"], d_seg["start"]))
if overlap > best_overlap:
best_overlap = overlap
best_speaker = d_seg["speaker"]
best_sim = d_seg.get("similarity", 0.0)
final_results.append({
"start": w_seg["start"],
"end": w_seg["end"],
"text": w_seg["text"],
"speaker": best_speaker,
"similarity": round(best_sim * 100, 1)
})
return final_results
# μ „μ—­ νŒŒμ΄ν”„λΌμΈ μΈμŠ€ν„΄μŠ€
_pipeline: Optional[SpeakerPipelineECAPA] = None
def get_pipeline(hf_token: str) -> SpeakerPipelineECAPA:
"""νŒŒμ΄ν”„λΌμΈ 싱글톀 μΈμŠ€ν„΄μŠ€ λ°˜ν™˜"""
global _pipeline
if _pipeline is None:
_pipeline = SpeakerPipelineECAPA(hf_token=hf_token)
return _pipeline
# FastAPI μ•±
@asynccontextmanager
async def lifespan(app: FastAPI):
# μ‹œμž‘ μ‹œ
print("πŸš€ Speechlib API μ„œλ²„ μ‹œμž‘")
yield
# μ’…λ£Œ μ‹œ
print("πŸ‘‹ Speechlib API μ„œλ²„ μ’…λ£Œ")
app = FastAPI(
title="Speechlib API",
description="ν™”μž 뢄리 + ν™”μž 식별 + STT REST API (ECAPA-TDNN)",
version="1.0.0",
lifespan=lifespan
)
@app.get("/")
async def root():
"""API μƒνƒœ 확인"""
return {
"status": "ok",
"message": "Speechlib API (ECAPA-TDNN)",
"endpoints": {
"/transcribe": "POST - λ‹¨μˆœ STT + ν™”μž 뢄리",
"/process": "POST - 전체 κΈ°λŠ₯ (ν™”μž 식별 포함)"
}
}
@app.get("/health")
async def health_check():
"""ν—¬μŠ€ 체크"""
return {"status": "healthy", "device": "cuda" if torch.cuda.is_available() else "cpu"}
@app.post("/transcribe")
async def transcribe(
audio: UploadFile = File(..., description="μ˜€λ””μ˜€ 파일"),
language: str = Form(default="ko", description="μ–Έμ–΄ μ½”λ“œ (ko, en, ja, zh)"),
hf_token: str = Form(..., description="HuggingFace 토큰")
):
"""
λ‹¨μˆœ STT + ν™”μž 뢄리 (ν™”μž 식별 μ—†μŒ)
"""
temp_dir = None
try:
# μž„μ‹œ 파일 μ €μž₯
temp_dir = tempfile.mkdtemp()
audio_path = os.path.join(temp_dir, audio.filename)
with open(audio_path, "wb") as f:
content = await audio.read()
f.write(content)
# νŒŒμ΄ν”„λΌμΈ μ‹€ν–‰
pipeline = get_pipeline(hf_token)
pipeline.registered_speakers.clear() # ν™”μž 식별 μ—†μŒ
results = pipeline.process(audio_path, language=language)
# κ²°κ³Ό ν¬λ§·νŒ…
segments = []
speaker_stats = {}
for r in results:
segments.append({
"start": round(r["start"], 2),
"end": round(r["end"], 2),
"text": r["text"],
"speaker": r["speaker"]
})
speaker = r["speaker"]
if speaker not in speaker_stats:
speaker_stats[speaker] = {"count": 0, "duration": 0}
speaker_stats[speaker]["count"] += 1
speaker_stats[speaker]["duration"] += r["end"] - r["start"]
for speaker in speaker_stats:
speaker_stats[speaker]["duration"] = round(speaker_stats[speaker]["duration"], 2)
return JSONResponse(content={
"success": True,
"segments": segments,
"speaker_stats": speaker_stats,
"total_segments": len(segments)
})
except Exception as e:
import traceback
return JSONResponse(
status_code=500,
content={
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}
)
finally:
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
@app.post("/process")
async def process_audio(
audio: UploadFile = File(..., description="뢄석할 μ˜€λ””μ˜€ 파일"),
voice_sample: UploadFile = File(default=None, description="ν™”μž μƒ˜ν”Œ 파일 (선택)"),
speaker_name: str = Form(default="speaker", description="식별할 ν™”μž 이름"),
language: str = Form(default="ko", description="μ–Έμ–΄ μ½”λ“œ (ko, en, ja, zh)"),
hf_token: str = Form(..., description="HuggingFace 토큰")
):
"""
전체 κΈ°λŠ₯: ν™”μž 뢄리 + ν™”μž 식별 + STT
"""
temp_dir = None
try:
# μž„μ‹œ 디렉토리 생성
temp_dir = tempfile.mkdtemp()
# 메인 μ˜€λ””μ˜€ μ €μž₯
audio_path = os.path.join(temp_dir, audio.filename)
with open(audio_path, "wb") as f:
content = await audio.read()
f.write(content)
# νŒŒμ΄ν”„λΌμΈ κ°€μ Έμ˜€κΈ°
pipeline = get_pipeline(hf_token)
pipeline.registered_speakers.clear()
# ν™”μž μƒ˜ν”Œμ΄ 있으면 등둝
if voice_sample and voice_sample.filename:
sample_path = os.path.join(temp_dir, voice_sample.filename)
with open(sample_path, "wb") as f:
sample_content = await voice_sample.read()
f.write(sample_content)
pipeline.register_speaker(speaker_name, [sample_path])
# 처리
results = pipeline.process(audio_path, language=language)
# κ²°κ³Ό ν¬λ§·νŒ…
segments = []
speaker_stats = {}
for r in results:
segments.append({
"start": round(r["start"], 2),
"end": round(r["end"], 2),
"text": r["text"],
"speaker": r["speaker"],
"similarity": r["similarity"]
})
speaker = r["speaker"]
if speaker not in speaker_stats:
speaker_stats[speaker] = {"count": 0, "duration": 0}
speaker_stats[speaker]["count"] += 1
speaker_stats[speaker]["duration"] += r["end"] - r["start"]
for speaker in speaker_stats:
speaker_stats[speaker]["duration"] = round(speaker_stats[speaker]["duration"], 2)
return JSONResponse(content={
"success": True,
"segments": segments,
"speaker_stats": speaker_stats,
"total_segments": len(segments)
})
except Exception as e:
import traceback
return JSONResponse(
status_code=500,
content={
"success": False,
"error": str(e),
"traceback": traceback.format_exc()
}
)
finally:
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)