KU_SW_Academy / models /ai_effector.py
heybaeheef's picture
Upload 3 files
0c3b738 verified
raw
history blame
23.8 kB
"""
AI Effector - DiffVox LLM 기반 μ΄νŽ™νŠΈ νŒŒλΌλ―Έν„° 예츑
===================================================
V4: κ·Όλ³Έ 원인 ν•΄κ²°
- sigmoid λ³€ν™˜ (delay.feedback, delay.mix, distortion_amount)
- parametrizations.X.original ν‚€ μ •κ·œν™”
- delay.delay_time은 ν•™μŠ΅ μ•ˆλ¨ β†’ 프리셋 보완
- λ™μ˜μ–΄ λ§€ν•‘
"""
import os
import json
import re
import math
import torch
import numpy as np
from typing import Dict, List, Optional, Any
from pathlib import Path
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")
def sigmoid(x: float) -> float:
"""μ‹œκ·Έλͺ¨μ΄λ“œ ν•¨μˆ˜"""
try:
return 1 / (1 + math.exp(-x))
except OverflowError:
return 0.0 if x < 0 else 1.0
# κΈ°λ³Έ νŒŒλΌλ―Έν„° (λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨ μ‹œ μ‚¬μš©)
DEFAULT_PARAMETERS = {
"eq_peak1.params.freq": 1000.0,
"eq_peak1.params.gain": 0.0,
"eq_peak1.params.Q": 1.0,
"eq_peak2.params.freq": 4000.0,
"eq_peak2.params.gain": 0.0,
"eq_peak2.params.Q": 1.0,
"eq_lowshelf.params.freq": 200.0,
"eq_lowshelf.params.gain": 0.0,
"eq_highshelf.params.freq": 8000.0,
"eq_highshelf.params.gain": 0.0,
"distortion_amount": 0.0,
"delay.delay_time": 0.02,
"delay.feedback": 0.3,
"delay.mix": 0.2,
"final_wet_mix": 0.5
}
# νŒŒλΌλ―Έν„° λ²”μœ„ μ œν•œ (λ³€ν™˜ ν›„ μ‹€μ œ κ°’ κΈ°μ€€)
PARAM_RANGES = {
"eq_peak1.params.freq": (20.0, 20000.0),
"eq_peak1.params.gain": (-12.0, 12.0),
"eq_peak1.params.Q": (0.1, 10.0),
"eq_peak2.params.freq": (20.0, 20000.0),
"eq_peak2.params.gain": (-12.0, 12.0),
"eq_peak2.params.Q": (0.1, 10.0),
"eq_lowshelf.params.freq": (20.0, 2000.0),
"eq_lowshelf.params.gain": (-12.0, 12.0),
"eq_highshelf.params.freq": (1000.0, 20000.0),
"eq_highshelf.params.gain": (-12.0, 12.0),
"distortion_amount": (0.0, 0.1), # sigmoid * 0.1 ν›„
"delay.delay_time": (0.01, 1.0),
"delay.feedback": (0.0, 0.9), # sigmoid ν›„
"delay.mix": (0.0, 1.0), # sigmoid ν›„
"final_wet_mix": (0.0, 1.0), # sigmoid ν›„
}
# λ™μ˜μ–΄ λ§€ν•‘ (λ―Έν•™μŠ΅ 단어 β†’ ν•™μŠ΅λœ 단어)
SYNONYM_MAP = {
"calm": "warm soft",
"relaxed": "warm soft",
"chill": "warm soft",
"smooth": "warm",
"mellow": "warm soft",
"breezy": "bright spacious",
"airy": "bright spacious",
"light": "bright",
"crisp": "bright",
"clean": "bright",
"dreamy": "warm spacious",
"ethereal": "bright spacious",
"atmospheric": "spacious",
"ambient": "spacious warm",
"aggressive": "saturated bright",
"powerful": "saturated",
"punchy": "saturated bright",
"hard": "saturated",
"gritty": "saturated dark",
"soft": "warm",
"harsh": "bright saturated",
"muddy": "dark",
"thin": "bright",
"thick": "warm dark",
"full": "warm",
"reverb": "spacious",
"echo": "spacious",
"wet": "spacious",
}
# μŠ€νƒ€μΌ 프리셋 (delay.delay_time λ³΄μ™„μš©)
STYLE_PRESETS = {
"warm": {
"eq_lowshelf.params.gain": 3.0,
"eq_highshelf.params.gain": -1.0,
},
"bright": {
"eq_highshelf.params.gain": 4.0,
"eq_peak2.params.gain": 2.0,
"eq_lowshelf.params.gain": -1.0,
},
"spacious": {
"delay.delay_time": 0.05, # ν•™μŠ΅ μ•ˆλœ νŒŒλΌλ―Έν„° 보완
},
"dark": {
"eq_highshelf.params.gain": -4.0,
"eq_lowshelf.params.gain": 2.0,
},
"saturated": {},
"soft": {
"eq_highshelf.params.gain": -2.0,
"eq_lowshelf.params.gain": 1.0,
},
}
class CLAPAudioEncoder:
"""CLAP 기반 μ˜€λ””μ˜€ 인코더 (ν•™μŠ΅ μ‹œμ™€ 동일)"""
def __init__(self, output_dim: int = 64, model_name: str = "laion/larger_clap_music"):
self.output_dim = output_dim
self.model_name = model_name
self.target_sr = 48000
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.processor = None
self._load_model()
def _load_model(self):
try:
from transformers import ClapModel, ClapProcessor
print(f"[CLAPEncoder] CLAP λͺ¨λΈ λ‘œλ”© 쀑: {self.model_name}")
self.processor = ClapProcessor.from_pretrained(self.model_name)
self.model = ClapModel.from_pretrained(self.model_name)
self.model = self.model.to(self.device)
self.model.eval()
print(f"[CLAPEncoder] βœ… CLAP λͺ¨λΈ λ‘œλ“œ μ™„λ£Œ (512β†’{self.output_dim} pooling)")
except ImportError:
print("[CLAPEncoder] ❌ transformers λ―Έμ„€μΉ˜")
except Exception as e:
print(f"[CLAPEncoder] ❌ λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨: {e}")
def get_audio_features(self, audio_path: str) -> List[float]:
if self.model is None:
return [0.0] * self.output_dim
try:
import librosa
audio, sr = librosa.load(audio_path, sr=self.target_sr, mono=True)
inputs = self.processor(
audios=audio,
sampling_rate=self.target_sr,
return_tensors="pt",
padding=True
).to(self.device)
with torch.no_grad():
outputs = self.model.get_audio_features(**inputs)
features_512 = outputs[0].cpu().numpy()
features_64 = self._reduce_dimension(features_512)
return features_64.tolist()
except Exception as e:
print(f"[CLAPEncoder] νŠΉμ§• μΆ”μΆœ μ‹€νŒ¨: {e}")
return [0.0] * self.output_dim
def _reduce_dimension(self, features: np.ndarray) -> np.ndarray:
current_dim = len(features)
if current_dim == self.output_dim:
return features
pool_size = current_dim // self.output_dim
remainder = current_dim % self.output_dim
pooled = []
idx = 0
for i in range(self.output_dim):
size = pool_size + (1 if i < remainder else 0)
pooled.append(np.mean(features[idx:idx+size]))
idx += size
return np.array(pooled)
def is_loaded(self) -> bool:
return self.model is not None
class AIEffector:
"""AI 기반 μ΄νŽ™ν„° νŒŒλΌλ―Έν„° 예츑 (V4)"""
def __init__(
self,
model_repo_id: str = "heybaeheef/KU_SW_Academy",
model_subfolder: str = "checkpoints",
base_model_name: str = "Qwen/Qwen3-8B",
audio_feature_dim: int = 64,
use_huggingface: bool = True
):
self.model_repo_id = model_repo_id
self.model_subfolder = model_subfolder
self.base_model_name = base_model_name
self.audio_feature_dim = audio_feature_dim
self.use_huggingface = use_huggingface
self.model = None
self.tokenizer = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[AIEffector] CLAP μ˜€λ””μ˜€ 인코더 μ΄ˆκΈ°ν™” 쀑...")
self.audio_encoder = CLAPAudioEncoder(output_dim=audio_feature_dim)
self.request_count = 0
self._load_model()
def _load_model(self):
try:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
print(f"[AIEffector] 베이슀 λͺ¨λΈ λ‘œλ”© 쀑: {self.base_model_name}")
if torch.cuda.is_available():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.base_model_name,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
print(f"[AIEffector] LoRA μ–΄λŒ‘ν„° λ‘œλ”© 쀑...")
if self.use_huggingface:
self.model = PeftModel.from_pretrained(
base_model,
self.model_repo_id,
subfolder=self.model_subfolder,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
else:
local_path = os.path.join(self.model_repo_id, self.model_subfolder)
self.model = PeftModel.from_pretrained(
base_model,
local_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
self.model.eval()
print(f"[AIEffector] βœ… λͺ¨λΈ λ‘œλ“œ 성곡!")
except Exception as e:
print(f"[AIEffector] ❌ λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨: {e}")
import traceback
traceback.print_exc()
self.model = None
self.tokenizer = None
def is_loaded(self) -> bool:
return self.model is not None
def _preprocess_text(self, text: str) -> str:
"""λ™μ˜μ–΄ λ§€ν•‘"""
text_lower = text.lower()
for synonym, replacement in SYNONYM_MAP.items():
if synonym in text_lower:
text_lower = text_lower.replace(synonym, replacement)
print(f" [Synonym] '{synonym}' β†’ '{replacement}'")
return text_lower
def _apply_preset(self, prompt: str) -> Dict[str, float]:
"""프리셋 λ§€μΉ­ (delay.delay_time λ³΄μ™„μš©)"""
params = {}
prompt_lower = prompt.lower()
matched = []
for style_name, style_params in STYLE_PRESETS.items():
if style_name in prompt_lower:
params.update(style_params)
matched.append(style_name)
if matched:
print(f" [Preset] λ§€μΉ­: {matched}")
return params
def _format_prompt(self, text_prompt: str, audio_features: List[float]) -> str:
"""ν•™μŠ΅ μ‹œμ™€ λ™μΌν•œ ν”„λ‘¬ν”„νŠΈ"""
audio_state_str = json.dumps(audio_features)
return f"""Task: Convert text to audio parameters.
Audio: {audio_state_str}
Text: {text_prompt}
Parameters:"""
def _preprocess_json(self, json_str: str) -> str:
"""JSON μ „μ²˜λ¦¬"""
# 숫자 μ–Έλ”μŠ€μ½”μ–΄ 제거 (0.30_299 β†’ 0.30299)
json_str = re.sub(r'(\d)_(\d)', r'\1\2', json_str)
# Trailing comma 제거
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
# NaN, Infinity
json_str = re.sub(r'\bNaN\b', '0', json_str)
json_str = re.sub(r'\bInfinity\b', '999999', json_str)
json_str = re.sub(r'-Infinity\b', '-999999', json_str)
return json_str
def _normalize_key(self, key: str) -> str:
"""
νŒŒλΌλ―Έν„° ν‚€ μ •κ·œν™”
eq_peak1.params.parametrizations.freq.original β†’ eq_peak1.params.freq
"""
# parametrizations.X.original β†’ X
key = re.sub(r'\.parametrizations\.(\w+)\.original', r'.\1', key)
# Q β†’ Q (λŒ€λ¬Έμž μœ μ§€)
return key
def _extract_json_object(self, text: str) -> Optional[str]:
"""JSON 객체 μΆ”μΆœ"""
start = text.find('{')
if start == -1:
return None
depth = 0
for i, char in enumerate(text[start:], start):
if char == '{':
depth += 1
elif char == '}':
depth -= 1
if depth == 0:
return text[start:i+1]
return None
def _convert_raw_to_actual(self, params: Dict[str, float]) -> Dict[str, float]:
"""
β˜…β˜…β˜… 핡심: ν•™μŠ΅ λ°μ΄ν„°μ˜ raw 값을 μ‹€μ œ κ°’μœΌλ‘œ λ³€ν™˜ β˜…β˜…β˜…
ν•™μŠ΅ λ°μ΄ν„°λŠ” nn.Parameter의 raw 값을 μ €μž₯함.
μ‹€μ œ μ‚¬μš© μ‹œ sigmoid λ“± λ³€ν™˜μ΄ 적용됨.
"""
result = params.copy()
# 1. delay.feedback: sigmoid λ³€ν™˜
if 'delay.feedback' in result:
raw = result['delay.feedback']
actual = sigmoid(raw)
print(f" [Convert] delay.feedback: {raw:.4f} β†’ sigmoid β†’ {actual:.4f}")
result['delay.feedback'] = actual
# 2. delay.mix: sigmoid λ³€ν™˜
if 'delay.mix' in result:
raw = result['delay.mix']
actual = sigmoid(raw)
print(f" [Convert] delay.mix: {raw:.4f} β†’ sigmoid β†’ {actual:.4f}")
result['delay.mix'] = actual
# 3. distortion_amount: sigmoid * 0.1
if 'distortion_amount' in result:
raw = result['distortion_amount']
actual = sigmoid(raw) * 0.1
print(f" [Convert] distortion_amount: {raw:.4f} β†’ sigmoid*0.1 β†’ {actual:.4f}")
result['distortion_amount'] = actual
# 4. final_wet_mix: sigmoid λ³€ν™˜
if 'final_wet_mix' in result:
raw = result['final_wet_mix']
actual = sigmoid(raw)
print(f" [Convert] final_wet_mix: {raw:.4f} β†’ sigmoid β†’ {actual:.4f}")
result['final_wet_mix'] = actual
return result
def _clamp_values(self, params: Dict[str, float]) -> Dict[str, float]:
"""κ°’ λ²”μœ„ μ œν•œ"""
result = params.copy()
for key, (min_val, max_val) in PARAM_RANGES.items():
if key in result:
original = result[key]
clamped = max(min_val, min(max_val, original))
if clamped != original:
print(f" [Clamp] {key}: {original:.4f} β†’ {clamped:.4f}")
result[key] = clamped
return result
def _parse_output(self, output_text: str) -> Dict[str, float]:
"""LLM 좜λ ₯ νŒŒμ‹±"""
print(f" [Parse] Raw output 길이: {len(output_text)} 문자")
json_str = None
try:
text = output_text
# <think> νƒœκ·Έ 제거
text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
# μ½”λ“œλΈ”λ‘ μΆ”μΆœ
code_match = re.search(r'```(?:json)?\s*([\s\S]*?)```', text)
if code_match:
text = code_match.group(1)
# JSON μΆ”μΆœ
json_str = self._extract_json_object(text)
if json_str:
print(f" [Parse] JSON 발견 (길이: {len(json_str)})")
# μ „μ²˜λ¦¬
json_str = self._preprocess_json(json_str)
# νŒŒμ‹±
raw_params = json.loads(json_str)
# κ²°κ³Ό λ§€ν•‘
result = DEFAULT_PARAMETERS.copy()
parsed_count = 0
for key, value in raw_params.items():
try:
# ν‚€ μ •κ·œν™”
norm_key = self._normalize_key(key)
float_val = float(value)
# λ§€μΉ­λ˜λŠ” κΈ°λ³Έ ν‚€ μ°ΎκΈ°
matched_key = None
for default_key in DEFAULT_PARAMETERS.keys():
# μ •ν™•ν•œ λ§€μΉ­
if norm_key == default_key:
matched_key = default_key
break
# λΆ€λΆ„ λ§€μΉ­ (ν‚€ 끝뢀뢄)
if norm_key.endswith(default_key.split('.')[-1]) and \
norm_key.split('.')[0] == default_key.split('.')[0]:
matched_key = default_key
break
if matched_key:
result[matched_key] = float_val
parsed_count += 1
else:
print(f" [Parse] λ§€μΉ­ μ•ˆλ¨: {key} β†’ {norm_key}")
except (ValueError, TypeError) as e:
print(f" [Parse] λ³€ν™˜ μ‹€νŒ¨: {key}={value} ({e})")
print(f" [Parse] βœ… {parsed_count}개 νŒŒλΌλ―Έν„° 맀핑됨")
return result
except json.JSONDecodeError as e:
print(f" [Parse] ❌ JSON μ—λŸ¬: {e}")
if json_str:
pos = getattr(e, 'pos', 0)
print(f" [Parse] μœ„μΉ˜: ...{json_str[max(0,pos-20):pos+20]}...")
except Exception as e:
print(f" [Parse] ❌ μ˜ˆμ™Έ: {e}")
print(f" [Parse] ⚠️ κΈ°λ³Έκ°’ 폴백")
return DEFAULT_PARAMETERS.copy()
def predict(self, audio_path: str, text_prompt: str = "") -> Dict[str, float]:
"""νŒŒλΌλ―Έν„° 예츑"""
self.request_count += 1
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"\n{'='*60}")
print(f"[AIEffector V4] 🎡 μš”μ²­ #{self.request_count} - {timestamp}")
print(f"{'='*60}")
print(f" πŸ“‚ μ˜€λ””μ˜€: {Path(audio_path).name}")
print(f" πŸ’¬ 원본 ν”„λ‘¬ν”„νŠΈ: '{text_prompt}'")
# λ™μ˜μ–΄ λ³€ν™˜
processed_prompt = self._preprocess_text(text_prompt)
if processed_prompt != text_prompt.lower():
print(f" πŸ’¬ λ³€ν™˜ ν”„λ‘¬ν”„νŠΈ: '{processed_prompt}'")
print(f" πŸ€– λͺ¨λΈ: {'AI' if self.is_loaded() else '프리셋'}")
# λͺ¨λΈ μ—†μœΌλ©΄ 프리셋
if not self.is_loaded():
print(f"\n ⚠️ AI λͺ¨λΈ λ―Έλ‘œλ“œ")
params = DEFAULT_PARAMETERS.copy()
params.update(self._apply_preset(processed_prompt))
self._log_parameters(params)
return self._convert_to_effect_chain_format(params)
try:
# 1. CLAP νŠΉμ§• μΆ”μΆœ
print(f"\n πŸ“Š [Step 1] CLAP νŠΉμ§• μΆ”μΆœ...")
audio_features = self.audio_encoder.get_audio_features(audio_path)
if not audio_features or all(f == 0 for f in audio_features):
print(f" ⚠️ μ‹€νŒ¨, 프리셋 폴백")
params = DEFAULT_PARAMETERS.copy()
params.update(self._apply_preset(processed_prompt))
self._log_parameters(params)
return self._convert_to_effect_chain_format(params)
print(f" βœ… {len(audio_features)}차원")
# 2. ν”„λ‘¬ν”„νŠΈ 생성
print(f"\n πŸ”€ [Step 2] ν”„λ‘¬ν”„νŠΈ 생성...")
prompt = self._format_prompt(processed_prompt, audio_features)
# 3. 토큰화
print(f"\n πŸ”’ [Step 3] 토큰화...")
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1500
).to(self.device)
print(f" 토큰 수: {inputs['input_ids'].shape[1]}")
# 4. LLM 생성
print(f"\n 🧠 [Step 4] LLM μΆ”λ‘ ...")
import time
start = time.time()
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=500,
do_sample=False,
temperature=0.1,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
print(f" μΆ”λ‘  μ‹œκ°„: {time.time()-start:.2f}초")
# 5. λ””μ½”λ”©
print(f"\n πŸ“ [Step 5] λ””μ½”λ”©...")
gen_tokens = outputs[0][inputs['input_ids'].shape[1]:]
output_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
print(f" 좜λ ₯ (처음 400자):\n{output_text[:400]}")
# 6. νŒŒμ‹±
print(f"\n πŸ”§ [Step 6] νŒŒμ‹±...")
raw_params = self._parse_output(output_text)
# 7. β˜…β˜…β˜… Raw β†’ Actual λ³€ν™˜ β˜…β˜…β˜…
print(f"\n πŸ”„ [Step 7] Raw β†’ Actual λ³€ν™˜...")
actual_params = self._convert_raw_to_actual(raw_params)
# 8. κ°’ ν΄λž¨ν•‘
print(f"\n πŸ“ [Step 8] κ°’ ν΄λž¨ν•‘...")
clamped_params = self._clamp_values(actual_params)
# 9. 프리셋 보완 (delay.delay_time은 ν•™μŠ΅ μ•ˆλ¨)
print(f"\n πŸŽ›οΈ [Step 9] 프리셋 보완...")
preset = self._apply_preset(processed_prompt)
if 'delay.delay_time' in preset:
clamped_params['delay.delay_time'] = preset['delay.delay_time']
print(f" delay.delay_time: {preset['delay.delay_time']} (프리셋)")
# 10. λ‘œκΉ…
self._log_parameters(clamped_params)
print(f"\n βœ… μ™„λ£Œ!")
print(f"{'='*60}\n")
return self._convert_to_effect_chain_format(clamped_params)
except Exception as e:
print(f"\n ❌ μ‹€νŒ¨: {e}")
import traceback
traceback.print_exc()
params = DEFAULT_PARAMETERS.copy()
params.update(self._apply_preset(processed_prompt))
self._log_parameters(params)
return self._convert_to_effect_chain_format(params)
def _convert_to_effect_chain_format(self, params: Dict[str, float]) -> Dict[str, float]:
"""effect_chain.py ν˜•μ‹μœΌλ‘œ λ³€ν™˜ (Q β†’ q)"""
result = {}
for key, value in params.items():
new_key = key.replace('.Q', '.q')
result[new_key] = value
return result
def _log_parameters(self, params: Dict[str, float]):
"""νŒŒλΌλ―Έν„° λ‘œκΉ…"""
print(f"\n πŸ“‹ μ΅œμ’… νŒŒλΌλ―Έν„°:")
print(f" [EQ Peak 1] freq={params.get('eq_peak1.params.freq',0):.0f}Hz, gain={params.get('eq_peak1.params.gain',0):.2f}dB")
print(f" [EQ Peak 2] freq={params.get('eq_peak2.params.freq',0):.0f}Hz, gain={params.get('eq_peak2.params.gain',0):.2f}dB")
print(f" [Low Shelf] gain={params.get('eq_lowshelf.params.gain',0):.2f}dB")
print(f" [High Shelf] gain={params.get('eq_highshelf.params.gain',0):.2f}dB")
print(f" [Distortion] {params.get('distortion_amount',0):.4f}")
print(f" [Delay] time={params.get('delay.delay_time',0):.3f}s, fb={params.get('delay.feedback',0):.2f}, mix={params.get('delay.mix',0):.2f}")
print(f" [Wet Mix] {params.get('final_wet_mix',0):.2f}")