Spaces:
Sleeping
Sleeping
| """ | |
| Speechlib REST API - HuggingFace Spaces (ECAPA-TDNN λ²μ ) | |
| νμ λΆλ¦¬ + νμ μλ³ + STT | |
| """ | |
| import os | |
| import tempfile | |
| import json | |
| import numpy as np | |
| import shutil | |
| from typing import List, Dict, Optional | |
| from contextlib import asynccontextmanager | |
| # νκ²½ μ€μ | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
| os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1" | |
| import torch | |
| # PyTorch νΈνμ± ν¨μΉ (λ²μ μ λ°λΌ λΆκΈ°) | |
| if hasattr(torch.serialization, 'add_safe_globals'): | |
| torch.serialization.add_safe_globals([torch.torch_version.TorchVersion]) | |
| from pyannote.audio.core import task as pyannote_task | |
| from pyannote.audio.core.io import Audio | |
| torch.serialization.add_safe_globals([ | |
| pyannote_task.Specifications, | |
| pyannote_task.Problem, | |
| pyannote_task.Resolution, | |
| Audio | |
| ]) | |
| # weights_only=False ν¨μΉ | |
| original_load = torch.load | |
| def patched_load(*args, **kwargs): | |
| if 'weights_only' not in kwargs: | |
| kwargs['weights_only'] = False | |
| return original_load(*args, **kwargs) | |
| torch.load = patched_load | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| import uvicorn | |
| import torchaudio | |
| from pydub import AudioSegment | |
| class SpeakerPipelineECAPA: | |
| """ | |
| ECAPA-TDNN μλ² λ©μ μ¬μ©ν νμ μλ³ νμ΄νλΌμΈ | |
| """ | |
| def __init__( | |
| self, | |
| hf_token: str, | |
| whisper_model: str = "large-v3-turbo", | |
| similarity_threshold: float = 0.25, | |
| device: str = None | |
| ): | |
| self.hf_token = hf_token | |
| self.whisper_model_size = whisper_model | |
| self.similarity_threshold = similarity_threshold | |
| # GPU μ¬μ© κ°λ₯νλ©΄ GPU, μλλ©΄ CPU | |
| if device is None: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device = device | |
| self.registered_speakers: Dict[str, np.ndarray] = {} | |
| # λͺ¨λΈλ€ (lazy loading) | |
| self._diarization_pipeline = None | |
| self._ecapa_model = None | |
| self._whisper_model = None | |
| print(f"[SpeakerPipeline ECAPA-TDNN] μ΄κΈ°ν") | |
| print(f" - Device: {self.device}") | |
| print(f" - μκ³κ°: {similarity_threshold}") | |
| def diarization_pipeline(self): | |
| if self._diarization_pipeline is None: | |
| print("[λ‘λ©] pyannote/speaker-diarization-3.1...") | |
| from pyannote.audio import Pipeline | |
| self._diarization_pipeline = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=self.hf_token | |
| ) | |
| self._diarization_pipeline.to(torch.device(self.device)) | |
| return self._diarization_pipeline | |
| def ecapa_model(self): | |
| if self._ecapa_model is None: | |
| print("[λ‘λ©] speechbrain ECAPA-TDNN...") | |
| from speechbrain.inference.speaker import EncoderClassifier | |
| self._ecapa_model = EncoderClassifier.from_hparams( | |
| source="speechbrain/spkrec-ecapa-voxceleb", | |
| savedir="pretrained_models/spkrec-ecapa-voxceleb", | |
| run_opts={"device": self.device} | |
| ) | |
| return self._ecapa_model | |
| def whisper_model(self): | |
| if self._whisper_model is None: | |
| print(f"[λ‘λ©] faster-whisper {self.whisper_model_size}...") | |
| from faster_whisper import WhisperModel | |
| compute_type = "float16" if self.device == "cuda" else "int8" | |
| self._whisper_model = WhisperModel( | |
| self.whisper_model_size, | |
| device=self.device, | |
| compute_type=compute_type | |
| ) | |
| return self._whisper_model | |
| def _load_audio(self, audio_path: str) -> tuple: | |
| """μ€λμ€ λ‘λ λ° μ μ²λ¦¬""" | |
| ext = os.path.splitext(audio_path)[1].lower() | |
| if ext in ['.m4a', '.mp4', '.aac', '.ogg', '.flac', '.mp3']: | |
| audio = AudioSegment.from_file(audio_path) | |
| audio = audio.set_channels(1).set_frame_rate(16000) | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: | |
| tmp_path = tmp.name | |
| audio.export(tmp_path, format='wav') | |
| waveform, sample_rate = torchaudio.load(tmp_path) | |
| os.unlink(tmp_path) | |
| else: | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| waveform = resampler(waveform) | |
| sample_rate = 16000 | |
| return waveform, sample_rate | |
| def get_embedding_ecapa(self, waveform: torch.Tensor) -> np.ndarray: | |
| """ECAPA-TDNNμΌλ‘ μλ² λ© μΆμΆ""" | |
| if waveform.dim() == 2: | |
| waveform = waveform.squeeze(0) | |
| waveform = waveform.to(self.device) | |
| with torch.no_grad(): | |
| embedding = self.ecapa_model.encode_batch(waveform.unsqueeze(0)) | |
| return embedding.squeeze().cpu().numpy() | |
| def register_speaker(self, name: str, audio_paths: List[str]) -> None: | |
| """νμ λ±λ‘""" | |
| print(f"\n[νμ λ±λ‘] {name} ({len(audio_paths)}κ° μν)") | |
| embeddings = [] | |
| for path in audio_paths: | |
| if not os.path.exists(path): | |
| continue | |
| try: | |
| waveform, sr = self._load_audio(path) | |
| emb = self.get_embedding_ecapa(waveform) | |
| emb = emb / np.linalg.norm(emb) | |
| embeddings.append(emb) | |
| print(f" β {os.path.basename(path)}") | |
| except Exception as e: | |
| print(f" β μλ¬({os.path.basename(path)}): {e}") | |
| if not embeddings: | |
| print(f" [κ²½κ³ ] μ ν¨ν μνμ΄ μμ΅λλ€!") | |
| return | |
| avg_embedding = np.mean(embeddings, axis=0) | |
| avg_embedding = avg_embedding / np.linalg.norm(avg_embedding) | |
| self.registered_speakers[name] = avg_embedding | |
| print(f"[νμ λ±λ‘] {name} μλ£!") | |
| def process(self, audio_path: str, language: str = "ko") -> List[Dict]: | |
| """λ©μΈ μ²λ¦¬ ν¨μ""" | |
| print(f"\n[μ²λ¦¬ μμ] {os.path.basename(audio_path)}") | |
| waveform, sample_rate = self._load_audio(audio_path) | |
| audio_dict = {"waveform": waveform, "sample_rate": sample_rate} | |
| # 1. νμ λΆλ¦¬ | |
| print("[1/3] νμ λΆλ¦¬ μ€...") | |
| raw_diarization = self.diarization_pipeline(audio_dict) | |
| diarization = None | |
| if hasattr(raw_diarization, "itertracks"): | |
| diarization = raw_diarization | |
| else: | |
| for attr in dir(raw_diarization): | |
| if attr.startswith("_"): continue | |
| try: | |
| val = getattr(raw_diarization, attr) | |
| if hasattr(val, "itertracks"): | |
| diarization = val | |
| break | |
| except: pass | |
| if diarization is None: | |
| raise RuntimeError("νμ λΆλ¦¬ κ²°κ³Όλ₯Ό νμ±ν μ μμ΅λλ€.") | |
| segments = [] | |
| for turn, _, speaker in diarization.itertracks(yield_label=True): | |
| segments.append({ | |
| "start": turn.start, | |
| "end": turn.end, | |
| "diarization_speaker": speaker | |
| }) | |
| print(f" β {len(segments)}κ° μΈκ·Έλ¨ΌνΈ κ°μ§") | |
| # 2. νμ μλ³ (ECAPA-TDNN) | |
| if self.registered_speakers: | |
| print("[2/3] νμ μλ³ μ€ (ECAPA-TDNN)...") | |
| speaker_embeddings = {} | |
| speakers_found = set(seg["diarization_speaker"] for seg in segments) | |
| for spk in speakers_found: | |
| spk_embs = [] | |
| for seg in segments: | |
| if seg["diarization_speaker"] != spk: | |
| continue | |
| duration = seg["end"] - seg["start"] | |
| if duration < 0.5: | |
| continue | |
| try: | |
| start_sample = int(seg["start"] * sample_rate) | |
| end_sample = int(seg["end"] * sample_rate) | |
| end_sample = min(end_sample, waveform.shape[1]) | |
| seg_waveform = waveform[:, start_sample:end_sample] | |
| if seg_waveform.shape[1] < sample_rate * 0.3: | |
| continue | |
| emb = self.get_embedding_ecapa(seg_waveform) | |
| emb = emb / np.linalg.norm(emb) | |
| spk_embs.append(emb) | |
| except: | |
| pass | |
| if spk_embs: | |
| speaker_embeddings[spk] = spk_embs | |
| # νμ λ§€ν | |
| speaker_mapping = {} | |
| speaker_scores = {} | |
| for spk, embs in speaker_embeddings.items(): | |
| avg_emb = np.mean(embs, axis=0) | |
| avg_emb = avg_emb / np.linalg.norm(avg_emb) | |
| speaker_scores[spk] = {} | |
| for name, ref_emb in self.registered_speakers.items(): | |
| sim = np.dot(avg_emb, ref_emb) | |
| speaker_scores[spk][name] = sim | |
| # κ²½μ λ§€μΉ | |
| for reg_name in self.registered_speakers.keys(): | |
| best_spk = None | |
| best_sim = -1 | |
| for spk in speaker_scores.keys(): | |
| if spk in [m[0] for m in speaker_mapping.values() if m[0] != spk]: | |
| continue | |
| sim = speaker_scores[spk].get(reg_name, -1) | |
| if sim > best_sim: | |
| best_sim = sim | |
| best_spk = spk | |
| if best_spk and best_sim >= self.similarity_threshold: | |
| speaker_mapping[best_spk] = (reg_name, best_sim) | |
| for spk in speaker_scores.keys(): | |
| if spk not in speaker_mapping: | |
| speaker_mapping[spk] = (spk, 0.0) | |
| for seg in segments: | |
| d_spk = seg["diarization_speaker"] | |
| if d_spk in speaker_mapping: | |
| seg["speaker"], seg["similarity"] = speaker_mapping[d_spk] | |
| else: | |
| seg["speaker"] = d_spk | |
| seg["similarity"] = 0.0 | |
| else: | |
| for seg in segments: | |
| seg["speaker"] = seg["diarization_speaker"] | |
| seg["similarity"] = 0.0 | |
| # 3. STT | |
| print("[3/3] μμ± μΈμ(STT) μ€...") | |
| whisper_segs, _ = self.whisper_model.transcribe( | |
| audio_path, language=language, beam_size=5, vad_filter=True | |
| ) | |
| whisper_results = [{"start": s.start, "end": s.end, "text": s.text.strip()} for s in whisper_segs] | |
| # 4. λ³ν© | |
| final_results = [] | |
| for w_seg in whisper_results: | |
| best_speaker = "Unknown" | |
| best_overlap = 0 | |
| best_sim = 0.0 | |
| for d_seg in segments: | |
| overlap = max(0, min(w_seg["end"], d_seg["end"]) - max(w_seg["start"], d_seg["start"])) | |
| if overlap > best_overlap: | |
| best_overlap = overlap | |
| best_speaker = d_seg["speaker"] | |
| best_sim = d_seg.get("similarity", 0.0) | |
| final_results.append({ | |
| "start": w_seg["start"], | |
| "end": w_seg["end"], | |
| "text": w_seg["text"], | |
| "speaker": best_speaker, | |
| "similarity": round(best_sim * 100, 1) | |
| }) | |
| return final_results | |
| # μ μ νμ΄νλΌμΈ μΈμ€ν΄μ€ | |
| _pipeline: Optional[SpeakerPipelineECAPA] = None | |
| def get_pipeline(hf_token: str) -> SpeakerPipelineECAPA: | |
| """νμ΄νλΌμΈ μ±κΈν€ μΈμ€ν΄μ€ λ°ν""" | |
| global _pipeline | |
| if _pipeline is None: | |
| _pipeline = SpeakerPipelineECAPA(hf_token=hf_token) | |
| return _pipeline | |
| # FastAPI μ± | |
| async def lifespan(app: FastAPI): | |
| # μμ μ | |
| print("π Speechlib API μλ² μμ") | |
| yield | |
| # μ’ λ£ μ | |
| print("π Speechlib API μλ² μ’ λ£") | |
| app = FastAPI( | |
| title="Speechlib API", | |
| description="νμ λΆλ¦¬ + νμ μλ³ + STT REST API (ECAPA-TDNN)", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| async def root(): | |
| """API μν νμΈ""" | |
| return { | |
| "status": "ok", | |
| "message": "Speechlib API (ECAPA-TDNN)", | |
| "endpoints": { | |
| "/transcribe": "POST - λ¨μ STT + νμ λΆλ¦¬", | |
| "/process": "POST - μ 체 κΈ°λ₯ (νμ μλ³ ν¬ν¨)" | |
| } | |
| } | |
| async def health_check(): | |
| """ν¬μ€ 체ν¬""" | |
| return {"status": "healthy", "device": "cuda" if torch.cuda.is_available() else "cpu"} | |
| async def transcribe( | |
| audio: UploadFile = File(..., description="μ€λμ€ νμΌ"), | |
| language: str = Form(default="ko", description="μΈμ΄ μ½λ (ko, en, ja, zh)"), | |
| hf_token: str = Form(..., description="HuggingFace ν ν°") | |
| ): | |
| """ | |
| λ¨μ STT + νμ λΆλ¦¬ (νμ μλ³ μμ) | |
| """ | |
| temp_dir = None | |
| try: | |
| # μμ νμΌ μ μ₯ | |
| temp_dir = tempfile.mkdtemp() | |
| audio_path = os.path.join(temp_dir, audio.filename) | |
| with open(audio_path, "wb") as f: | |
| content = await audio.read() | |
| f.write(content) | |
| # νμ΄νλΌμΈ μ€ν | |
| pipeline = get_pipeline(hf_token) | |
| pipeline.registered_speakers.clear() # νμ μλ³ μμ | |
| results = pipeline.process(audio_path, language=language) | |
| # κ²°κ³Ό ν¬λ§·ν | |
| segments = [] | |
| speaker_stats = {} | |
| for r in results: | |
| segments.append({ | |
| "start": round(r["start"], 2), | |
| "end": round(r["end"], 2), | |
| "text": r["text"], | |
| "speaker": r["speaker"] | |
| }) | |
| speaker = r["speaker"] | |
| if speaker not in speaker_stats: | |
| speaker_stats[speaker] = {"count": 0, "duration": 0} | |
| speaker_stats[speaker]["count"] += 1 | |
| speaker_stats[speaker]["duration"] += r["end"] - r["start"] | |
| for speaker in speaker_stats: | |
| speaker_stats[speaker]["duration"] = round(speaker_stats[speaker]["duration"], 2) | |
| return JSONResponse(content={ | |
| "success": True, | |
| "segments": segments, | |
| "speaker_stats": speaker_stats, | |
| "total_segments": len(segments) | |
| }) | |
| except Exception as e: | |
| import traceback | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "success": False, | |
| "error": str(e), | |
| "traceback": traceback.format_exc() | |
| } | |
| ) | |
| finally: | |
| if temp_dir and os.path.exists(temp_dir): | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| async def process_audio( | |
| audio: UploadFile = File(..., description="λΆμν μ€λμ€ νμΌ"), | |
| voice_sample: UploadFile = File(default=None, description="νμ μν νμΌ (μ ν)"), | |
| speaker_name: str = Form(default="speaker", description="μλ³ν νμ μ΄λ¦"), | |
| language: str = Form(default="ko", description="μΈμ΄ μ½λ (ko, en, ja, zh)"), | |
| hf_token: str = Form(..., description="HuggingFace ν ν°") | |
| ): | |
| """ | |
| μ 체 κΈ°λ₯: νμ λΆλ¦¬ + νμ μλ³ + STT | |
| """ | |
| temp_dir = None | |
| try: | |
| # μμ λλ ν 리 μμ± | |
| temp_dir = tempfile.mkdtemp() | |
| # λ©μΈ μ€λμ€ μ μ₯ | |
| audio_path = os.path.join(temp_dir, audio.filename) | |
| with open(audio_path, "wb") as f: | |
| content = await audio.read() | |
| f.write(content) | |
| # νμ΄νλΌμΈ κ°μ Έμ€κΈ° | |
| pipeline = get_pipeline(hf_token) | |
| pipeline.registered_speakers.clear() | |
| # νμ μνμ΄ μμΌλ©΄ λ±λ‘ | |
| if voice_sample and voice_sample.filename: | |
| sample_path = os.path.join(temp_dir, voice_sample.filename) | |
| with open(sample_path, "wb") as f: | |
| sample_content = await voice_sample.read() | |
| f.write(sample_content) | |
| pipeline.register_speaker(speaker_name, [sample_path]) | |
| # μ²λ¦¬ | |
| results = pipeline.process(audio_path, language=language) | |
| # κ²°κ³Ό ν¬λ§·ν | |
| segments = [] | |
| speaker_stats = {} | |
| for r in results: | |
| segments.append({ | |
| "start": round(r["start"], 2), | |
| "end": round(r["end"], 2), | |
| "text": r["text"], | |
| "speaker": r["speaker"], | |
| "similarity": r["similarity"] | |
| }) | |
| speaker = r["speaker"] | |
| if speaker not in speaker_stats: | |
| speaker_stats[speaker] = {"count": 0, "duration": 0} | |
| speaker_stats[speaker]["count"] += 1 | |
| speaker_stats[speaker]["duration"] += r["end"] - r["start"] | |
| for speaker in speaker_stats: | |
| speaker_stats[speaker]["duration"] = round(speaker_stats[speaker]["duration"], 2) | |
| return JSONResponse(content={ | |
| "success": True, | |
| "segments": segments, | |
| "speaker_stats": speaker_stats, | |
| "total_segments": len(segments) | |
| }) | |
| except Exception as e: | |
| import traceback | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "success": False, | |
| "error": str(e), | |
| "traceback": traceback.format_exc() | |
| } | |
| ) | |
| finally: | |
| if temp_dir and os.path.exists(temp_dir): | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |