KU_SW_Academy / models /audio_encoder.py
heybaeheef's picture
Upload 3 files
0c3b738 verified
"""
Audio Encoder for MagicPath Server
===================================
CLAP ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์˜ค๋””์˜ค ํŒŒ์ผ์—์„œ ํŠน์ง• ๋ฒกํ„ฐ ์ถ”์ถœ
DiffVox LLM๊ณผ ๋™์ผํ•œ ์ธ์ฝ”๋” ์‚ฌ์šฉ
"""
import torch
import numpy as np
from typing import List, Optional
import warnings
warnings.filterwarnings("ignore")
class AudioEncoder:
"""CLAP ๊ธฐ๋ฐ˜ ์˜ค๋””์˜ค ์ธ์ฝ”๋”"""
def __init__(
self,
output_dim: int = 64,
reduction_method: str = "pool",
model_name: str = "laion/larger_clap_general"
):
"""
์˜ค๋””์˜ค ์ธ์ฝ”๋” ์ดˆ๊ธฐํ™”
Args:
output_dim: ์ถœ๋ ฅ ํŠน์ง• ์ฐจ์› (๊ธฐ๋ณธ 64)
reduction_method: ์ฐจ์› ์ถ•์†Œ ๋ฐฉ๋ฒ• ("pool", "pca", "linear")
model_name: CLAP ๋ชจ๋ธ ์ด๋ฆ„
"""
self.output_dim = output_dim
self.reduction_method = reduction_method
self.model_name = model_name
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.processor = None
self.projection = None
self._load_model()
def _load_model(self):
"""CLAP ๋ชจ๋ธ ๋กœ๋“œ"""
try:
from transformers import ClapModel, ClapProcessor
print(f"[AudioEncoder] CLAP ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘: {self.model_name}")
self.processor = ClapProcessor.from_pretrained(self.model_name)
self.model = ClapModel.from_pretrained(self.model_name)
self.model = self.model.to(self.device)
self.model.eval()
# CLAP ์ถœ๋ ฅ ์ฐจ์› ํ™•์ธ (๋ณดํ†ต 512)
clap_dim = self.model.config.projection_dim
print(f"[AudioEncoder] CLAP ์ถœ๋ ฅ ์ฐจ์›: {clap_dim}")
# ์ฐจ์› ์ถ•์†Œ๋ฅผ ์œ„ํ•œ projection layer
if self.reduction_method == "linear" and clap_dim != self.output_dim:
self.projection = torch.nn.Linear(clap_dim, self.output_dim)
self.projection = self.projection.to(self.device)
print(f"[AudioEncoder] Linear projection: {clap_dim} โ†’ {self.output_dim}")
print("[AudioEncoder] โœ… ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
except ImportError:
print("[AudioEncoder] โŒ transformers ๋ฏธ์„ค์น˜")
print(" pip install transformers")
except Exception as e:
print(f"[AudioEncoder] โŒ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
def get_audio_features(self, audio_path: str) -> List[float]:
"""
์˜ค๋””์˜ค ํŒŒ์ผ์—์„œ ํŠน์ง• ๋ฒกํ„ฐ ์ถ”์ถœ
Args:
audio_path: ์˜ค๋””์˜ค ํŒŒ์ผ ๊ฒฝ๋กœ
Returns:
ํŠน์ง• ๋ฒกํ„ฐ (output_dim ์ฐจ์›)
"""
if self.model is None:
print("[AudioEncoder] ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์Œ")
return []
try:
import librosa
# ์˜ค๋””์˜ค ๋กœ๋“œ
audio, sr = librosa.load(audio_path, sr=48000, mono=True)
# CLAP ์ž…๋ ฅ ์ค€๋น„
inputs = self.processor(
audios=audio,
sampling_rate=48000,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# ํŠน์ง• ์ถ”์ถœ
with torch.no_grad():
audio_features = self.model.get_audio_features(**inputs)
# CPU๋กœ ์ด๋™
features = audio_features.squeeze().cpu().numpy()
# ์ฐจ์› ์ถ•์†Œ
features = self._reduce_dimension(features)
return features.tolist()
except Exception as e:
print(f"[AudioEncoder] ํŠน์ง• ์ถ”์ถœ ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()
return []
def _reduce_dimension(self, features: np.ndarray) -> np.ndarray:
"""ํŠน์ง• ๋ฒกํ„ฐ ์ฐจ์› ์ถ•์†Œ"""
current_dim = len(features)
if current_dim == self.output_dim:
return features
if self.reduction_method == "pool":
# ํ‰๊ท  ํ’€๋ง์œผ๋กœ ์ฐจ์› ์ถ•์†Œ
if current_dim > self.output_dim:
pool_size = current_dim // self.output_dim
remainder = current_dim % self.output_dim
pooled = []
idx = 0
for i in range(self.output_dim):
size = pool_size + (1 if i < remainder else 0)
pooled.append(np.mean(features[idx:idx+size]))
idx += size
return np.array(pooled)
else:
# ์ฐจ์›์ด ์ž‘์œผ๋ฉด zero-padding
padded = np.zeros(self.output_dim)
padded[:current_dim] = features
return padded
elif self.reduction_method == "linear" and self.projection is not None:
# Linear projection
with torch.no_grad():
features_tensor = torch.tensor(features, dtype=torch.float32).to(self.device)
projected = self.projection(features_tensor)
return projected.cpu().numpy()
else:
# ๊ธฐ๋ณธ: ์•ž์—์„œ๋ถ€ํ„ฐ ์ž๋ฅด๊ธฐ
return features[:self.output_dim]
def get_text_features(self, text: str) -> List[float]:
"""
ํ…์ŠคํŠธ์—์„œ ํŠน์ง• ๋ฒกํ„ฐ ์ถ”์ถœ (CLAP text encoder)
Args:
text: ์ž…๋ ฅ ํ…์ŠคํŠธ
Returns:
ํŠน์ง• ๋ฒกํ„ฐ
"""
if self.model is None:
return []
try:
inputs = self.processor(
text=text,
return_tensors="pt",
padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
text_features = self.model.get_text_features(**inputs)
features = text_features.squeeze().cpu().numpy()
features = self._reduce_dimension(features)
return features.tolist()
except Exception as e:
print(f"[AudioEncoder] ํ…์ŠคํŠธ ํŠน์ง• ์ถ”์ถœ ์‹คํŒจ: {e}")
return []