""" Speechlib REST API - HuggingFace Spaces (ECAPA-TDNN 버전) 화자 분리 + 화자 식별 + STT """ import os import tempfile import json import numpy as np import shutil from typing import List, Dict, Optional from contextlib import asynccontextmanager # 환경 설정 os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1" import torch # PyTorch 호환성 패치 (버전에 따라 분기) if hasattr(torch.serialization, 'add_safe_globals'): torch.serialization.add_safe_globals([torch.torch_version.TorchVersion]) from pyannote.audio.core import task as pyannote_task from pyannote.audio.core.io import Audio torch.serialization.add_safe_globals([ pyannote_task.Specifications, pyannote_task.Problem, pyannote_task.Resolution, Audio ]) # weights_only=False 패치 original_load = torch.load def patched_load(*args, **kwargs): if 'weights_only' not in kwargs: kwargs['weights_only'] = False return original_load(*args, **kwargs) torch.load = patched_load from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import JSONResponse import uvicorn import torchaudio from pydub import AudioSegment class SpeakerPipelineECAPA: """ ECAPA-TDNN 임베딩을 사용한 화자 식별 파이프라인 """ def __init__( self, hf_token: str, whisper_model: str = "large-v3-turbo", similarity_threshold: float = 0.25, device: str = None ): self.hf_token = hf_token self.whisper_model_size = whisper_model self.similarity_threshold = similarity_threshold # GPU 사용 가능하면 GPU, 아니면 CPU if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device self.registered_speakers: Dict[str, np.ndarray] = {} # 모델들 (lazy loading) self._diarization_pipeline = None self._ecapa_model = None self._whisper_model = None print(f"[SpeakerPipeline ECAPA-TDNN] 초기화") print(f" - Device: {self.device}") print(f" - 임계값: {similarity_threshold}") @property def diarization_pipeline(self): if self._diarization_pipeline is None: print("[로딩] pyannote/speaker-diarization-3.1...") from pyannote.audio import Pipeline self._diarization_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=self.hf_token ) self._diarization_pipeline.to(torch.device(self.device)) return self._diarization_pipeline @property def ecapa_model(self): if self._ecapa_model is None: print("[로딩] speechbrain ECAPA-TDNN...") from speechbrain.inference.speaker import EncoderClassifier self._ecapa_model = EncoderClassifier.from_hparams( source="speechbrain/spkrec-ecapa-voxceleb", savedir="pretrained_models/spkrec-ecapa-voxceleb", run_opts={"device": self.device} ) return self._ecapa_model @property def whisper_model(self): if self._whisper_model is None: print(f"[로딩] faster-whisper {self.whisper_model_size}...") from faster_whisper import WhisperModel compute_type = "float16" if self.device == "cuda" else "int8" self._whisper_model = WhisperModel( self.whisper_model_size, device=self.device, compute_type=compute_type ) return self._whisper_model def _load_audio(self, audio_path: str) -> tuple: """오디오 로드 및 전처리""" ext = os.path.splitext(audio_path)[1].lower() if ext in ['.m4a', '.mp4', '.aac', '.ogg', '.flac', '.mp3']: audio = AudioSegment.from_file(audio_path) audio = audio.set_channels(1).set_frame_rate(16000) with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: tmp_path = tmp.name audio.export(tmp_path, format='wav') waveform, sample_rate = torchaudio.load(tmp_path) os.unlink(tmp_path) else: waveform, sample_rate = torchaudio.load(audio_path) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) sample_rate = 16000 return waveform, sample_rate def get_embedding_ecapa(self, waveform: torch.Tensor) -> np.ndarray: """ECAPA-TDNN으로 임베딩 추출""" if waveform.dim() == 2: waveform = waveform.squeeze(0) waveform = waveform.to(self.device) with torch.no_grad(): embedding = self.ecapa_model.encode_batch(waveform.unsqueeze(0)) return embedding.squeeze().cpu().numpy() def register_speaker(self, name: str, audio_paths: List[str]) -> None: """화자 등록""" print(f"\n[화자 등록] {name} ({len(audio_paths)}개 샘플)") embeddings = [] for path in audio_paths: if not os.path.exists(path): continue try: waveform, sr = self._load_audio(path) emb = self.get_embedding_ecapa(waveform) emb = emb / np.linalg.norm(emb) embeddings.append(emb) print(f" ✓ {os.path.basename(path)}") except Exception as e: print(f" ✗ 에러({os.path.basename(path)}): {e}") if not embeddings: print(f" [경고] 유효한 샘플이 없습니다!") return avg_embedding = np.mean(embeddings, axis=0) avg_embedding = avg_embedding / np.linalg.norm(avg_embedding) self.registered_speakers[name] = avg_embedding print(f"[화자 등록] {name} 완료!") def process(self, audio_path: str, language: str = "ko") -> List[Dict]: """메인 처리 함수""" print(f"\n[처리 시작] {os.path.basename(audio_path)}") waveform, sample_rate = self._load_audio(audio_path) audio_dict = {"waveform": waveform, "sample_rate": sample_rate} # 1. 화자 분리 print("[1/3] 화자 분리 중...") raw_diarization = self.diarization_pipeline(audio_dict) diarization = None if hasattr(raw_diarization, "itertracks"): diarization = raw_diarization else: for attr in dir(raw_diarization): if attr.startswith("_"): continue try: val = getattr(raw_diarization, attr) if hasattr(val, "itertracks"): diarization = val break except: pass if diarization is None: raise RuntimeError("화자 분리 결과를 파싱할 수 없습니다.") segments = [] for turn, _, speaker in diarization.itertracks(yield_label=True): segments.append({ "start": turn.start, "end": turn.end, "diarization_speaker": speaker }) print(f" → {len(segments)}개 세그먼트 감지") # 2. 화자 식별 (ECAPA-TDNN) if self.registered_speakers: print("[2/3] 화자 식별 중 (ECAPA-TDNN)...") speaker_embeddings = {} speakers_found = set(seg["diarization_speaker"] for seg in segments) for spk in speakers_found: spk_embs = [] for seg in segments: if seg["diarization_speaker"] != spk: continue duration = seg["end"] - seg["start"] if duration < 0.5: continue try: start_sample = int(seg["start"] * sample_rate) end_sample = int(seg["end"] * sample_rate) end_sample = min(end_sample, waveform.shape[1]) seg_waveform = waveform[:, start_sample:end_sample] if seg_waveform.shape[1] < sample_rate * 0.3: continue emb = self.get_embedding_ecapa(seg_waveform) emb = emb / np.linalg.norm(emb) spk_embs.append(emb) except: pass if spk_embs: speaker_embeddings[spk] = spk_embs # 화자 매핑 speaker_mapping = {} speaker_scores = {} for spk, embs in speaker_embeddings.items(): avg_emb = np.mean(embs, axis=0) avg_emb = avg_emb / np.linalg.norm(avg_emb) speaker_scores[spk] = {} for name, ref_emb in self.registered_speakers.items(): sim = np.dot(avg_emb, ref_emb) speaker_scores[spk][name] = sim # 경쟁 매칭 for reg_name in self.registered_speakers.keys(): best_spk = None best_sim = -1 for spk in speaker_scores.keys(): if spk in [m[0] for m in speaker_mapping.values() if m[0] != spk]: continue sim = speaker_scores[spk].get(reg_name, -1) if sim > best_sim: best_sim = sim best_spk = spk if best_spk and best_sim >= self.similarity_threshold: speaker_mapping[best_spk] = (reg_name, best_sim) for spk in speaker_scores.keys(): if spk not in speaker_mapping: speaker_mapping[spk] = (spk, 0.0) for seg in segments: d_spk = seg["diarization_speaker"] if d_spk in speaker_mapping: seg["speaker"], seg["similarity"] = speaker_mapping[d_spk] else: seg["speaker"] = d_spk seg["similarity"] = 0.0 else: for seg in segments: seg["speaker"] = seg["diarization_speaker"] seg["similarity"] = 0.0 # 3. STT print("[3/3] 음성 인식(STT) 중...") whisper_segs, _ = self.whisper_model.transcribe( audio_path, language=language, beam_size=5, vad_filter=True ) whisper_results = [{"start": s.start, "end": s.end, "text": s.text.strip()} for s in whisper_segs] # 4. 병합 final_results = [] for w_seg in whisper_results: best_speaker = "Unknown" best_overlap = 0 best_sim = 0.0 for d_seg in segments: overlap = max(0, min(w_seg["end"], d_seg["end"]) - max(w_seg["start"], d_seg["start"])) if overlap > best_overlap: best_overlap = overlap best_speaker = d_seg["speaker"] best_sim = d_seg.get("similarity", 0.0) final_results.append({ "start": w_seg["start"], "end": w_seg["end"], "text": w_seg["text"], "speaker": best_speaker, "similarity": round(best_sim * 100, 1) }) return final_results # 전역 파이프라인 인스턴스 _pipeline: Optional[SpeakerPipelineECAPA] = None def get_pipeline(hf_token: str) -> SpeakerPipelineECAPA: """파이프라인 싱글톤 인스턴스 반환""" global _pipeline if _pipeline is None: _pipeline = SpeakerPipelineECAPA(hf_token=hf_token) return _pipeline # FastAPI 앱 @asynccontextmanager async def lifespan(app: FastAPI): # 시작 시 print("🚀 Speechlib API 서버 시작") yield # 종료 시 print("👋 Speechlib API 서버 종료") app = FastAPI( title="Speechlib API", description="화자 분리 + 화자 식별 + STT REST API (ECAPA-TDNN)", version="1.0.0", lifespan=lifespan ) @app.get("/") async def root(): """API 상태 확인""" return { "status": "ok", "message": "Speechlib API (ECAPA-TDNN)", "endpoints": { "/transcribe": "POST - 단순 STT + 화자 분리", "/process": "POST - 전체 기능 (화자 식별 포함)" } } @app.get("/health") async def health_check(): """헬스 체크""" return {"status": "healthy", "device": "cuda" if torch.cuda.is_available() else "cpu"} @app.post("/transcribe") async def transcribe( audio: UploadFile = File(..., description="오디오 파일"), language: str = Form(default="ko", description="언어 코드 (ko, en, ja, zh)"), hf_token: str = Form(..., description="HuggingFace 토큰") ): """ 단순 STT + 화자 분리 (화자 식별 없음) """ temp_dir = None try: # 임시 파일 저장 temp_dir = tempfile.mkdtemp() audio_path = os.path.join(temp_dir, audio.filename) with open(audio_path, "wb") as f: content = await audio.read() f.write(content) # 파이프라인 실행 pipeline = get_pipeline(hf_token) pipeline.registered_speakers.clear() # 화자 식별 없음 results = pipeline.process(audio_path, language=language) # 결과 포맷팅 segments = [] speaker_stats = {} for r in results: segments.append({ "start": round(r["start"], 2), "end": round(r["end"], 2), "text": r["text"], "speaker": r["speaker"] }) speaker = r["speaker"] if speaker not in speaker_stats: speaker_stats[speaker] = {"count": 0, "duration": 0} speaker_stats[speaker]["count"] += 1 speaker_stats[speaker]["duration"] += r["end"] - r["start"] for speaker in speaker_stats: speaker_stats[speaker]["duration"] = round(speaker_stats[speaker]["duration"], 2) return JSONResponse(content={ "success": True, "segments": segments, "speaker_stats": speaker_stats, "total_segments": len(segments) }) except Exception as e: import traceback return JSONResponse( status_code=500, content={ "success": False, "error": str(e), "traceback": traceback.format_exc() } ) finally: if temp_dir and os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) @app.post("/process") async def process_audio( audio: UploadFile = File(..., description="분석할 오디오 파일"), voice_sample: UploadFile = File(default=None, description="화자 샘플 파일 (선택)"), speaker_name: str = Form(default="speaker", description="식별할 화자 이름"), language: str = Form(default="ko", description="언어 코드 (ko, en, ja, zh)"), hf_token: str = Form(..., description="HuggingFace 토큰") ): """ 전체 기능: 화자 분리 + 화자 식별 + STT """ temp_dir = None try: # 임시 디렉토리 생성 temp_dir = tempfile.mkdtemp() # 메인 오디오 저장 audio_path = os.path.join(temp_dir, audio.filename) with open(audio_path, "wb") as f: content = await audio.read() f.write(content) # 파이프라인 가져오기 pipeline = get_pipeline(hf_token) pipeline.registered_speakers.clear() # 화자 샘플이 있으면 등록 if voice_sample and voice_sample.filename: sample_path = os.path.join(temp_dir, voice_sample.filename) with open(sample_path, "wb") as f: sample_content = await voice_sample.read() f.write(sample_content) pipeline.register_speaker(speaker_name, [sample_path]) # 처리 results = pipeline.process(audio_path, language=language) # 결과 포맷팅 segments = [] speaker_stats = {} for r in results: segments.append({ "start": round(r["start"], 2), "end": round(r["end"], 2), "text": r["text"], "speaker": r["speaker"], "similarity": r["similarity"] }) speaker = r["speaker"] if speaker not in speaker_stats: speaker_stats[speaker] = {"count": 0, "duration": 0} speaker_stats[speaker]["count"] += 1 speaker_stats[speaker]["duration"] += r["end"] - r["start"] for speaker in speaker_stats: speaker_stats[speaker]["duration"] = round(speaker_stats[speaker]["duration"], 2) return JSONResponse(content={ "success": True, "segments": segments, "speaker_stats": speaker_stats, "total_segments": len(segments) }) except Exception as e: import traceback return JSONResponse( status_code=500, content={ "success": False, "error": str(e), "traceback": traceback.format_exc() } ) finally: if temp_dir and os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)